curriculum.py 4.68 KB
Newer Older
Benjamin Beyret's avatar
Benjamin Beyret committed
1
2
3
4
5
import os
import json
import math

from .exception import CurriculumError
6
from animalai.envs.arena_config import ArenaConfig
Benjamin Beyret's avatar
Benjamin Beyret committed
7
8
9
10
11
12
13

import logging

logger = logging.getLogger('mlagents.trainers')


class Curriculum(object):
14
    def __init__(self, location, yaml_files):
Benjamin Beyret's avatar
Benjamin Beyret committed
15
16
17
        """
        Initializes a Curriculum object.
        :param location: Path to JSON defining curriculum.
18
        :param yaml_files: A list of configuration files for each lesson
Benjamin Beyret's avatar
Benjamin Beyret committed
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
        """
        self.max_lesson_num = 0
        self.measure = None
        self._lesson_num = 0
        # The name of the brain should be the basename of the file without the
        # extension.
        self._brain_name = os.path.basename(location).split('.')[0]

        try:
            with open(location) as data_file:
                self.data = json.load(data_file)
        except IOError:
            raise CurriculumError(
                'The file {0} could not be found.'.format(location))
        except UnicodeDecodeError:
            raise CurriculumError('There was an error decoding {}'
                                  .format(location))
        self.smoothing_value = 0
37
        for key in ['configuration_files', 'measure', 'thresholds',
Benjamin Beyret's avatar
Benjamin Beyret committed
38
39
40
41
42
43
44
45
46
47
                    'min_lesson_length', 'signal_smoothing']:
            if key not in self.data:
                raise CurriculumError("{0} does not contain a "
                                      "{1} field."
                                      .format(location, key))
        self.smoothing_value = 0
        self.measure = self.data['measure']
        self.min_lesson_length = self.data['min_lesson_length']
        self.max_lesson_num = len(self.data['thresholds'])

48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
        configuration_files = self.data['configuration_files']
        # for key in configuration_files:
        # if key not in default_reset_parameters:
        #     raise CurriculumError(
        #         'The parameter {0} in Curriculum {1} is not present in '
        #         'the Environment'.format(key, location))
        if len(configuration_files) != self.max_lesson_num + 1:
            raise CurriculumError(
                'The parameter {0} in Curriculum {1} must have {2} values '
                'but {3} were found'.format(key, location,
                                            self.max_lesson_num + 1,
                                            len(configuration_files)))
        folder_yaml_files = os.listdir(location)
        if not all([file in folder_yaml_files for file in configuration_files]):
            raise Curriculum(
                'One or more configuration file(s) in curriculum {0} could not be found'.format(location)
            )
        self.configurations = [ArenaConfig(os.path.join(location, file) for file in yaml_files)]
Benjamin Beyret's avatar
Benjamin Beyret committed
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89

    @property
    def lesson_num(self):
        return self._lesson_num

    @lesson_num.setter
    def lesson_num(self, lesson_num):
        self._lesson_num = max(0, min(lesson_num, self.max_lesson_num))

    def increment_lesson(self, measure_val):
        """
        Increments the lesson number depending on the progress given.
        :param measure_val: Measure of progress (either reward or percentage
               steps completed).
        :return Whether the lesson was incremented.
        """
        if not self.data or not measure_val or math.isnan(measure_val):
            return False
        if self.data['signal_smoothing']:
            measure_val = self.smoothing_value * 0.25 + 0.75 * measure_val
            self.smoothing_value = measure_val
        if self.lesson_num < self.max_lesson_num:
            if measure_val > self.data['thresholds'][self.lesson_num]:
                self.lesson_num += 1
90
91
92
93
94
                # config = {}
                # parameters = self.data['parameters']
                # for key in parameters:
                #     config[key] = parameters[key][self.lesson_num]
                logger.info('{0} lesson changed. Now in lesson {1}'
Benjamin Beyret's avatar
Benjamin Beyret committed
95
                            .format(self._brain_name,
96
                                    self.lesson_num))
Benjamin Beyret's avatar
Benjamin Beyret committed
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
                return True
        return False

    def get_config(self, lesson=None):
        """
        Returns reset parameters which correspond to the lesson.
        :param lesson: The lesson you want to get the config of. If None, the
               current lesson is returned.
        :return: The configuration of the reset parameters.
        """
        if not self.data:
            return {}
        if lesson is None:
            lesson = self.lesson_num
        lesson = max(0, min(lesson, self.max_lesson_num))
112
113
114
115
        config = self.configurations[lesson]
        # parameters = self.data['parameters']
        # for key in parameters:
        #     config[key] = parameters[key][lesson]
Benjamin Beyret's avatar
Benjamin Beyret committed
116
        return config