Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in
Toggle navigation
Menu
Open sidebar
Ozan Catal
Animal Ai Env
Commits
04c7c1f5
Commit
04c7c1f5
authored
Apr 23, 2020
by
Benjamin
Browse files
add test_ens_aai
parent
f3d290e3
Changes
3
Show whitespace changes
Inline
Side-by-side
animalai/animalai/envs/environment.py
View file @
04c7c1f5
...
...
@@ -109,8 +109,8 @@ class AnimalAIEnvironment(UnityEnvironment):
def
reset
(
self
,
arenas_configurations
:
ArenaConfig
=
None
)
->
None
:
if
arenas_configurations
:
arenas_configurations_proto
=
arenas_configurations
.
to_proto
()
arenas_configurations_proto_string
=
(
arenas_configurations_proto
.
SerializeToString
()
arenas_configurations_proto_string
=
arenas_configurations_proto
.
SerializeToString
(
deterministic
=
True
)
self
.
arenas_parameters_side_channel
.
send_raw_data
(
bytearray
(
arenas_configurations_proto_string
)
...
...
animalai/animalai/envs/tests/test_arena_config.py
View file @
04c7c1f5
...
...
@@ -112,7 +112,7 @@ def test_rgb():
def
test_item
():
item
:
Item
=
yaml
.
load
(
item_yaml
,
Loader
=
yaml
.
Loader
)
item
:
Item
=
yaml
.
load
(
item_yaml
,
Loader
=
yaml
.
Loader
)
assert
item
.
name
==
"Wall"
assert
len
(
item
.
positions
)
==
2
...
...
@@ -129,7 +129,7 @@ def test_item():
def
test_arena
():
arena
:
Arena
=
yaml
.
load
(
arena_yaml
,
Loader
=
yaml
.
Loader
)
arena
:
Arena
=
yaml
.
load
(
arena_yaml
,
Loader
=
yaml
.
Loader
)
assert
arena
.
pass_mark
==
0
assert
arena
.
t
==
250
assert
len
(
arena
.
blackouts
)
==
9
...
...
animalai/animalai/envs/tests/test_envs_aai.py
View file @
04c7c1f5
from
unittest.mock
import
patch
,
mock_open
import
pytest
from
animalai.envs.environment
import
AnimalAIEnvironment
from
animalai.envs.arena_config
import
ArenaConfig
from
mlagents_envs.mock_communicator
import
MockCommunicator
from
mlagents_envs.base_env
import
BatchedStepResult
arena_config_yaml
=
"""
!ArenaConfig
arenas:
0: !Arena
pass_mark: 2
t: 250
items:
- !Item
name: GoodGoalMulti
1: !Arena
pass_mark: -1
t: 250
items:
- !Item
name: BadGoal
"""
@
patch
(
"animalai.envs.environment.AnimalAIEnvironment.reset"
)
@
patch
(
"mlagents_envs.environment.UnityEnvironment.executable_launcher"
)
@
patch
(
"mlagents_envs.environment.UnityEnvironment.get_communicator"
)
def
test_basic_initialization
(
mock_communicator
,
mock_launcher
,
mock_reset
):
mock_communicator
.
return_value
=
MockCommunicator
(
discrete_action
=
True
,
visual_inputs
=
1
,
num_agents
=
32
,
vec_obs_size
=
2
)
env
=
AnimalAIEnvironment
(
file_name
=
" "
,
n_arenas
=
32
,
camera_height
=
126
,
camera_width
=
512
)
assert
env
.
get_agent_groups
()
==
[
"RealFakeBrain"
]
mock_launcher
.
assert_called_once
()
launcher_args
,
_
=
mock_launcher
.
call_args
executable_args
=
launcher_args
[
3
]
assert
executable_args
==
[
"--playerMode"
,
"0"
,
"--numberOfArenas"
,
"32"
,
"--cameraWidth"
,
"512"
,
"--cameraHeight"
,
"126"
,
]
env
.
close
()
@
patch
(
"animalai.envs.environment.AnimalAIEnvironment.reset"
)
@
patch
(
"mlagents_envs.environment.UnityEnvironment.executable_launcher"
)
@
patch
(
"mlagents_envs.environment.UnityEnvironment.get_communicator"
)
def
test_play_initialization
(
mock_communicator
,
mock_launcher
,
mock_reset
):
mock_communicator
.
return_value
=
MockCommunicator
()
env
=
AnimalAIEnvironment
(
file_name
=
" "
,
n_arenas
=
1
,
play
=
True
)
mock_launcher
.
assert_called_once
()
launcher_args
,
_
=
mock_launcher
.
call_args
executable_args
=
launcher_args
[
3
]
assert
executable_args
==
[
"--playerMode"
,
"1"
,
"--numberOfArenas"
,
"1"
]
env
.
close
()
@
patch
(
"builtins.open"
,
new_callable
=
mock_open
,
read_data
=
arena_config_yaml
)
@
patch
(
"mlagents_envs.side_channel.raw_bytes_channel.RawBytesChannel.send_raw_data"
)
@
patch
(
"mlagents_envs.environment.UnityEnvironment.executable_launcher"
)
@
patch
(
"mlagents_envs.environment.UnityEnvironment.get_communicator"
)
def
test_reset_arena_config
(
mock_communicator
,
mock_launcher
,
mock_byte_channel
,
mock_yaml
):
mock_communicator
.
return_value
=
MockCommunicator
(
discrete_action
=
True
,
visual_inputs
=
0
,
num_agents
=
2
,
vec_obs_size
=
2
)
arena_config
=
ArenaConfig
(
" "
)
env
=
AnimalAIEnvironment
(
file_name
=
" "
,
n_arenas
=
2
,
arenas_configurations
=
arena_config
,
)
mock_byte_channel
.
assert_called_once
()
bytes_arg
=
bytes
(
arena_config
.
to_proto
().
SerializeToString
(
deterministic
=
True
))
# we cannot call assert_called_with
mock_byte_channel
.
assert_called_with
(
bytes_arg
)
batched_step_result
=
env
.
get_step_result
(
"RealFakeBrain"
)
spec
=
env
.
get_agent_group_spec
(
"RealFakeBrain"
)
env
.
close
()
assert
isinstance
(
batched_step_result
,
BatchedStepResult
)
assert
len
(
spec
.
observation_shapes
)
==
len
(
batched_step_result
.
obs
)
n_agents
=
batched_step_result
.
n_agents
()
for
shape
,
obs
in
zip
(
spec
.
observation_shapes
,
batched_step_result
.
obs
):
assert
(
n_agents
,)
+
shape
==
obs
.
shape
if
__name__
==
"__main__"
:
pytest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment