tf_agents.utils.eager_utils.create_train_op

View source on GitHub

Creates an Operation that evaluates the gradients and returns the loss.

total_loss A Tensor representing the total loss.
optimizer A tf.Optimizer to use for computing the gradients.
global_step A Tensor representing the global step variable. If left as _USE_GLOBAL_STEP, then tf.train.get_or_create_global_step() is used.
update_ops An optional list of updates to execute. If update_ops is None, then the update ops are set to the contents of the tf.GraphKeys.UPDATE_OPS collection. If update_ops is not None, but it doesn't contain all of the update ops in tf.GraphKeys.UPDATE_OPS, a warning will be displayed.
variables_to_train an optional list of variables to train. If None, it will default to all tf.trainable_variables().
transform_grads_fn A function which takes a single argument, a list of gradient to variable pairs (tuples), performs any requested gradient updates, such as gradient clipping or multipliers, and returns the updated list.
summarize_gradients Whether or not add summaries for each gradient.
gate_gradients How to gate the computation of gradients. See tf.Optimizer.
aggregation_method Specifies the method used to combine gradient terms. Valid values are defined in the class AggregationMethod.
check_numerics Whether or not we apply check_numerics.

A Tensor that when evaluated, computes the gradients and returns the total loss value.