Missed TensorFlow Dev Summit? Check out the video playlist. Watch recordings


View source on GitHub

A neural network based epsilon greedy agent.

Inherits From: GreedyRewardPredictionAgent

    *args, **kwargs

This agent receives a neural network that it trains to predict rewards. The action is chosen greedily with respect to the prediction with probability 1 - epsilon, and uniformly randomly with probability epsilon.


  • time_step_spec: A TimeStep spec of the expected time_steps.
  • action_spec: A nest of BoundedTensorSpec representing the actions.
  • reward_network: A tf_agents.network.Network to be used by the agent. The network will be called with call(observation, step_type) and it is expected to provide a reward prediction for all actions. Note: when using observation_and_action_constraint_splitter, make sure the reward_network is compatible with the network-specific half of the output of the observation_and_action_constraint_splitter. In particular, observation_and_action_constraint_splitter will be called on the observation before passing to the network.
  • optimizer: The optimizer to use for training.
  • epsilon: A float representing the probability of choosing a random action instead of the greedy action.
  • observation_and_action_constraint_splitter: A function used for masking valid/invalid actions with each state of the environment. The function takes in a full observation and returns a tuple consisting of 1) the part of the observation intended as input to the bandit agent and policy, and 2) the boolean mask. This function should also work with a TensorSpec as input, and should output TensorSpec objects for the observation and mask.
  • error_loss_fn: A function for computing the error loss, taking parameters labels, predictions, and weights (any function from tf.losses would work). The default is tf.losses.mean_squared_error.
  • gradient_clipping: A float representing the norm length to clip gradients (or None for no clipping.)
  • debug_summaries: A Python bool, default False. When True, debug summaries are gathered.
  • summarize_grads_and_vars: A Python bool, default False. When True, gradients and network variable summaries are written during training.
  • enable_summaries: A Python bool, default True. When False, all summaries (debug or otherwise) should not be written.
  • emit_policy_info: (tuple of strings) what side information we want to get as part of the policy info. Allowed values can be found in policy_utilities.PolicyInfo.
  • train_step_counter: An optional tf.Variable to increment every time the train op is run. Defaults to the global_step.
  • laplacian_matrix: A float Tensor shaped [num_actions, num_actions]. This holds the Laplacian matrix used to regularize the smoothness of the estimated expected reward function. This only applies to problems where the actions have a graph structure. If None, the regularization is not applied.
  • laplacian_smoothing_weight: A float that determines the weight of the regularization term. Note that this has no effect if laplacian_matrix above is None.
  • name: Python str name of this agent. All variables in this module will fall under that name. Defaults to the class name.


  • action_spec: TensorSpec describing the action produced by the agent.

  • collect_data_spec: Returns a Trajectory spec, as expected by the collect_policy.

  • collect_policy: Return a policy that can be used to collect data from the environment.

  • debug_summaries

  • name: Returns the name of this module as passed or determined in the ctor.

    NOTE: This is not the same as the self.name_scope.name which includes parent module names.

  • name_scope: Returns a tf.name_scope instance for this class.

  • policy: Return the current policy held by the agent.

  • submodules: Sequence of all sub-modules.

    Submodules are modules which are properties of this module, or found as properties of modules which are properties of this module (and so on).

a = tf.Module()
b = tf.Module()
c = tf.Module()
a.b = b
b.c = c
assert list(a.submodules) == [b, c]
assert list(b.submodules) == [c]
assert list(c.submodules) == []
  • summaries_enabled
  • summarize_grads_and_vars
  • time_step_spec: Describes the TimeStep tensors expected by the agent.

  • train_sequence_length: The number of time steps needed in experience tensors passed to train.

    Train requires experience to be a Trajectory containing tensors shaped [B, T, ...]. This argument describes the value of T required.

    For example, for non-RNN DQN training, T=2 because DQN requires single transitions.

    If this value is None, then train can handle an unknown T (it can be determined at runtime from the data). Most RNN-based agents fall into this category.

  • train_step_counter

  • trainable_variables: Sequence of trainable variables owned by this module and its submodules.

  • variables: Sequence of variables owned by this module and its submodules.


  • ValueError: If the action spec contains more than one action or or it is not a bounded scalar int32 spec with minimum 0.



View source



View source


Initializes the agent.


An operation that can be used to initialize the agent.


  • RuntimeError: If the class was not initialized properly (super.__init__ was not called).


View source

    observations, actions, rewards, weights=None, training=False

Computes loss for reward prediction training.


  • observations: A batch of observations.
  • actions: A batch of actions.
  • rewards: A batch of rewards.
  • weights: Optional scalar or elementwise (per-batch-entry) importance weights. The output batch loss will be scaled by these weights, and the final scalar loss is the mean of these values.
  • training: Whether the loss is being used for training.


  • loss: A LossInfo containing the loss for the training step.


  • ValueError: if the number of actions is greater than 1.


View source

    experience, weights=None

Trains the agent.


  • experience: A batch of experience data in the form of a Trajectory. The structure of experience must match that of self.collect_data_spec. All tensors in experience must be shaped [batch, time, ...] where time must be equal to self.train_step_length if that property is not None.
  • weights: (optional). A Tensor, either 0-D or shaped [batch], containing weights to be used when calculating the total train loss. Weights are typically multiplied elementwise against the per-batch loss, but the implementation is up to the Agent.


A LossInfo loss tuple containing loss and info tensors.

  • In eager mode, the loss values are first calculated, then a train step is performed before they are returned.
  • In graph mode, executing any or all of the loss tensors will first calculate the loss value(s), then perform a train step, and return the pre-train-step LossInfo.


  • TypeError: If experience is not type Trajectory. Or if experience does not match self.collect_data_spec structure types.
  • ValueError: If experience tensors' time axes are not compatible with self.train_sequence_length. Or if experience does not match self.collect_data_spec structure.
  • RuntimeError: If the class was not initialized properly (super.__init__ was not called).


    cls, method

Decorator to automatically enter the module name scope.

class MyModule(tf.Module):
  def __call__(self, x):
    if not hasattr(self, 'w'):
      self.w = tf.Variable(tf.random.normal([x.shape[1], 64]))
    return tf.matmul(x, self.w)

Using the above module would produce tf.Variables and tf.Tensors whose names included the module name:

mod = MyModule()
mod(tf.ones([8, 32]))
# ==> <tf.Tensor: ...>
# ==> <tf.Variable ...'my_module/w:0'>


  • method: The method to wrap.


The original method wrapped such that it enters the module's name scope.