|View source on GitHub|
Implements standard functionality on top of the AbstractTrainer API.
orbit.StandardTrainer( train_dataset, options: Optional[
orbit.StandardTrainerOptions] = None )
This class structures the training "inner loop" roughly as follows:
train_loop_begin() for _ in range(num_steps): train_step(train_iterator) return train_loop_end()
train_loop_end are always done in eager
mode, while the loop/
train_step may be implemented using
tf.function, as determined by the
options passed to
||Returns the name of this module as passed or determined in the ctor.|
||Sequence of non-trainable variables owned by this module and its submodules.|
Sequence of all sub-modules.
Submodules are modules which are properties of this module, or found as properties of modules which are properties of this module (and so on).
||The current training dataset.|
||Sequence of trainable variables owned by this module and its submodules.|
||Sequence of variables owned by this module and its submodules.|
Creates a training loop from the current step function and options.
|The train loop function, i.e. wrapper of multiple train steps.|
train( num_steps: tf.Tensor ) -> Optional[runner.Output]
num_steps steps of training.
The number of training steps to run. This corresponds directly
to the number of calls made to
The output of
Called once at the beginning of the training loop.
This method is always called in eager mode, and is a good place to reset metrics that accumulate values over multiple steps of training.
Note that this method is called before dataset iterator creation.
train_loop_end() -> Optional[runner.Output]
Called once at the end of the training loop.
This method is always called in eager mode, and is a good place to get
metric results. The value returned from this function will be returned as-is
train method implementation provided by
The function may return a dictionary of
train_step( iterator )
Implements one step of training.
What a "step" consists of is up to the implementer. When using distribution
strategies, the call to this method takes place in the "cross-replica
context" for generality, to allow e.g. multiple iterator dequeues and calls
Note that if
use_tf_function=True, all the code inside
be compatible with
tf.function tracing (and in particular, any state
self should be avoided). In some cases, non-
tf.function compatible code can be moved to
train_loop_end, which always execute eagerly.
with_name_scope( method )
Decorator to automatically enter the module name scope.
def __call__(self, x):
if not hasattr(self, 'w'):
self.w = tf.Variable(tf.random.normal([x.shape, 3]))
return tf.matmul(x, self.w)
mod = MyModule()
<tf.Tensor: shape=(1, 3), dtype=float32, numpy=..., dtype=float32)>
<tf.Variable 'my_module/Variable:0' shape=(2, 3) dtype=float32,
||The method to wrap.|
|The original method wrapped such that it enters the module's name scope.|