Commit 04c7c1f5 authored by Benjamin's avatar Benjamin
Browse files

add test_ens_aai

parent f3d290e3
...@@ -109,8 +109,8 @@ class AnimalAIEnvironment(UnityEnvironment): ...@@ -109,8 +109,8 @@ class AnimalAIEnvironment(UnityEnvironment):
def reset(self, arenas_configurations: ArenaConfig = None) -> None: def reset(self, arenas_configurations: ArenaConfig = None) -> None:
if arenas_configurations: if arenas_configurations:
arenas_configurations_proto = arenas_configurations.to_proto() arenas_configurations_proto = arenas_configurations.to_proto()
arenas_configurations_proto_string = ( arenas_configurations_proto_string = arenas_configurations_proto.SerializeToString(
arenas_configurations_proto.SerializeToString() deterministic=True
) )
self.arenas_parameters_side_channel.send_raw_data( self.arenas_parameters_side_channel.send_raw_data(
bytearray(arenas_configurations_proto_string) bytearray(arenas_configurations_proto_string)
......
...@@ -112,7 +112,7 @@ def test_rgb(): ...@@ -112,7 +112,7 @@ def test_rgb():
def test_item(): def test_item():
item: Item = yaml.load(item_yaml,Loader=yaml.Loader) item: Item = yaml.load(item_yaml, Loader=yaml.Loader)
assert item.name == "Wall" assert item.name == "Wall"
assert len(item.positions) == 2 assert len(item.positions) == 2
...@@ -129,7 +129,7 @@ def test_item(): ...@@ -129,7 +129,7 @@ def test_item():
def test_arena(): def test_arena():
arena: Arena = yaml.load(arena_yaml,Loader=yaml.Loader) arena: Arena = yaml.load(arena_yaml, Loader=yaml.Loader)
assert arena.pass_mark == 0 assert arena.pass_mark == 0
assert arena.t == 250 assert arena.t == 250
assert len(arena.blackouts) == 9 assert len(arena.blackouts) == 9
......
from unittest.mock import patch, mock_open
import pytest
from animalai.envs.environment import AnimalAIEnvironment
from animalai.envs.arena_config import ArenaConfig
from mlagents_envs.mock_communicator import MockCommunicator
from mlagents_envs.base_env import BatchedStepResult
arena_config_yaml = """
!ArenaConfig
arenas:
0: !Arena
pass_mark: 2
t: 250
items:
- !Item
name: GoodGoalMulti
1: !Arena
pass_mark: -1
t: 250
items:
- !Item
name: BadGoal
"""
@patch("animalai.envs.environment.AnimalAIEnvironment.reset")
@patch("mlagents_envs.environment.UnityEnvironment.executable_launcher")
@patch("mlagents_envs.environment.UnityEnvironment.get_communicator")
def test_basic_initialization(mock_communicator, mock_launcher, mock_reset):
mock_communicator.return_value = MockCommunicator(
discrete_action=True, visual_inputs=1, num_agents=32, vec_obs_size=2
)
env = AnimalAIEnvironment(
file_name=" ", n_arenas=32, camera_height=126, camera_width=512
)
assert env.get_agent_groups() == ["RealFakeBrain"]
mock_launcher.assert_called_once()
launcher_args, _ = mock_launcher.call_args
executable_args = launcher_args[3]
assert executable_args == [
"--playerMode",
"0",
"--numberOfArenas",
"32",
"--cameraWidth",
"512",
"--cameraHeight",
"126",
]
env.close()
@patch("animalai.envs.environment.AnimalAIEnvironment.reset")
@patch("mlagents_envs.environment.UnityEnvironment.executable_launcher")
@patch("mlagents_envs.environment.UnityEnvironment.get_communicator")
def test_play_initialization(mock_communicator, mock_launcher, mock_reset):
mock_communicator.return_value = MockCommunicator()
env = AnimalAIEnvironment(file_name=" ", n_arenas=1, play=True)
mock_launcher.assert_called_once()
launcher_args, _ = mock_launcher.call_args
executable_args = launcher_args[3]
assert executable_args == ["--playerMode", "1", "--numberOfArenas", "1"]
env.close()
@patch("builtins.open", new_callable=mock_open, read_data=arena_config_yaml)
@patch("mlagents_envs.side_channel.raw_bytes_channel.RawBytesChannel.send_raw_data")
@patch("mlagents_envs.environment.UnityEnvironment.executable_launcher")
@patch("mlagents_envs.environment.UnityEnvironment.get_communicator")
def test_reset_arena_config(
mock_communicator, mock_launcher, mock_byte_channel, mock_yaml
):
mock_communicator.return_value = MockCommunicator(
discrete_action=True, visual_inputs=0, num_agents=2, vec_obs_size=2
)
arena_config = ArenaConfig(" ")
env = AnimalAIEnvironment(
file_name=" ", n_arenas=2, arenas_configurations=arena_config,
)
mock_byte_channel.assert_called_once()
bytes_arg = bytes(arena_config.to_proto().SerializeToString(deterministic=True))
# we cannot call assert_called_with
mock_byte_channel.assert_called_with(bytes_arg)
batched_step_result = env.get_step_result("RealFakeBrain")
spec = env.get_agent_group_spec("RealFakeBrain")
env.close()
assert isinstance(batched_step_result, BatchedStepResult)
assert len(spec.observation_shapes) == len(batched_step_result.obs)
n_agents = batched_step_result.n_agents()
for shape, obs in zip(spec.observation_shapes, batched_step_result.obs):
assert (n_agents,) + shape == obs.shape
if __name__ == "__main__":
pytest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment