A nest of tf.TypeSpec representing the time_steps.
Provided by the user.
A nest of BoundedTensorSpec representing the actions.
Provided by the user.
An instance of tf_policy.Base representing the Agent's current
An instance of tf_policy.Base representing the Agent's
current data collection policy (used to set self.step_spec).
A python integer or None, signifying the number
of time steps required from tensors in experience as passed to
train(). All tensors in experience will be shaped [B, T, ...] but
for certain agents, T should be fixed. For example, DQN requires
transitions in the form of 2 time steps, so for a non-RNN DQN Agent, set
this value to 2. For agents that don't care, or which can handle T
unknown at graph build time (i.e. most RNN-based agents), set this
argument to None.
The number of outer dimensions for the agent. Must be
either 1 or 2. If 2, training will require both a batch_size and time
dimension on every Tensor; if 1, training will require only a batch_size
(Optional) Describes additional supported arguments
to the train call. This must be a dict mapping strings to nests
of specs. Overriding the experience arg is also supported.
Some algorithms require additional arguments to the train() call, and
while TF-Agents encourages most of these to be provided in the
policy_info / info field of experience, sometimes the extra
information doesn't fit well, i.e., when it doesn't come from the
A bool; if true, subclasses should gather debug
A bool; if true, subclasses should additionally
collect gradient and variable summaries.
A bool; if false, subclasses should not gather any
summaries (debug or otherwise); subclasses should gate all summaries
using either summaries_enabled, debug_summaries, or
An optional counter to increment every time the train
op is run. Defaults to the global_step.
If train_argspec is not a dict.
If train_argspec has the keys experience or weights.
If any leaf nodes in train_argspec values are not
subclasses of tf.TypeSpec.
If time_step_spec is not an instance of ts.TimeStep.
If num_outer_dims is not in [1, 2].
TensorSpec describing the action produced by the agent.
Returns a Trajectory spec, as expected by the collect_policy.
Return a policy that can be used to collect data from the environment.
Returns the name of this module as passed or determined in the ctor.
A batch of experience data in the form of a Trajectory. The
structure of experience must match that of self.collect_data_spec.
All tensors in experience must be shaped [batch, time, ...] where
time must be equal to self.train_step_length if that
property is not None.
(optional). A Tensor, either 0-D or shaped [batch],
containing weights to be used when calculating the total train loss.
Weights are typically multiplied elementwise against the per-batch loss,
but the implementation is up to the Agent.
Any additional data as declared by self.train_argspec.
A LossInfo loss tuple containing loss and info tensors.
In eager mode, the loss values are first calculated, then a train step
is performed before they are returned.
In graph mode, executing any or all of the loss tensors
will first calculate the loss value(s), then perform a train step,
and return the pre-train-step LossInfo.
If experience is not type Trajectory. Or if experience
does not match self.collect_data_spec structure types.
If experience tensors' time axes are not compatible with
self.train_sequence_length. Or if experience does not match
If the user does not pass **kwargs matching
If the class was not initialized properly (super.__init__
was not called).
Decorator to automatically enter the module name scope.
class MyModule(tf.Module): @tf.Module.with_name_scope def __call__(self, x): if not hasattr(self, 'w'): self.w = tf.Variable(tf.random.normal([x.shape, 3])) return tf.matmul(x, self.w)
Using the above module would produce tf.Variables and tf.Tensors whose
names included the module name: