Missed TensorFlow Dev Summit? Check out the video playlist. Watch recordings

tf_agents.utils.nest_utils.is_batched_nested_tensors

View source on GitHub

Compares tensors to specs to determine if all tensors are batched or not.

tf_agents.utils.nest_utils.is_batched_nested_tensors(
    tensors, specs, num_outer_dims=1
)

For each tensor, it checks the dimensions and dtypes with respect to specs.

Returns True if all tensors are batched and False if all tensors are unbatched.

Raises a ValueError if the shapes are incompatible or a mix of batched and unbatched tensors are provided.

Raises a TypeError if tensors' dtypes do not match specs.

Args:

  • tensors: Nested list/tuple/dict of Tensors.
  • specs: Nested list/tuple/dict of Tensors or CompositeTensors describing the shape of unbatched tensors.
  • num_outer_dims: The integer number of dimensions that are considered batch dimensions. Default 1.

Returns:

True if all Tensors are batched and False if all Tensors are unbatched.

Raises:

  • ValueError: If
    1. Any of the tensors or specs have shapes with ndims == None, or
    2. The shape of Tensors are not compatible with specs, or
    3. A mix of batched and unbatched tensors are provided.
    4. The tensors are batched but have an incorrect number of outer dims.
  • TypeError: If dtypes between tensors and specs are not compatible.