Module: tf_agents.utils.tensor_normalizer

View source on GitHub

Tensor normalizer classses.

These encapsulate variables and function for tensor normalization.

Example usage:

observation = tf.placeholder(tf.float32, shape=[]) tensor_normalizer = StreamingTensorNormalizer( tensor_spec.TensorSpec([], tf.float32), scope='normalize_observation') normalized_observation = tensor_normalizer.normalize(observation) update_normalization = tensor_normalizer.update(observation)

with tf.Session() as sess: for o in observation_list: # Compute normalized observation given current observation vars. normalizedobservation = sess.run( normalized_observation, feed_dict = {observation: o})

# Update normalization params for next normalization op.
sess.run(update_normalization, feed_dict = {observation: o})

# Do something with normalized_observation_
...

Classes

class EMATensorNormalizer: TensorNormalizer with exponential moving avg. mean and var estimates.

class StreamingTensorNormalizer: Normalizes mean & variance based on full history of tensor values.

class TensorNormalizer: Encapsulates tensor normalization and owns normalization variables.