Missed TensorFlow Dev Summit? Check out the video playlist. Watch recordings


View source on GitHub

Abstract base class for TF Policies.

    time_step_spec, action_spec, policy_state_spec=(), info_spec=(), clip=True,
    emit_log_probability=False, automatic_state_reset=True,
    observation_and_action_constraint_splitter=None, name=None

Example of simple use in TF:

tf_env = SomeTFEnvironment() policy = SomeTFPolicy()

time_step, step_state, reset_env = tf_env.reset() policy_state = policy.get_initial_state(batch_size=tf_env.batch_size) action_step = policy.action(time_step, policy_state) next_time_step, _ = env.step(action_step.action, step_state)

sess.run([time_step, action, next_time_step])

Example of using the same policy for several steps:

tf_env = SomeTFEnvironment() policy = SomeTFPolicy()

exp_policy = SomeTFPolicy() update_policy = exp_policy.update(policy) policy_state = exp_policy.get_initial_state(tf_env.batch_size)

time_step, step_state, _ = tf_env.reset() action_step, policy_state, _ = exp_policy.action(time_step, policy_state) next_time_step, step_state = env.step(action_step.action, step_state)

for j in range(num_episodes): sess.run(update_policy) for i in range(num_steps): sess.run([time_step, action_step, next_time_step])

Example with multiple steps:

tf_env = SomeTFEnvironment() policy = SomeTFPolicy()

# reset() creates the initial time_step and step_state, plus a reset_op time_step, step_state, reset_op = tf_env.reset() policy_state = policy.get_initial_state(tf_env.batch_size) n_step = [time_step] for i in range(n): action_step = policy.action(time_step, policy_state) policy_state = action_step.state n_step.append(action_step) time_step, step_state = tf_env.step(action_step.action, step_state) n_step.append(time_step)

# n_step contains [time_step, action, time_step, action, ...] sess.run(n_step)

Example with explicit resets:

tf_env = SomeTFEnvironment() policy = SomeTFPolicy() policy_state = policy.get_initial_state(tf_env.batch_size)

time_step, step_state, reset_env = tf_env.reset() action_step = policy.action(time_step, policy_state) # It applies the action and returns the new TimeStep. next_time_step, _ = tf_env.step(action_step.action, step_state) next_action_step = policy.action(next_time_step, policy_state)

# The Environment and the Policy would be reset before starting. sess.run([time_step, action_step, next_time_step, next_action_step]) # Will force reset the Environment and the Policy. sess.run([reset_env]) sess.run([time_step, action_step, next_time_step, next_action_step])


  • time_step_spec: A TimeStep spec of the expected time_steps. Usually provided by the user to the subclass.
  • action_spec: A nest of BoundedTensorSpec representing the actions. Usually provided by the user to the subclass.
  • policy_state_spec: A nest of TensorSpec representing the policy_state. Provided by the subclass, not directly by the user.
  • info_spec: A nest of TensorSpec representing the policy info. Provided by the subclass, not directly by the user.
  • clip: Whether to clip actions to spec before returning them. Default True. Most policy-based algorithms (PCL, PPO, REINFORCE) use unclipped continuous actions for training.
  • emit_log_probability: Emit log-probabilities of actions, if supported. If True, policy_step.info will have CommonFields.LOG_PROBABILITY set. Please consult utility methods provided in policy_step for setting and retrieving these. When working with custom policies, either provide a dictionary info_spec or a namedtuple with the field 'log_probability'.
  • automatic_state_reset: If True, then get_initial_policy_state is used to clear state in action() and distribution() for for time steps where time_step.is_first().
  • observation_and_action_constraint_splitter: A function used to process observations with action constraints. These constraints can indicate, for example, a mask of valid/invalid actions for a given state of the environment. The function takes in a full observation and returns a tuple consisting of 1) the part of the observation intended as input to the network and 2) the constraint. An example observation_and_action_constraint_splitter could be as simple as: def observation_and_action_constraint_splitter(observation): return observation['network_input'], observation['constraint'] Note: when using observation_and_action_constraint_splitter, make sure the provided q_network is compatible with the network-specific half of the output of the observation_and_action_constraint_splitter. In particular, observation_and_action_constraint_splitter will be called on the observation before passing to the network. If observation_and_action_constraint_splitter is None, action constraints are not applied.
  • name: A name for this module. Defaults to the class name.


  • action_spec: Describes the TensorSpecs of the Tensors expected by step(action).

    action can be a single Tensor, or a nested dict, list or tuple of Tensors.

  • emit_log_probability: Whether this policy instance emits log probabilities or not.

  • info_spec: Describes the Tensors emitted as info by action and distribution.

    info can be an empty tuple, a single Tensor, or a nested dict, list or tuple of Tensors.

  • name: Returns the name of this module as passed or determined in the ctor.

    NOTE: This is not the same as the self.name_scope.name which includes parent module names.

  • name_scope: Returns a tf.name_scope instance for this class.

  • observation_and_action_constraint_splitter

  • policy_state_spec: Describes the Tensors expected by step(_, policy_state).

    policy_state can be an empty tuple, a single Tensor, or a nested dict, list or tuple of Tensors.

  • policy_step_spec: Describes the output of action().

  • submodules: Sequence of all sub-modules.

    Submodules are modules which are properties of this module, or found as properties of modules which are properties of this module (and so on).

