![]() |
Implements the common trainer shared for TensorFlow models.
tfm.core.base_trainer.Trainer(
config: tfm.core.base_trainer.ExperimentConfig
,
task: tfm.core.base_task.Task
,
model: tf.keras.Model,
optimizer: tf.optimizers.Optimizer,
train: bool = True,
evaluate: bool = True,
train_dataset: Optional[Union[tf.data.Dataset, tf.distribute.DistributedDataset]] = None,
validation_dataset: Optional[Union[tf.data.Dataset, tf.distribute.DistributedDataset]] = None,
checkpoint_exporter=None
)
Methods
create_eval_loop_fn
create_eval_loop_fn(
has_state: bool
)
Creates a training loop from the given step function and options.
create_train_loop_fn
create_train_loop_fn()
Creates a eval loop from the given step function and options.
distribute_dataset
distribute_dataset(
dataset_or_fn, *args, **kwargs
)
A utility function to help create a tf.distribute.DistributedDataset
.
Args | |
---|---|
dataset_or_fn
|
A instance of tf.data.Dataset , or a "dataset function"
returning a tf.data.Dataset . If it is a function, it may optionally
have an argument named input_context which will be passed a
tf.distribute.InputContext instance.
|
*args
|
Any positional arguments to pass through to dataset_or_fn .
|
**kwargs
|
Any keyword arguments to pass through to dataset_or_fn .
|
Returns | |
---|---|
A distributed Dataset. |
eval_begin
eval_begin()
Sets up metrics.
eval_end
eval_end(
aggregated_logs=None
)
Processes evaluation results.
eval_reduce
eval_reduce(
state=None, step_outputs=None
)
A function to perform per-step reduction on the evaluation outputs.
This is useful for passing state throughout evaluation, especially in cases
where maintaining or accumulating state is hard to accomplish using
tf.metrics.Metric
or other tf.Variable
-based approaches. For instance,
it can be used to easily accumulate all per-example losses from the full
evaluation for subsequent processing in eval_end()
.
Args | |
---|---|
state
|
A state being maintained throughout the evaluation. |
step_outputs
|
Outputs from the current evaluation step. |
Returns | |
---|---|
An output which is passed as the state argument to this function for the
next step. After evaluation is finished, the output from last step will be
passed to eval_end .
|
eval_step
eval_step(
iterator
)
See base class.
evaluate
evaluate(
num_steps: tf.Tensor
) -> Optional[runner.Output]
Implements num_steps
steps of evaluation.
Args | |
---|---|
num_steps
|
The number of evaluation steps to run. When this is -1,
evaluation proceeds until a call to eval_step raises a StopIteration
or tf.errors.OutOfRangeError .
|
Returns | |
---|---|
The output of self.eval_end() .
|
Raises | |
---|---|
ValueError
|
If options.use_tf_while_loop is True and num_steps is
unspecified.
|
init_async
init_async()
Initializes the Async Trainer base class.
initialize
initialize()
A callback function.
This function will be called when no checkpoint found for the model. If there is a checkpoint, the checkpoint will be loaded and this function will not be called. Tasks may use this callback function to load a pretrained checkpoint, saved under a directory other than the model_dir.
join
join()
Join all async steps. Only useful in aysnc training.
train
train(
num_steps: tf.Tensor
) -> Optional[runner.Output]
Implements num_steps
steps of training.
Args | |
---|---|
num_steps
|
The number of training steps to run. This corresponds directly
to the number of calls made to train_step .
|
Returns | |
---|---|
The output of train_loop_end .
|
train_loop_begin
train_loop_begin()
Called once at the beginning of the training loop.
This method is always called in eager mode, and is a good place to reset metrics that accumulate values over multiple steps of training.
Note that this method is called before dataset iterator creation.
train_loop_end
train_loop_end()
See base class.
train_step
train_step(
iterator
)
See base class.