|View source on GitHub|
Facilitates flattening and unflattening batch dims of a tensor.
tf_agents.networks.utils.BatchSquash( batch_dims )
Used in the notebooks
|Used in the tutorials|
Exposes a pair of matched faltten and unflatten methods. After flattening only 1 batch dimension will be left. This facilitates evaluating networks that expect inputs to have only 1 batch dimension.
||Number of batch dimensions the flatten/unflatten ops should handle.|
||if batch dims is negative.|
flatten( tensor )
Flattens and caches the tensor's batch_dims.
unflatten( tensor )
Unflattens the tensor's batch_dims using the cached shape.