This agent receives a neural network that it trains to predict rewards. The
action is chosen greedily with respect to the prediction.
A TimeStep spec of the expected time_steps.
A nest of BoundedTensorSpec representing the actions.
A tf_agents.network.Network to be used by the agent. The
network will be called with call(observation, step_type) and it is
expected to provide a reward prediction for all actions.
The optimizer to use for training.
A function used for masking
valid/invalid actions with each state of the environment. The function
takes in a full observation and returns a tuple consisting of 1) the
part of the observation intended as input to the bandit agent and
policy, and 2) the boolean mask. This function should also work with a
TensorSpec as input, and should output TensorSpec objects for the
observation and mask.
A function for computing the error loss, taking parameters
labels, predictions, and weights (any function from tf.losses would
work). The default is tf.losses.mean_squared_error.
A float representing the norm length to clip gradients
(or None for no clipping.)
A Python bool, default False. When True, debug summaries
A Python bool, default False. When True,
gradients and network variable summaries are written during training.
A Python bool, default True. When False, all summaries
(debug or otherwise) should not be written.
(tuple of strings) what side information we want to get
as part of the policy info. Allowed values can be found in
An optional tf.Variable to increment every time the
train op is run. Defaults to the global_step.
A float Tensor or a numpy array shaped
[num_actions, num_actions]. This holds the Laplacian matrix used to
regularize the smoothness of the estimated expected reward function.
This only applies to problems where the actions have a graph structure.
If None, the regularization is not applied.
A float that determines the weight of the
regularization term. Note that this has no effect if laplacian_matrix
above is None.
Python str name of this agent. All variables in this module will
fall under that name. Defaults to the class name.
If the action spec contains more than one action or or it is
not a bounded scalar int32 spec with minimum 0.
if the Laplacian provided is not None and not valid.
TensorSpec describing the action produced by the agent.
Returns a Trajectory spec, as expected by the collect_policy.
Return a policy that can be used to collect data from the environment.
Returns the name of this module as passed or determined in the ctor.
A batch of experience data in the form of a Trajectory. The
structure of experience must match that of self.collect_data_spec.
All tensors in experience must be shaped [batch, time, ...] where
time must be equal to self.train_step_length if that
property is not None.
(optional). A Tensor, either 0-D or shaped [batch],
containing weights to be used when calculating the total train loss.
Weights are typically multiplied elementwise against the per-batch loss,
but the implementation is up to the Agent.
Any additional data as declared by self.train_argspec.
A LossInfo loss tuple containing loss and info tensors.
In eager mode, the loss values are first calculated, then a train step
is performed before they are returned.
In graph mode, executing any or all of the loss tensors
will first calculate the loss value(s), then perform a train step,
and return the pre-train-step LossInfo.
If experience is not type Trajectory. Or if experience
does not match self.collect_data_spec structure types.
If experience tensors' time axes are not compatible with
self.train_sequence_length. Or if experience does not match
If the user does not pass **kwargs matching
If the class was not initialized properly (super.__init__
was not called).
Decorator to automatically enter the module name scope.
class MyModule(tf.Module): @tf.Module.with_name_scope def __call__(self, x): if not hasattr(self, 'w'): self.w = tf.Variable(tf.random.normal([x.shape, 3])) return tf.matmul(x, self.w)
Using the above module would produce tf.Variables and tf.Tensors whose
names included the module name: