![]() |
A REINFORCE Agent.
Inherits From: TFAgent
tf_agents.agents.ReinforceAgent(
time_step_spec: tf_agents.trajectories.time_step.TimeStep
,
action_spec: types.TensorSpec,
actor_network: tf_agents.networks.network.Network
,
optimizer: tf_agents.typing.types.Optimizer
,
value_network: Optional[tf_agents.networks.network.Network
] = None,
value_estimation_loss_coef: tf_agents.typing.types.Float
= 0.2,
advantage_fn: Optional[tf_agents.typing.types.LossFn
] = None,
use_advantage_loss: bool = True,
gamma: tf_agents.typing.types.Float
= 1.0,
normalize_returns: bool = True,
gradient_clipping: Optional[types.Float] = None,
debug_summaries: bool = False,
summarize_grads_and_vars: bool = False,
entropy_regularization: Optional[types.Float] = None,
train_step_counter: Optional[tf.Variable] = None,
name: Optional[Text] = None
)
Used in the notebooks
Used in the tutorials |
---|
Implements:
REINFORCE algorithm from
"Simple statistical gradient-following algorithms for connectionist reinforcement learning" Williams, R.J., 1992. http://www-anw.cs.umass.edu/~barto/courses/cs687/williams92simple.pdf
REINFORCE with state-value baseline, where state-values are estimated with function approximation, from
"Reinforcement learning: An introduction" (Sec. 13.4) Sutton, R.S. and Barto, A.G., 2018. http://incompleteideas.net/book/the-book-2nd.html
The REINFORCE agent can be optionally provided with:
- value_network: A
tf_agents.network.Network
which parameterizes state-value estimation as a neural network. The network will be called with call(observation, step_type) and returns a floating point state-values tensor. - value_estimation_loss_coef: Weight on the value prediction loss.
If value_network and value_estimation_loss_coef are provided, advantages are
computed as
advantages = (discounted accumulated rewards) - (estimated state-values)
and the overall learning objective becomes:
(total loss) =
(policy gradient loss) +
value_estimation_loss_coef * (squared error of estimated state-values)
Args | |
---|---|
time_step_spec
|
A TimeStep spec of the expected time_steps.
|
action_spec
|
A nest of BoundedTensorSpec representing the actions. |
actor_network
|
A tf_agents.network.Network to be used by the agent. The network will be called with call(observation, step_type). |
optimizer
|
Optimizer for the actor network. |
value_network
|
(Optional) A tf_agents.network.Network to be used by the
agent. The network will be called with call(observation, step_type) and
returns a floating point value tensor.
|
value_estimation_loss_coef
|
(Optional) Multiplier for value prediction loss to balance with policy gradient loss. |
advantage_fn
|
A function A(returns, value_preds) that takes returns and
value function predictions as input and returns advantages. The default
is A(returns, value_preds) = returns - value_preds if a value network
is specified and use_advantage_loss=True , otherwise A(returns,
value_preds) = returns .
|
use_advantage_loss
|
Whether to use value function predictions for
computing returns. use_advantage_loss=False is equivalent to setting
advantage_fn=lambda returns, value_preds: returns .
|
gamma
|
A discount factor for future rewards. |
normalize_returns
|
Whether to normalize returns across episodes when computing the loss. |
gradient_clipping
|
Norm length to clip gradients. |
debug_summaries
|
A bool to gather debug summaries. |
summarize_grads_and_vars
|
If True, gradient and network variable summaries will be written during training. |
entropy_regularization
|
Coefficient for entropy regularization loss term. |
train_step_counter
|
An optional counter to increment every time the train op is run. Defaults to the global_step. |
name
|
The name of this agent. All variables in this module will fall under that name. Defaults to the class name. |
Attributes | |
---|---|
action_spec
|
TensorSpec describing the action produced by the agent. |
collect_data_spec
|
Returns a Trajectory spec, as expected by the collect_policy .
|
collect_policy
|
Return a policy that can be used to collect data from the environment. |
data_context
|
|
debug_summaries
|
|
policy
|
Return the current policy held by the agent. |
summaries_enabled
|
|
summarize_grads_and_vars
|
|
time_step_spec
|
Describes the TimeStep tensors expected by the agent.
|
train_argspec
|
TensorSpec describing extra supported kwargs to train() .
|
train_sequence_length
|
The number of time steps needed in experience tensors passed to train .
Train requires experience to be a For example, for non-RNN DQN training, If this value is |
train_step_counter
|
|
training_data_spec
|
Returns a trajectory spec, as expected by the train() function. |
validate_args
|
Whether train & preprocess_sequence validate input & output args.
|
Methods
entropy_regularization_loss
entropy_regularization_loss(
actions_distribution: tf_agents.typing.types.NestedDistribution
,
weights: Optional[types.Tensor] = None
) -> tf_agents.typing.types.Tensor
Computes the optional entropy regularization loss.
Extending REINFORCE by entropy regularization was originally proposed in "Function optimization using connectionist reinforcement learning algorithms." (Williams and Peng, 1991).
Args | |
---|---|
actions_distribution
|
A possibly batched tuple of action distributions. |
weights
|
Optional scalar or element-wise (per-batch-entry) importance weights. May include a mask for invalid timesteps. |
Returns | |
---|---|
entropy_regularization_loss
|
A tensor with the entropy regularization loss. |
initialize
initialize() -> Optional[tf.Operation]
Initializes the agent.
Returns | |
---|---|
An operation that can be used to initialize the agent. |
Raises | |
---|---|
RuntimeError
|
If the class was not initialized properly (super.__init__
was not called).
|
policy_gradient_loss
policy_gradient_loss(
actions_distribution: tf_agents.typing.types.NestedDistribution
,
actions: tf_agents.typing.types.NestedTensor
,
is_boundary: tf_agents.typing.types.Tensor
,
returns: tf_agents.typing.types.Tensor
,
num_episodes: tf_agents.typing.types.Int
,
weights: Optional[types.Tensor] = None
) -> tf_agents.typing.types.Tensor
Computes the policy gradient loss.
Args | |
---|---|
actions_distribution
|
A possibly batched tuple of action distributions. |
actions
|
Tensor with a batch of actions. |
is_boundary
|
Tensor of booleans that indicate if the corresponding action was in a boundary trajectory and should be ignored. |
returns
|
Tensor with a return from each timestep, aligned on index. Works better when returns are normalized. |
num_episodes
|
Number of episodes contained in the training data. |
weights
|
Optional scalar or element-wise (per-batch-entry) importance weights. May include a mask for invalid timesteps. |
Returns | |
---|---|
policy_gradient_loss
|
A tensor that will contain policy gradient loss for the on-policy experience. |
preprocess_sequence
preprocess_sequence(
experience: tf_agents.typing.types.NestedTensor
) -> tf_agents.typing.types.NestedTensor
Defines preprocess_sequence function to be fed into replay buffers.
This defines how we preprocess the collected data before training.
Defaults to pass through for most agents.
Structure of experience
must match that of self.collect_data_spec
.
Args | |
---|---|
experience
|
a Trajectory shaped [batch, time, ...] or [time, ...] which
represents the collected experience data.
|
Returns | |
---|---|
A post processed Trajectory with the same shape as the input.
|
Raises | |
---|---|
TypeError
|
If experience does not match self.collect_data_spec structure
types.
|
total_loss
total_loss(
experience: tf_agents.trajectories.trajectory.Trajectory
,
returns: tf_agents.typing.types.Tensor
,
weights: tf_agents.typing.types.Tensor
,
training: bool = False
) -> tf_agents.agents.tf_agent.LossInfo
train
train(
experience: tf_agents.typing.types.NestedTensor
,
weights: Optional[types.Tensor] = None,
**kwargs
) -> tf_agents.agents.tf_agent.LossInfo
Trains the agent.
Args | |
---|---|
experience
|
A batch of experience data in the form of a Trajectory . The
structure of experience must match that of self.training_data_spec .
All tensors in experience must be shaped [batch, time, ...] where
time must be equal to self.train_step_length if that
property is not None .
|
weights
|
(optional). A Tensor , either 0-D or shaped [batch] ,
containing weights to be used when calculating the total train loss.
Weights are typically multiplied elementwise against the per-batch loss,
but the implementation is up to the Agent.
|
**kwargs
|
Any additional data as declared by self.train_argspec .
|
Returns | |
---|---|
A LossInfo loss tuple containing loss and info tensors.
|
Raises | |
---|---|
TypeError
|
If validate_args is True and: Experience is not type
Trajectory ; or if experience does not match
self.training_data_spec structure types.
|
ValueError
|
If validate_args is True and: Experience tensors' time axes
are not compatible with self.train_sequence_length ; or if experience
does not match self.training_data_spec structure.
|
ValueError
|
If validate_args is True and the user does not pass
**kwargs matching self.train_argspec .
|
RuntimeError
|
If the class was not initialized properly (super.__init__
was not called).
|
value_estimation_loss
value_estimation_loss(
value_preds: tf_agents.typing.types.Tensor
,
returns: tf_agents.typing.types.Tensor
,
num_episodes: tf_agents.typing.types.Int
,
weights: Optional[types.Tensor] = None
) -> tf_agents.typing.types.Tensor
Computes the value estimation loss.
Args | |
---|---|
value_preds
|
Per-timestep estimated values. |
returns
|
Per-timestep returns for value function to predict. |
num_episodes
|
Number of episodes contained in the training data. |
weights
|
Optional scalar or element-wise (per-batch-entry) importance weights. May include a mask for invalid timesteps. |
Returns | |
---|---|
value_estimation_loss
|
A scalar value_estimation_loss loss. |