Missed TensorFlow Dev Summit? Check out the video playlist. Watch recordings

tf_agents.utils.nest_utils.flatten_multi_batched_nested_tensors

View source on GitHub

Reshape tensors to contain only one batch dimension.

tf_agents.utils.nest_utils.flatten_multi_batched_nested_tensors(
    tensors, specs
)

For each tensor, it checks the number of extra dimensions beyond those in the spec, and reshapes tensor to have only one batch dimension. NOTE: Each tensor's batch dimensions must be the same.

Args:

  • tensors: Nested list/tuple or dict of batched Tensors or SparseTensors.
  • specs: Nested list/tuple or dict of TensorSpecs, describing the shape of the non-batched Tensors.

Returns:

A nested version of each tensor with a single batch dimension. A list of the batch dimensions which were flattened.

Raises:

  • ValueError: if the tensors and specs have incompatible dimensions or shapes.