Missed TensorFlow Dev Summit? Check out the video playlist. Watch recordings

tf_agents.utils.nest_utils.split_nested_tensors

View source on GitHub

Split batched nested tensors, on batch dim (outer dim), into a list.

tf_agents.utils.nest_utils.split_nested_tensors(
    tensors, specs, num_or_size_splits
)

Args:

  • tensors: Nested list/tuple or dict of batched Tensors.
  • specs: Nested list/tuple or dict of TensorSpecs, describing the shape of the non-batched Tensors.
  • num_or_size_splits: Same as argument for tf.split. Either a python integer indicating the number of splits along batch_dim or a list of integer Tensors containing the sizes of each output tensor along batch_dim. If a scalar then it must evenly divide value.shape[axis]; otherwise the sum of sizes along the split dimension must match that of the value. For SparseTensor inputs, num_or_size_splits must be the scalar num_split (see documentation of tf.sparse.split for more details).

Returns:

A list of nested non-batched version of each tensor, where each list item corresponds to one batch item.

Raises:

  • ValueError: if the tensors and specs have incompatible dimensions or shapes.
  • ValueError: if a non-scalar is passed and there are SparseTensors in the structure.