tfm.core.base_trainer.Trainer

Implements the common trainer shared for TensorFlow models.

config An ExperimentConfig instance specifying experiment config.
task A base_task.Task instance.
model The model instance, e.g. a tf.keras.Model instance.
optimizer tf.optimizers.Optimizer instance.
train bool, whether or not this trainer will be used for training. default to True.
evaluate bool, whether or not this trainer will be used for evaluation. default to True.
train_dataset a dataset object created for training. With tf.distribute, it needs to be a DistributedDataset.
validation_dataset a dataset object created for evaluation. With tf.distribute, it needs to be a DistributedDataset. The evaluator will create a dataset iterator for each eval round, so the dataset does not need to repeat.
checkpoint_exporter an object that has the maybe_export_checkpoint interface.

checkpoint Accesses the training checkpoint.
checkpoint_exporter Accesses the checkpoint exporter.
config

eval_dataset The current evaluation dataset.
global_step

model

optimizer

strategy

task

train_dataset The current training dataset.
train_loss Accesses the training loss metric object.
train_metrics Accesses all training metric objects.
validation_loss Accesses the validation loss metric object.
validation_metrics Accesses all validation metric metric objects.

Methods

coordinator_for_async

View source

create_eval_loop_fn

View source

Creates a training loop from the given step function and options.

create_train_loop_fn

View source

Creates a eval loop from the given step function and options.

distribute_dataset

View source

A utility function to help create a tf.distribute.DistributedDataset.

Args
dataset_or_fn A instance of tf.data.Dataset, or a "dataset function" returning a tf.data.Dataset. If it is a function, it may optionally have an argument named input_context which will be passed a tf.distribute.InputContext instance.
*args Any positional arguments to pass through to dataset_or_fn.
**kwargs Any keyword arguments to pass through to dataset_or_fn.

Returns
A distributed Dataset.

eval_begin

View source

Sets up metrics.

eval_end

View source

Processes evaluation results.

eval_reduce

View source

A function to perform per-step reduction on the evaluation outputs.

This is useful for passing state throughout evaluation, especially in cases where maintaining or accumulating state is hard to accomplish using tf.metrics.Metric or other tf.Variable-based approaches. For instance, it can be used to easily accumulate all per-example losses from the full evaluation for subsequent processing in eval_end().

Args
state A state being maintained throughout the evaluation.
step_outputs Outputs from the current evaluation step.

Returns
An output which is passed as the state argument to this function for the next step. After evaluation is finished, the output from last step will be passed to eval_end.

eval_step

View source

See base class.

evaluate

Implements num_steps steps of evaluation.

Args
num_steps The number of evaluation steps to run. When this is -1, evaluation proceeds until a call to eval_step raises a StopIteration or tf.errors.OutOfRangeError.

Returns
The output of self.eval_end().

Raises
ValueError If options.use_tf_while_loop is True and num_steps is unspecified.

init_async

View source

Initializes the Async Trainer base class.

initialize

View source

A callback function.

This function will be called when no checkpoint found for the model. If there is a checkpoint, the checkpoint will be loaded and this function will not be called. Tasks may use this callback function to load a pretrained checkpoint, saved under a directory other than the model_dir.

join

View source

Join all async steps. Only useful in aysnc training.

next_eval_inputs

View source

Fetches the next inputs for the model during eval.

This method consumes the input iterator and returns the next inputs for the model and an additional logs dict. The output dict remains in the host (not sent to GPUs/TPUs) and is merged with the model outputs which will be processed later in aggregate_logs. This is useful for sending extra logs downstream that are not compatible with the accelerators.

Args
iterator Dataset iterator to generate the next inputs from.

Returns
The inputs to the model, and an additional logs dictionnary. The logs are not passed to the model, instead they are merged with model output logs.

next_train_inputs

View source

Fetches the next inputs for the model during train.

This method consumes the input iterator and returns the next inputs for the model.

This method provides a way to control how to fetch the next model input, and what data to send to the model.

Args
iterator Dataset iterator to generate the next inputs from.

Returns
The inputs to the model.

train

Implements num_steps steps of training.

Args
num_steps The number of training steps to run. This corresponds directly to the number of calls made to train_step.

Returns
The output of train_loop_end.

train_loop_begin

Called once at the beginning of the training loop.

This method is always called in eager mode, and is a good place to reset metrics that accumulate values over multiple steps of training.

Note that this method is called before dataset iterator creation.

train_loop_end

View source

See base class.

train_step

View source

See base class.