Module: tf_agents.utils.nest_utils

Utilities for handling nested tensors.

Functions

assert_matching_dtypes_and_inner_shapes(...): Returns True if tensors and specs have matching dtypes and inner shapes.

assert_same_structure(...): Same as tf.nest.assert_same_structure but with cleaner error messages.

batch_nested_array(...)

batch_nested_tensors(...): Add batch dimension if needed to nested tensors while checking their specs.

fast_map_structure(...)

fast_map_structure_flatten(...)

flatten_and_check_shape_nested_specs(...): Flatten nested specs and check their shape for use in other functions.

flatten_multi_batched_nested_tensors(...): Reshape tensors to contain only one batch dimension.

flatten_with_joined_paths(...)

get_outer_array_shape(...): Batch dims of array's batch dimension dim.

get_outer_rank(...): Compares tensors to specs to determine the number of batch dimensions.

get_outer_shape(...): Runtime batch dims of tensor's batch dimension dim.

has_tensors(...)

is_batched_nested_tensors(...): Compares tensors to specs to determine if all tensors are batched or not.

prune_extra_keys(...): Recursively prunes keys from wide if they don't appear in narrow.

remove_singleton_batch_spec_dim(...): Look for spec's shape, check that outer dim is 1, and remove it.

spec_shape(...)

split_nested_tensors(...): Split batched nested tensors, on batch dim (outer dim), into a list.

stack_nested_arrays(...): Stack/batch a list of nested numpy arrays.

stack_nested_tensors(...): Stacks a list of nested tensors along the dimension specified.

tile_batch(...): Tile the batch dimension of a (possibly nested structure of) tensor(s).

unbatch_nested_array(...)

unbatch_nested_tensors(...): Remove the batch dimension if needed from nested tensors using their specs.

unbatch_nested_tensors_to_arrays(...)

unstack_nested_arrays(...): Unstack/unbatch a nest of numpy arrays.

unstack_nested_arrays_into_flat_items(...): Unstack/unbatch a nest of numpy arrays into flat items.

unstack_nested_tensors(...): Make list of unstacked nested tensors.

where(...): Generalization of tf.where for nested structures.