View source on GitHub |
A Behavioral Cloning agent.
Inherits From: TFAgent
tf_agents.agents.BehavioralCloningAgent(
time_step_spec: tf_agents.trajectories.TimeStep
,
action_spec: tf_agents.typing.types.NestedTensorSpec
,
cloning_network: tf_agents.networks.Network
,
optimizer: tf_agents.typing.types.Optimizer
,
num_outer_dims: Literal[1, 2] = 1,
epsilon_greedy: tf_agents.typing.types.Float
= 0.1,
loss_fn: Optional[Callable[[types.NestedTensor, bool], types.Tensor]] = None,
gradient_clipping: Optional[types.Float] = None,
debug_summaries: bool = False,
summarize_grads_and_vars: bool = False,
train_step_counter: Optional[tf.Variable] = None,
name: Optional[Text] = None
)
Implements a generic form of BehavioralCloning that can also be used to pipe supervised learning through TF-Agents. By default the agent defines two types of losses.
For discrete actions the agent uses:
def discrete_loss(agent, experience, training=False):
bc_logits = self._cloning_network(experience.observation, training=training)
return tf.nn.sparse_softmax_cross_entropy_with_logits(
labels=experience.action - action_spec.minimum, logits=bc_logits)
This requires a Network that generates num_action
Q-values. In the case of
continuous actions a simple MSE loss is used by default:
def continuous_loss_fn(agent, experience, training=False):
bc_output, _ = self._cloning_network(
experience.observation,
step_type=experience.step_type,
training=training,
network_state=network_state)
if isinstance(bc_output, tfp.distributions.Distribution):
bc_action = bc_output.sample()
else:
bc_action = bc_output
return tf.losses.mse(experience.action, bc_action)
The implementation of these loss functions is slightly more complex to support nested action_specs.
Args | |
---|---|
time_step_spec
|
A TimeStep spec of the expected time_steps.
|
action_spec
|
A nest of BoundedTensorSpec representing the actions. |
cloning_network
|
A tf_agents.networks.Network to be used by the agent.
The network will be called as network(observation,
step_type=step_type, network_state=initial_state) and must return a
2-tuple with elements (output, next_network_state)
|
optimizer
|
The optimizer to use for training. |
num_outer_dims
|
The number of outer dimensions for the agent. Must be either 1 or 2. If 2, training will require both a batch_size and time dimension on every Tensor; if 1, training will require only a batch_size outer dimension. |
epsilon_greedy
|
probability of choosing a random action in the default epsilon-greedy collect policy (used only if actions are discrete) |
loss_fn
|
A function for computing the error between the output of the
cloning network and the action that was taken. If None, the loss depends
on the action dtype. The loss_fn is called with parameters:
(experience, training) , and must return a loss value for each element
of the batch.
|
gradient_clipping
|
Norm length to clip gradients. |
debug_summaries
|
A bool to gather debug summaries. |
summarize_grads_and_vars
|
If True, gradient and network variable summaries will be written during training. |
train_step_counter
|
An optional counter to increment every time the train op is run. Defaults to the global_step. |
name
|
The name of this agent. All variables in this module will fall under that name. Defaults to the class name. |
Methods
initialize
initialize() -> Optional[tf.Operation]
Initializes the agent.
Returns | |
---|---|
An operation that can be used to initialize the agent. |
Raises | |
---|---|
RuntimeError
|
If the class was not initialized properly (super.__init__
was not called).
|
loss
loss(
experience: tf_agents.typing.types.NestedTensor
,
weights: Optional[types.Tensor] = None,
training: bool = False,
**kwargs
) -> tf_agents.agents.tf_agent.LossInfo
Gets loss from the agent.
If the user calls this from _train, it must be in a tf.GradientTape
scope
in order to apply gradients to trainable variables.
If intermediate gradient steps are needed, _loss and _train will return
different values since _loss only supports updating all gradients at once
after all losses have been calculated.
Args | |
---|---|
experience
|
A batch of experience data in the form of a Trajectory . The
structure of experience must match that of self.training_data_spec .
All tensors in experience must be shaped [batch, time, ...] where
time must be equal to self.train_step_length if that property is not
None .
|
weights
|
(optional). A Tensor , either 0-D or shaped [batch] ,
containing weights to be used when calculating the total train loss.
Weights are typically multiplied elementwise against the per-batch loss,
but the implementation is up to the Agent.
|
training
|
Explicit argument to pass to loss . This typically affects
network computation paths like dropout and batch normalization.
|
**kwargs
|
Any additional data as args to loss .
|
Returns | |
---|---|
A LossInfo loss tuple containing loss and info tensors.
|
Raises | |
---|---|
RuntimeError
|
If the class was not initialized properly (super.__init__
was not called).
|
post_process_policy
post_process_policy() -> tf_agents.policies.TFPolicy
Post process policies after training.
The policies of some agents require expensive post processing after training before they can be used. e.g. A Recommender agent might require rebuilding an index of actions. For such agents, this method will return a post processed version of the policy. The post processing may either update the existing policies in place or create a new policy, depnding on the agent. The default implementation for agents that do not want to override this method is to return agent.policy.
Returns | |
---|---|
The post processed policy. |
preprocess_sequence
preprocess_sequence(
experience: tf_agents.typing.types.NestedTensor
) -> tf_agents.typing.types.NestedTensor
Defines preprocess_sequence function to be fed into replay buffers.
This defines how we preprocess the collected data before training.
Defaults to pass through for most agents.
Structure of experience
must match that of self.collect_data_spec
.
Args | |
---|---|
experience
|
a Trajectory shaped [batch, time, ...] or [time, ...] which
represents the collected experience data.
|
Returns | |
---|---|
A post processed Trajectory with the same shape as the input.
|
train
train(
experience: tf_agents.typing.types.NestedTensor
,
weights: Optional[types.Tensor] = None,
**kwargs
) -> tf_agents.agents.tf_agent.LossInfo
Trains the agent.
Args | |
---|---|
experience
|
A batch of experience data in the form of a Trajectory . The
structure of experience must match that of self.training_data_spec .
All tensors in experience must be shaped [batch, time, ...] where
time must be equal to self.train_step_length if that property is not
None .
|
weights
|
(optional). A Tensor , either 0-D or shaped [batch] ,
containing weights to be used when calculating the total train loss.
Weights are typically multiplied elementwise against the per-batch loss,
but the implementation is up to the Agent.
|
**kwargs
|
Any additional data to pass to the subclass. |
Returns | |
---|---|
A LossInfo loss tuple containing loss and info tensors.
|
Raises | |
---|---|
RuntimeError
|
If the class was not initialized properly (super.__init__
was not called).
|