Module: tf_agents.utils.eager_utils

Common utilities for TF-Agents.

Example of usage:

from tf_agents.utils import eager_utils

@eager_utils.run_in_graph_and_eager_modes
def loss_fn(x, y):
  v = tf.get_variable('v', initializer=tf.ones_initializer(), shape=())
  return v + x - y

with tfe.graph_mode():
  # loss and train_step are Tensors/Ops in the graph
  loss_op = loss_fn(inputs, labels)
  train_step_op = eager_utils.create_train_step(loss_op, optimizer)
  # Compute the loss and apply gradients to the variables using the optimizer.
  with tf.Session() as sess:
    sess.run(tf.compat.v1.global_variables_initializer())
    for _ in range(num_train_steps):
      loss_value = sess.run(train_step_op)

with tfe.eager_mode():
  # loss and train_step are lambda functions that can be called.
  loss = loss_fn(inputs, labels)
  train_step = eager_utils.create_train_step(loss, optimizer)
  # Compute the loss and apply gradients to the variables using the optimizer.
  for _ in range(num_train_steps):
    loss_value = train_step()

Classes

class Future: Converts a function or class method call into a future callable.

Functions

add_gradients_summaries(...): Add summaries to gradients.

add_variables_summaries(...): Add summaries for variables.

clip_gradient_norms(...): Clips the gradients by the given value.

clip_gradient_norms_fn(...): Returns a transform_grads_fn function for gradient clipping.

create_train_op(...): Creates an Operation that evaluates the gradients and returns the loss.

create_train_step(...): Creates a train_step that evaluates the gradients and returns the loss.

dataset_iterator(...): Constructs a Dataset iterator.

future_in_eager_mode(...): Decorator that allow a function/method to run in graph and in eager modes.

get_next(...): Returns the next element in a Dataset iterator.

has_self_cls_arg(...): Checks if it is method which takes self/cls as the first argument.

is_unbound(...): Checks if it is an unbounded method.

np_function(...): Decorator that allow a numpy function to be used in Eager and Graph modes.