Manages all the learning details needed when training an agent.
tf_agents.train.Learner(
root_dir, train_step, agent, experience_dataset_fn=None,
after_train_strategy_step_fn=None, triggers=None, checkpoint_interval=100000,
summary_interval=1000, max_checkpoints_to_keep=3,
use_kwargs_in_agent_train=False, strategy=None
)
Used in the notebooks
These include:
- Using distribution strategies correctly
- Summaries
- Checkpoints
- Minimizing entering/exiting TF context:
Especially in the case of TPUs scheduling a single TPU program to
perform multiple train steps is critical for performance.
- Generalizes the train call to be done correctly across CPU, GPU, or TPU
executions managed by DistributionStrategies. This uses
strategy.run
and
then makes sure to do a reduce operation over the LossInfo
returned by
the agent.
Args |
root_dir
|
Main directory path where checkpoints, saved_models, and
summaries will be written to.
|
train_step
|
a scalar tf.int64 tf.Variable which will keep track of the
number of train steps. This is used for artifacts created like
summaries, or outputs in the root_dir.
|
agent
|
tf_agent.TFAgent instance to train with.
|
experience_dataset_fn
|
a function that will create an instance of a
tf.data.Dataset used to sample experience for training. Required for
using the Learner as is. Optional for subclass learners which take a new
iterator each time when learner.run is called.
|
after_train_strategy_step_fn
|
(Optional) callable of the form
fn(sample, loss) which can be used for example to update priorities in
a replay buffer where sample is pulled from the experience_iterator
and loss is a LossInfo named tuple returned from the agent. This is
called after every train step. It runs using strategy.run(...) .
|
triggers
|
List of callables of the form trigger(train_step) . After every
run call every trigger is called with the current train_step value
as an np scalar.
|
checkpoint_interval
|
Number of train steps in between checkpoints. Note
these are placed into triggers and so a check to generate a checkpoint
only occurs after every run call. Set to -1 to disable (this is not
recommended, because it means that if the pipeline gets preempted, all
previous progress is lost). This only takes care of the checkpointing
the training process. Policies must be explicitly exported through
triggers.
|
summary_interval
|
Number of train steps in between summaries. Note these
are placed into triggers and so a check to generate a checkpoint only
occurs after every run call.
|
max_checkpoints_to_keep
|
Maximum number of checkpoints to keep around.
These are used to recover from pre-emptions when training.
|
use_kwargs_in_agent_train
|
If True the experience from the replay buffer
is passed into the agent as kwargs. This requires samples from the RB to
be of the form dict(experience=experience, kwarg1=kwarg1, ...) . This
is useful if you have an agent with a custom argspec.
|
strategy
|
(Optional) tf.distribute.Strategy to use during training.
|
Attributes |
train_step_numpy
|
The current train_step.
|
Methods
run
View source
run(
iterations=1, iterator=None
)
Runs iterations
iterations of training.
Args |
iterations
|
Number of train iterations to perform per call to run. The
iterations will be evaluated in a tf.while loop created by autograph.
Final aggregated losses will be returned.
|
iterator
|
The iterator to the dataset to use for training. If not
specified, self._experience_iterator is used.
|
Returns |
The total loss computed before running the final step.
|
single_train_step
View source
single_train_step(
iterator
)