Commit b6ee4065 authored by Benjamin's avatar Benjamin
Browse files

add typing to arena_config

parent a105d636
......@@ -3,6 +3,7 @@ import jsonpickle
import yaml
import copy
from typing import List
from animalai.communicator_objects import (
ArenasConfigurationsProto,
ArenaConfigurationProto,
......@@ -16,7 +17,7 @@ yaml.Dumper.ignore_aliases = lambda *args: True
class Vector3(yaml.YAMLObject):
yaml_tag = u"!Vector3"
def __init__(self, x=0, y=0, z=0):
def __init__(self, x: float = 0, y: float = 0, z: float = 0):
self.x = x
self.y = y
self.z = z
......@@ -33,7 +34,7 @@ class Vector3(yaml.YAMLObject):
class RGB(yaml.YAMLObject):
yaml_tag = u"!RGB"
def __init__(self, r=0, g=0, b=0):
def __init__(self, r: float = 0, g: float = 0, b: float = 0):
self.r = r
self.g = g
self.b = b
......@@ -51,7 +52,12 @@ class Item(yaml.YAMLObject):
yaml_tag = u"!Item"
def __init__(
self, name="", positions=None, rotations=None, sizes=None, colors=None
self,
name: str = "",
positions: List[Vector3] = None,
rotations: List[float] = None,
sizes: List[Vector3] = None,
colors: List[RGB] = None,
):
self.name = name
self.positions = positions if positions is not None else []
......@@ -73,7 +79,13 @@ class Item(yaml.YAMLObject):
class Arena(yaml.YAMLObject):
yaml_tag = u"!Arena"
def __init__(self, t=1000, items=None, pass_mark=0, blackouts=None):
def __init__(
self,
t: int = 1000,
items: List[Item] = None,
pass_mark: float = 0,
blackouts: List[int] = None,
):
self.t = t
self.items = items if items is not None else {}
self.pass_mark = pass_mark
......@@ -92,19 +104,19 @@ class Arena(yaml.YAMLObject):
class ArenaConfig(yaml.YAMLObject):
yaml_tag = u"!ArenaConfig"
def __init__(self, yaml_path=None):
def __init__(self, yaml_path: str = None):
if yaml_path is not None:
self.arenas = yaml.load(open(yaml_path, "r"), Loader=yaml.Loader).arenas
else:
self.arenas = {}
def save_config(self, json_path):
def save_config(self, json_path: str) -> None:
out = jsonpickle.encode(self.arenas)
out = json.loads(out)
json.dump(out, open(json_path, "w"), indent=4)
def to_proto(self, seed=-1) -> ArenasConfigurationsProto:
def to_proto(self, seed: int = -1) -> ArenasConfigurationsProto:
arenas_configurations_proto = ArenasConfigurationsProto()
arenas_configurations_proto.seed = seed
......@@ -113,11 +125,11 @@ class ArenaConfig(yaml.YAMLObject):
return arenas_configurations_proto
def update(self, arenas_configurations):
if arenas_configurations is not None:
for arena_i in arenas_configurations.arenas:
self.arenas[arena_i] = copy.copy(arenas_configurations.arenas[arena_i])
# def update(self, arenas_configurations:):
#
# if arenas_configurations is not None:
# for arena_i in arenas_configurations.arenas:
# self.arenas[arena_i] = copy.copy(arenas_configurations.arenas[arena_i])
def constructor_arena(loader, node):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment