Commit beec4fc7 authored by Benjamin Beyret's avatar Benjamin Beyret
Browse files

remove OTC specific code

parent b43bfc45
......@@ -13,11 +13,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Obstacle Tower-specific utilities including Atari-specific network architectures.
This includes a class implementing minimal preprocessing, which
is in charge of:
. Converting observations to greyscale.
"""
Code adapted from the Obstacle Tower competition
"""
from __future__ import absolute_import
......@@ -26,13 +23,10 @@ from __future__ import print_function
import math
from animalai.envs.gym.environment import AnimalAIEnv
import numpy as np
import tensorflow as tf
import gin.tf
import cv2
slim = tf.contrib.slim
......@@ -41,18 +35,6 @@ NATURE_DQN_DTYPE = tf.uint8 # DType of Atari 2600 observations.
NATURE_DQN_STACK_SIZE = 4 # Number of frames in the state stack.
@gin.configurable
def create_animalai_environment(environment_path=None):
"""Wraps the Animal AI environment with some basic preprocessing.
Returns:
An Animal AI environment with some standard preprocessing.
"""
assert environment_path is not None
env = AnimalAIEnv(environment_path, 0, n_arenas=1, retro=True)
env = OTCPreprocessing(env)
return env
@gin.configurable
def nature_dqn_network(num_actions, network_type, state):
"""The convolutional network used to compute the agent's Q-values.
......@@ -75,6 +57,7 @@ def nature_dqn_network(num_actions, network_type, state):
q_values = slim.fully_connected(net, num_actions, activation_fn=None)
return network_type(q_values)
@gin.configurable
def rainbow_network(num_actions, num_atoms, support, network_type, state):
"""The convolutional network used to compute agent's Q-value distributions.
......@@ -114,6 +97,7 @@ def rainbow_network(num_actions, num_atoms, support, network_type, state):
q_values = tf.reduce_sum(support * probabilities, axis=2)
return network_type(q_values, logits, probabilities)
@gin.configurable
def implicit_quantile_network(num_actions, quantile_embedding_dim,
network_type, state, num_quantiles):
......@@ -171,100 +155,3 @@ def implicit_quantile_network(num_actions, quantile_embedding_dim,
weights_initializer=weights_initializer)
return network_type(quantile_values=quantile_values, quantiles=quantiles)
#
# @gin.configurable
# class AAIPreprocessing(object):
# """A class implementing image preprocessing for OTC agents.
#
# Specifically, this converts observations to greyscale. It doesn't
# do anything else to the environment.
# """
#
# def __init__(self, environment):
# """Constructor for an Obstacle Tower preprocessor.
#
# Args:
# environment: Gym environment whose observations are preprocessed.
#
# """
# self.environment = environment
#
# self.game_over = False
# self.lives = 0 # Will need to be set by reset().
#
# @property
# def observation_space(self):
# return self.environment.observation_space
#
# @property
# def action_space(self):
# return self.environment.action_space
#
# @property
# def reward_range(self):
# return self.environment.reward_range
#
# @property
# def metadata(self):
# return self.environment.metadata
#
# def reset(self):
# """Resets the environment. Converts the observation to greyscale,
# if it is not.
#
# Returns:
# observation: numpy array, the initial observation emitted by the
# environment.
# """
# observation = self.environment.reset()
# if (len(observation.shape) > 2):
# observation = cv2.cvtColor(observation, cv2.COLOR_RGB2GRAY)
#
# return observation
#
# def render(self, mode):
# """Renders the current screen, before preprocessing.
#
# This calls the Gym API's render() method.
#
# Args:
# mode: Mode argument for the environment's render() method.
# Valid values (str) are:
# 'rgb_array': returns the raw ALE image.
# 'human': renders to display via the Gym renderer.
#
# Returns:
# if mode='rgb_array': numpy array, the most recent screen.
# if mode='human': bool, whether the rendering was successful.
# """
# return self.environment.render(mode)
#
# def step(self, action):
# """Applies the given action in the environment. Converts the observation to
# greyscale, if it is not.
#
# Remarks:
#
# * If a terminal state (from life loss or episode end) is reached, this may
# execute fewer than self.frame_skip steps in the environment.
# * Furthermore, in this case the returned observation may not contain valid
# image data and should be ignored.
#
# Args:
# action: The action to be executed.
#
# Returns:
# observation: numpy array, the observation following the action.
# reward: float, the reward following the action.
# is_terminal: bool, whether the environment has reached a terminal state.
# This is true when a life is lost and terminal_on_life_loss, or when the
# episode is over.
# info: Gym API's info data structure.
# """
#
# observation, reward, game_over, info = self.environment.step(action)
# self.game_over = game_over
# if (len(observation.shape) > 2):
# observation = cv2.cvtColor(observation, cv2.COLOR_RGB2GRAY)
# return observation, reward, game_over, info
......@@ -2,7 +2,7 @@ from setuptools import setup
setup(
name='animalai_train',
version='1.0.2',
version='1.0.3',
description='Animal AI competition training library',
url='https://github.com/beyretb/AnimalAI-Olympics',
author='Benjamin Beyret',
......@@ -20,7 +20,7 @@ setup(
zip_safe=False,
install_requires=[
'animalai>=1.0.2',
'animalai>=1.0.3',
'dopamine-rl',
'tensorflow==1.12.2',
'matplotlib',
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment