Missed TensorFlow Dev Summit? Check out the video playlist. Watch recordings


View source on GitHub

Base class for py_policy instances of TF policies in Eager mode.

Inherits From: Base

    *args, **kwargs

Handles adding and removing batch dimensions from the actions and time_steps. Note if you have a tf_policy you should directly use the PyTFEagerPolicy class instead of this Base.


  • time_step_spec: A TimeStep ArraySpec of the expected time_steps. Usually provided by the user to the subclass.
  • action_spec: A nest of BoundedArraySpec representing the actions. Usually provided by the user to the subclass.
  • policy_state_spec: A nest of ArraySpec representing the policy state. Provided by the subclass, not directly by the user.
  • info_spec: A nest of ArraySpec representing the policy info. Provided by the subclass, not directly by the user.
  • observation_and_action_constraint_splitter: A function used to process observations with action constraints. These constraints can indicate, for example, a mask of valid/invalid actions for a given state of the environment. The function takes in a full observation and returns a tuple consisting of 1) the part of the observation intended as input to the network and 2) the constraint. An example observation_and_action_constraint_splitter could be as simple as:
def observation_and_action_constraint_splitter(observation):
  return observation['network_input'], observation['constraint']

Note: when using observation_and_action_constraint_splitter, make sure the provided q_network is compatible with the network-specific half of the output of the observation_and_action_constraint_splitter. In particular, observation_and_action_constraint_splitter will be called on the observation before passing to the network. If observation_and_action_constraint_splitter is None, action constraints are not applied.


  • action_spec: Describes the ArraySpecs of the np.Array returned by action().

    action can be a single np.Array, or a nested dict, list or tuple of np.Array.

  • collect_data_spec: Describes the data collected when using this policy with an environment.

  • info_spec: Describes the Arrays emitted as info by action().

  • observation_and_action_constraint_splitter

  • policy_state_spec: Describes the arrays expected by functions with policy_state as input.

  • policy_step_spec: Describes the output of action().

  • time_step_spec: Describes the TimeStep np.Arrays expected by action(time_step).

  • trajectory_spec: Describes the data collected when using this policy with an environment.



View source

    time_step, policy_state=()

Generates next action given the time_step and policy_state.


  • time_step: A TimeStep tuple corresponding to time_step_spec().
  • policy_state: An optional previous policy_state.


A PolicyStep named tuple containing: action: A nest of action Arrays matching the action_spec(). state: A nest of policy states to be fed into the next call to action. info: Optional side information such as action log probabilities.


View source


Returns an initial state usable by the policy.


  • batch_size: An optional batch size.


An initial policy state.


View source