Missed TensorFlow Dev Summit? Check out the video playlist. Watch recordings


View source on GitHub

One training step.

    loss_fn, optimizer, trainable_variables, grad_summary_fn=None, tf_function=True,


  • loss_fn: Python callable which returns the pair loss (tf.Tensor) and any other second result such that tf.nest.map_structure(tf.convert_to_tensor, other) will succeed.
  • optimizer: tf.optimizers.Optimizer-like instance which has members gradient and apply_gradients.
  • trainable_variables: tf.nest.flatten-able structure of tf.Variable instances.
  • grad_summary_fn: Python callable which takes a trainable_variables-like structure of tf.Tensors representing the gradient of the result of loss_fn with respect to trainable_variables. For example, lambda grads: tf.nest.map_structure( lambda x: 0. if x is None else tf.norm(x), grads). Default value: None (i.e., no summarization is made).
  • tf_function: bool representing whether the resulting function should be tf.function decoreated. Default value: True.
  • xla_compile: bool representing whether XLA compilation should be performed. (This argument is ignored if the function is executed eagerly.) Default value: True.


  • fit_op: A Python callable taking args which are forwarded to loss_fn and such that when called trainable_variables are updated per the logic of optimizer.apply_gradients.