Missed TensorFlow Dev Summit? Check out the video playlist. Watch recordings

tf_agents.utils.eager_utils.create_train_step

View source on GitHub

Creates a train_step that evaluates the gradients and returns the loss.

tf_agents.utils.eager_utils.create_train_step(
    loss, optimizer, global_step=_USE_GLOBAL_STEP, total_loss_fn=None,
    update_ops=None, variables_to_train=None, transform_grads_fn=None,
    summarize_gradients=False, gate_gradients=tf.compat.v1.train.Optimizer.GATE_OP,
    aggregation_method=None, check_numerics=True
)

Args:

  • loss: A (possibly nested tuple of) Tensor or function representing the loss.
  • optimizer: A tf.Optimizer to use for computing the gradients.
  • global_step: A Tensor representing the global step variable. If left as _USE_GLOBAL_STEP, then tf.train.get_or_create_global_step() is used.
  • total_loss_fn: Function to call on loss value to access the final item to minimize.
  • update_ops: An optional list of updates to execute. If update_ops is None, then the update ops are set to the contents of the tf.GraphKeys.UPDATE_OPS collection. If update_ops is not None, but it doesn't contain all of the update ops in tf.GraphKeys.UPDATE_OPS, a warning will be displayed.
  • variables_to_train: an optional list of variables to train. If None, it will default to all tf.trainable_variables().
  • transform_grads_fn: A function which takes a single argument, a list of gradient to variable pairs (tuples), performs any requested gradient updates, such as gradient clipping or multipliers, and returns the updated list.
  • summarize_gradients: Whether or not add summaries for each gradient.
  • gate_gradients: How to gate the computation of gradients. See tf.Optimizer.
  • aggregation_method: Specifies the method used to combine gradient terms. Valid values are defined in the class AggregationMethod.
  • check_numerics: Whether or not we apply check_numerics.

Returns:

In graph mode: A (possibly nested tuple of) Tensor that when evaluated, calculates the current loss, computes the gradients, applies the optimizer, and returns the current loss. In eager mode: A lambda function that when is called, calculates the loss, then computes and applies the gradients and returns the original loss values.

Raises:

  • ValueError: if loss is not callable.
  • RuntimeError: if resource variables are not enabled.