tf_agents.networks.utils.BatchSquash

Facilitates flattening and unflattening batch dims of a tensor.

Used in the notebooks

Used in the tutorials

Exposes a pair of matched faltten and unflatten methods. After flattening only 1 batch dimension will be left. This facilitates evaluating networks that expect inputs to have only 1 batch dimension.

batch_dims Number of batch dimensions the flatten/unflatten ops should handle.

ValueError if batch dims is negative.

Methods

flatten

View source

Flattens and caches the tensor's batch_dims.

unflatten

View source

Unflattens the tensor's batch_dims using the cached shape.