Reshape tensors to contain only one batch dimension.
tf_agents.utils.nest_utils.flatten_multi_batched_nested_tensors(
tensors, specs
)
For each tensor, it checks the number of extra dimensions beyond those in
the spec, and reshapes tensor to have only one batch dimension.
NOTE: Each tensor's batch dimensions must be the same.
Args |
tensors
|
Nested list/tuple or dict of batched Tensors or SparseTensors.
|
specs
|
Nested list/tuple or dict of TensorSpecs, describing the shape of the
non-batched Tensors.
|
Returns |
A nested version of each tensor with a single batch dimension.
A list of the batch dimensions which were flattened.
|
Raises |
ValueError
|
if the tensors and specs have incompatible dimensions or shapes.
|