a = tf.Module()
b = tf.Module()
c = tf.Module()
a.b = b
b.c = c
assert list(a.submodules) == [b, c]
assert list(b.submodules) == [c]
assert list(c.submodules) == []
  • time_step_spec: Describes the TimeStep tensors returned by step().

  • trainable_variables: Sequence of trainable variables owned by this module and its submodules.

  • trajectory_spec: Describes the Tensors written when using this policy with an environment.



View source

    time_step, policy_state=(), seed=None

Generates next action given the time_step and policy_state.


  • time_step: A TimeStep tuple corresponding to time_step_spec().
  • policy_state: A Tensor, or a nested dict, list or tuple of Tensors representing the previous policy_state.
  • seed: Seed to use if action performs sampling (optional).


A PolicyStep named tuple containing: action: An action Tensor matching the action_spec(). state: A policy state tensor to be fed into the next call to action. info: Optional side information such as action log probabilities.


  • RuntimeError: If subclass init didn't call super().init.


View source

    time_step, policy_state=()

Generates the distribution over next actions given the time_step.


  • time_step: A TimeStep tuple corresponding to time_step_spec().
  • policy_state: A Tensor, or a nested dict, list or tuple of Tensors representing the previous policy_state.


A PolicyStep named tuple containing:

action: A tf.distribution capturing the distribution of next actions. state: A policy state tensor for the next call to distribution. info: Optional side information such as action log probabilities.


View source


Returns an initial state usable by the policy.


  • batch_size: Tensor or constant: size of the batch dimension. Can be None in which case not dimensions gets added.


A nested object of type policy_state containing properly initialized Tensors.


View source

    policy, tau=1.0, tau_non_trainable=None, sort_variables_by_name=False

Update the current policy with another policy.

This would include copying the variables from the other policy.


  • policy: Another policy it can update from.
  • tau: A float scalar in [0, 1]. When tau is 1.0 (the default), we do a hard update. This is used for trainable variables.
  • tau_non_trainable: A float scalar in [0, 1] for non_trainable variables. If None, will copy from tau.
  • sort_variables_by_name: A bool, when True would sort the variables by name before doing the update.


An TF op to do the update.


View source


Returns the list of Variables that belong to the policy.


    cls, method

Decorator to automatically enter the module name scope.

class MyModule(tf.Module):
  def __call__(self, x):
    if not hasattr(self, 'w'):
      self.w = tf.Variable(tf.random.normal([x.shape[1], 64]))
    return tf.matmul(x, self.w)

Using the above module would produce tf.Variables and tf.Tensors whose names included the module name:

mod = MyModule()
mod(tf.ones([8, 32]))
# ==> <tf.Tensor: ...>
# ==> <tf.Variable ...'my_module/w:0'>


  • method: The method to wrap.


The original method wrapped such that it enters the module's name scope.