Missed TensorFlow Dev Summit? Check out the video playlist. Watch recordings

tf_agents.networks.utils.BatchSquash

View source on GitHub

Facilitates flattening and unflattening batch dims of a tensor.

tf_agents.networks.utils.BatchSquash(
    batch_dims
)

Used in the notebooks

Used in the tutorials

Exposes a pair of matched faltten and unflatten methods. After flattening only 1 batch dimension will be left. This facilitates evaluating networks that expect inputs to have only 1 batch dimension.

Args:

  • batch_dims: Number of batch dimensions the flatten/unflatten ops should handle.

Raises:

  • ValueError: if batch dims is negative.

Methods

flatten

View source

flatten(
    tensor
)

Flattens and caches the tensor's batch_dims.

unflatten

View source

unflatten(
    tensor
)

Unflattens the tensor's batch_dims using the cached shape.