|View source on GitHub|
Wrapper that provides versions of a function with and without summaries.
orbit.utils.OptionalSummariesFunction( function, **tf_function_kwargs )
This is a utility class for implementing optimized summary recording via a
two-function approach, specifically important for TPUs. Two
versions of a given
function are created: one with soft device placement
enabled (for use on steps that require summary writing), and one with summary
writing and soft device placement entirely disabled (for use on all other
steps). This removes any performance impact of summaries on steps where they
aren't recorded (b/148418718).
This class can be used as a base class to implement summary optimizations for
a function with a specific signature. For example, to implement efficient TPU
summaries for a standard
train() method (as in
class TrainFunctionWithSummaries(orbit.utils.OptionalSummariesFunction): '''Implements a two-program approach for summaries on TPU.''' def __call__(self, num_steps): if tf.summary.should_record_summaries(): output = self.with_summaries(tf.constant(1)) num_steps -= 1 if num_steps >= 1: output = self.without_summaries(num_steps) return output
This can be used directly or to implement a decorator:
def train_function_with_summaries(function=None, **kwargs): if function is not None: return TrainFunctionWithSummaries(function, **kwargs) return functools.partial(TrainFunctionWithSummaries, **kwargs)
The decorator can be applied directly to
@train_function_with_summaries def train(self, num_steps): ...
A similar approach approach can be implemented for functions with different signatures.
This wrapper properly handles instance methods (see
||The underlying function to wrap.|
Additional arguments to pass to
A wrapped version of the underlying function with summaries
enabled (using whatever the active predicate is for
||A wrapped version of the underlying function with all summary recording disabled.|