Save the date! Google I/O returns May 18-20 Register now


Manages all the learning details needed when training an agent.

Used in the notebooks

Used in the tutorials

These include:

  • Using distribution strategies correctly
  • Summaries
  • Checkpoints
  • Minimizing entering/exiting TF context: Especially in the case of TPUs scheduling a single TPU program to perform multiple train steps is critical for performance.
  • Generalizes the train call to be done correctly across CPU, GPU, or TPU executions managed by DistributionStrategies. This uses and then makes sure to do a reduce operation over the LossInfo returned by the agent.

root_dir Main directory path where checkpoints, saved_models, and summaries will be written to.
train_step a scalar tf.int64 tf.Variable which will keep track of the number of train steps. This is used for artifacts created like summaries, or outputs in the root_dir.
agent tf_agent.TFAgent instance to train with.
experience_dataset_fn a function that will create an instance of a used to sample experience for training. Required for using the Learner as is. Optional for subclass learners which take a new iterator each time when is called.
after_train_strategy_step_fn (Optional) callable of the form fn(sample, loss) which can be used for example to update priorities in a replay buffer where sample is pulled from the experience_iterator and loss is a LossInfo named tuple returned from the agent. This is called after every train step. It runs using
triggers List of callables of the form trigger(train_step). After every run call every trigger is called with the current train_step value as an np scalar.
checkpoint_interval Number of train steps in between checkpoints. Note these are placed into triggers and so a check to generate a checkpoint only occurs after every run call. Set to -1 to disable (this is not recommended, because it means that if the pipeline gets preempted, all previous progress is lost). This only takes care of the checkpointing the training process. Policies must be explicitly exported through triggers.
summary_interval Number of train steps in between summaries. Note these are placed into triggers and so a check to generate a checkpoint only occurs after every run call.
max_checkpoints_to_keep Maximum number of checkpoints to keep around. These are used to recover from pre-emptions when training.
use_kwargs_in_agent_train If True the experience from the replay buffer is passed into the agent as kwargs. This requires samples from the RB to be of the form dict(experience=experience, kwarg1=kwarg1, ...). This is useful if you have an agent with a custom argspec.
strategy (Optional) tf.distribute.Strategy to use during training.

train_step_numpy The current train_step.



View source

Runs iterations iterations of training.

iterations Number of train iterations to perform per call to run. The iterations will be evaluated in a tf.while loop created by autograph. Final aggregated losses will be returned.
iterator The iterator to the dataset to use for training. If not specified, self._experience_iterator is used.

The total loss computed before running the final step.


View source