ML Community Day is November 9! Join us for updates from TensorFlow, JAX, and more Learn more

tf_agents.train.Learner

Manages all the learning details needed when training an agent.

Used in the notebooks

Used in the tutorials

These include:

  • Using distribution strategies correctly
  • Summaries
  • Checkpoints
  • Minimizing entering/exiting TF context: Especially in the case of TPUs scheduling a single TPU program to perform multiple train steps is critical for performance.
  • Generalizes the train call to be done correctly across CPU, GPU, or TPU executions managed by DistributionStrategies. This uses strategy.run and then makes sure to do a reduce operation over the LossInfo returned by the agent.

root_dir Main directory path where checkpoints, saved_models, and summaries will be written to.
train_step a scalar tf.int64 tf.Variable which will keep track of the number of train steps. This is used for artifacts created like summaries, or outputs in the root_dir.
agent tf_agent.TFAgent instance to train with.
experience_dataset_fn a function that will create an instance of a tf.data.Dataset used to sample experience for training. Required for using the Learner as is. Optional for subclass learners which take a new iterator each time when learner.run is called.
after_train_strategy_step_fn (Optional) callable of the form fn(sample, loss) which can be used for example to update priorities in a replay buffer where sample is pulled from the experience_iterator and loss is a LossInfo named tuple returned from the agent. This is called after every train step. It runs using strategy.run(...).
triggers List of callables of the form trigger(train_step). After every run call every trigger is called with the current train_step value as an np scalar.
checkpoint_interval Number of train steps in between checkpoints. Note these are placed into triggers and so a check to generate a checkpoint only occurs after every run call. Set to -1 to disable (this is not recommended, because it means that if the pipeline gets preempted, all previous progress is lost). This only takes care of the checkpointing the training process. Policies must be explicitly exported through triggers.
summary_interval Number of train steps in between summaries. Note these are placed into triggers and so a check to generate a checkpoint only occurs after every run call.
max_checkpoints_to_keep Maximum number of checkpoints to keep around. These are used to recover from pre-emptions when training.
use_kwargs_in_agent_train If True the experience from the replay buffer is passed into the agent as kwargs. This requires samples from the RB to be of the form dict(experience=experience, kwarg1=kwarg1, ...). This is useful if you have an agent with a custom argspec.
strategy (Optional) tf.distribute.Strategy to use during training.
run_optimizer_variable_init Specifies if the variables of the optimizer are initialized before checkpointing. This should be almost always True (default) to ensure that the state of the optimizer is checkpointed properly. The initialization of the optimizer variables happens by building the Tensorflow graph. This is done by calling a get_concrete_function on the agent's train method which requires passing some input. Since, no real data is available at this point we use the batched form of training_data_spec to achieve this (standard technique). The problem arises when the agent expects some agent specific batching of the input. In this case, there is no general way at this point in the learner to batch the impacted specs properly. To avoid breaking the code in these specific cases, we recommend turning off initialization of the optimizer variables by setting the value of this field to False.
use_reverb_v2 If True then we expect the dataset samples to return a named_tuple with a data and an info field. If False we expect a tuple(data, info).

train_step_numpy The current train_step.

Methods

loss

View source

Computes loss for the experience.

Since this calls agent.loss() it does not update gradients or increment the train step counter. Networks are called with training=False so statistics like batch norm are not updated.

Args
experience_and_sample_info A batch of experience and sample info. If not specified, next(self._experience_iterator) is used.
reduce_op a tf.distribute.ReduceOp value specifying how loss values should be aggregated across replicas.

Returns
The total loss computed.

run

View source

Runs iterations iterations of training.

Args
iterations Number of train iterations to perform per call to run. The iterations will be evaluated in a tf.while loop created by autograph. Final aggregated losses will be returned.
iterator The iterator to the dataset to use for training. If not specified, self._experience_iterator is used.
parallel_iterations Maximum number of train iterations to allow running in parallel. This value is forwarded directly to the training tf.while loop.

Returns
The total loss computed before running the final step.

single_train_step

View source