|View source on GitHub|
Facilitates flattening and unflattening batch dims of a tensor.
tf_agents.networks.utils.BatchSquash( batch_dims )
Used in the notebooks
|Used in the tutorials|
Exposes a pair of matched faltten and unflatten methods. After flattening only 1 batch dimension will be left. This facilitates evaluating networks that expect inputs to have only 1 batch dimension.
batch_dims: Number of batch dimensions the flatten/unflatten ops should handle.
ValueError: if batch dims is negative.
flatten( tensor )
Flattens and caches the tensor's batch_dims.
unflatten( tensor )
Unflattens the tensor's batch_dims using the cached shape.