Missed TensorFlow Dev Summit? Check out the video playlist. Watch recordings

tf_agents.metrics.py_metrics.AverageReturnMetric

View source on GitHub

Computes the average undiscounted reward.

Inherits From: StreamingMetric

tf_agents.metrics.py_metrics.AverageReturnMetric(
    *args, **kwargs
)

Used in the notebooks

Used in the tutorials

Attributes:

  • name: Returns the name of this module as passed or determined in the ctor.

    NOTE: This is not the same as the self.name_scope.name which includes parent module names.

  • name_scope: Returns a tf.name_scope instance for this class.

  • prefix: Prefix for the metric.

  • submodules: Sequence of all sub-modules.

    Submodules are modules which are properties of this module, or found as properties of modules which are properties of this module (and so on).

a = tf.Module()
b = tf.Module()
c = tf.Module()
a.b = b
b.c = c
assert list(a.submodules) == [b, c]
assert list(b.submodules) == [c]
assert list(c.submodules) == []
  • summary_op: TF summary op for this metric.
  • summary_placeholder: TF placeholder to be used for the result of this metric.
  • trainable_variables: Sequence of trainable variables owned by this module and its submodules.

  • variables: Sequence of variables owned by this module and its submodules.

Methods

__call__

View source

__call__(
    *args
)

Method to update the metric contents.

To change the behavior of this function, override the call method.

Different subclasses might use this differently. For instance, the PyStepMetric takes in a trajectory, while the CounterMetric takes no parameters.

Args:

  • *args: See call method of subclass for specific arguments.

add_to_buffer

View source

add_to_buffer(
    values
)

Appends new values to the buffer.

aggregate

View source

@staticmethod
aggregate(
    metrics
)

Aggregates a list of metrics.

The default behaviour is to return the average of the metrics.

Args:

  • metrics: a list of metrics, of the same class.

Returns:

The result of aggregating this metric.

call

View source

call(
    trajectory
)

Processes a trajectory to update the metric.

Args:

  • trajectory: A trajectory.Trajectory.

log

View source

log()

reset

View source

reset()

Resets internal stat gathering variables used to compute the metric.

result

View source

result()

Returns the value of this metric.

tf_summaries

View source

tf_summaries(
    train_step=None, step_metrics=()
)

Build TF summary op and placeholder for this metric.

To execute the op, call py_metric.run_summaries.

Args:

  • train_step: Step counter for training iterations. If None, no metric is generated against the global step.
  • step_metrics: Step values to plot as X axis in addition to global_step.

Returns:

The summary op.

Raises:

  • RuntimeError: If this method has already been called (it can only be called once).
  • ValueError: If any item in step_metrics is not of type PyMetric or tf_metric.TFStepMetric.

with_name_scope

@classmethod
with_name_scope(
    cls, method
)

Decorator to automatically enter the module name scope.

class MyModule(tf.Module):
  @tf.Module.with_name_scope
  def __call__(self, x):
    if not hasattr(self, 'w'):
      self.w = tf.Variable(tf.random.normal([x.shape[1], 64]))
    return tf.matmul(x, self.w)

Using the above module would produce tf.Variables and tf.Tensors whose names included the module name:

mod = MyModule()
mod(tf.ones([8, 32]))
# ==> <tf.Tensor: ...>
mod.w
# ==> <tf.Variable ...'my_module/w:0'>

Args:

  • method: The method to wrap.

Returns:

The original method wrapped such that it enters the module's name scope.