Missed TensorFlow Dev Summit? Check out the video playlist. Watch recordings

tf_agents.utils.tensor_normalizer.EMATensorNormalizer

View source on GitHub

TensorNormalizer with exponential moving avg. mean and var estimates.

Inherits From: TensorNormalizer

tf_agents.utils.tensor_normalizer.EMATensorNormalizer(
    tensor_spec, scope='normalize_tensor', norm_update_rate=0.001
)

Attributes:

  • name: Returns the name of this module as passed or determined in the ctor.

    NOTE: This is not the same as the self.name_scope.name which includes parent module names.

  • name_scope: Returns a tf.name_scope instance for this class.

  • nested: True if tensor is nested, False otherwise.

  • submodules: Sequence of all sub-modules.

    Submodules are modules which are properties of this module, or found as properties of modules which are properties of this module (and so on).

a = tf.Module()
b = tf.Module()
c = tf.Module()
a.b = b
b.c = c
assert list(a.submodules) == [b, c]
assert list(b.submodules) == [c]
assert list(c.submodules) == []
  • trainable_variables: Sequence of trainable variables owned by this module and its submodules.

  • variables: Returns a tuple of tf variables owned by this EMATensorNormalizer.

Methods

copy

View source

copy(
    scope=None
)

Copy constructor for EMATensorNormalizer.

normalize

View source

normalize(
    tensor, clip_value=5.0, center_mean=True, variance_epsilon=0.001
)

Applies normalization to tensor.

Args:

  • tensor: Tensor to normalize.
  • clip_value: Clips normalized observations between +/- this value if clip_value > 0, otherwise does not apply clipping.
  • center_mean: If true, subtracts off mean from normalized tensor.
  • variance_epsilon: Epsilon to avoid division by zero in normalization.

Returns:

  • normalized_tensor: Tensor after applying normalization.

update

View source

update(
    tensor, outer_dims=(0,)
)

Updates tensor normalizer variables.

with_name_scope

@classmethod
with_name_scope(
    cls, method
)

Decorator to automatically enter the module name scope.

class MyModule(tf.Module):
  @tf.Module.with_name_scope
  def __call__(self, x):
    if not hasattr(self, 'w'):
      self.w = tf.Variable(tf.random.normal([x.shape[1], 64]))
    return tf.matmul(x, self.w)

Using the above module would produce tf.Variables and tf.Tensors whose names included the module name:

mod = MyModule()
mod(tf.ones([8, 32]))
# ==> <tf.Tensor: ...>
mod.w
# ==> <tf.Variable ...'my_module/w:0'>

Args:

  • method: The method to wrap.

Returns:

The original method wrapped such that it enters the module's name scope.