orbit.utils.OptionalSummariesFunction

Wrapper that provides versions of a function with and without summaries.

This is a utility class for implementing optimized summary recording via a two-function approach, specifically important for TPUs. Two tf.function versions of a given function are created: one with soft device placement enabled (for use on steps that require summary writing), and one with summary writing and soft device placement entirely disabled (for use on all other steps). This removes any performance impact of summaries on steps where they aren't recorded (b/148418718).

This class can be used as a base class to implement summary optimizations for a function with a specific signature. For example, to implement efficient TPU summaries for a standard train() method (as in orbit.AbstractTrainer):

class TrainFunctionWithSummaries(orbit.utils.OptionalSummariesFunction):
  '''Implements a two-program approach for summaries on TPU.'''

  def __call__(self, num_steps):
    if tf.summary.should_record_summaries():
      output = self.with_summaries(tf.constant(1))
      num_steps -= 1
    if num_steps >= 1:
      output = self.without_summaries(num_steps)
    return output

This can be used directly or to implement a decorator:

def train_function_with_summaries(function=None, **kwargs):
  if function is not None:
    return TrainFunctionWithSummaries(function, **kwargs)
  return functools.partial(TrainFunctionWithSummaries, **kwargs)

The decorator can be applied directly to train() methods:

@train_function_with_summaries
def train(self, num_steps):
  ...

A similar approach approach can be implemented for functions with different signatures.

This wrapper properly handles instance methods (see __get__).

function The underlying function to wrap.
**tf_function_kwargs Additional arguments to pass to tf.function.

with_summaries A wrapped version of the underlying function with summaries enabled (using whatever the active predicate is for tf.summary.record_if), and placed inside a "soft device placement" context to enable summary recording on TPU.
without_summaries A wrapped version of the underlying function with all summary recording disabled.