tfp.experimental.nn.util.make_fit_op

One training step.

loss_fn Python callable which returns the pair loss (tf.Tensor) and any other second result such that tf.nest.map_structure(tf.convert_to_tensor, other) will succeed.
optimizer tf.optimizers.Optimizer-like instance which has members gradient and apply_gradients.
trainable_variables tf.nest.flatten-able structure of tf.Variable instances.
grad_summary_fn Python callable which takes a trainable_variables-like structure of tf.Tensors representing the gradient of the result of loss_fn with respect to trainable_variables. For example, lambda grads: tf.nest.map_structure( lambda x: 0. if x is None else tf.norm(x), grads). Default value: None (i.e., no summarization is made).
tf_function bool representing whether the resulting function should be tf.function decoreated. Default value: True.
xla_compile bool representing whether XLA compilation should be performed. (This argument is ignored if the function is executed eagerly.) Default value: True.

fit_op A Python callable taking args which are forwarded to loss_fn and such that when called trainable_variables are updated per the logic of optimizer.apply_gradients.