Revision 84ea7aa1fd83966d6a121e60e14dfaac4675c767 authored by AurelianTactics on 24 October 2018, 16:59:46 UTC, committed by pzhokhov on 24 October 2018, 16:59:46 UTC
DeepQ, PPO2, ACER, trpo_mpi, A2C, and ACKTR have the code for: ``` from baselines.common import set_global_seeds ... def learn(...): ... set_global_seeds(seed) ``` DDPG has the argument 'seed=None' but doesn't have the two lines of code needed to set the global seeds.
1 parent 88300ed
models.py
import tensorflow as tf
from baselines.common.models import get_network_builder
class Model(object):
def __init__(self, name, network='mlp', **network_kwargs):
self.name = name
self.network_builder = get_network_builder(network)(**network_kwargs)
@property
def vars(self):
return tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=self.name)
@property
def trainable_vars(self):
return tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=self.name)
@property
def perturbable_vars(self):
return [var for var in self.trainable_vars if 'LayerNorm' not in var.name]
class Actor(Model):
def __init__(self, nb_actions, name='actor', network='mlp', **network_kwargs):
super().__init__(name=name, network=network, **network_kwargs)
self.nb_actions = nb_actions
def __call__(self, obs, reuse=False):
with tf.variable_scope(self.name, reuse=tf.AUTO_REUSE):
x = self.network_builder(obs)
x = tf.layers.dense(x, self.nb_actions, kernel_initializer=tf.random_uniform_initializer(minval=-3e-3, maxval=3e-3))
x = tf.nn.tanh(x)
return x
class Critic(Model):
def __init__(self, name='critic', network='mlp', **network_kwargs):
super().__init__(name=name, network=network, **network_kwargs)
self.layer_norm = True
def __call__(self, obs, action, reuse=False):
with tf.variable_scope(self.name, reuse=tf.AUTO_REUSE):
x = tf.concat([obs, action], axis=-1) # this assumes observation and action can be concatenated
x = self.network_builder(x)
x = tf.layers.dense(x, 1, kernel_initializer=tf.random_uniform_initializer(minval=-3e-3, maxval=3e-3))
return x
@property
def output_vars(self):
output_vars = [var for var in self.trainable_vars if 'output' in var.name]
return output_vars
Computing file changes ...