![]() |
Compares tensors to specs to determine if all tensors are batched or not.
tf_agents.utils.nest_utils.is_batched_nested_tensors(
tensors,
specs,
num_outer_dims=1,
allow_extra_fields=False,
check_dtypes=True
)
For each tensor, it checks the dimensions and dtypes with respect to specs.
Returns True
if all tensors are batched and False
if all tensors are
unbatched.
Raises a ValueError
if the shapes are incompatible or a mix of batched and
unbatched tensors are provided.
Raises a TypeError
if tensors' dtypes do not match specs.
Args | |
---|---|
tensors
|
Nested list/tuple/dict of Tensors. |
specs
|
Nested list/tuple/dict of Tensors or CompositeTensors describing the shape of unbatched tensors. |
num_outer_dims
|
The integer number of dimensions that are considered batch dimensions. Default 1. |
allow_extra_fields
|
If True , then tensors may have extra subfields which
are not in specs. In this case, the extra subfields
will not be checked. For example: python
tensors = {"a": tf.zeros((3, 4), dtype=tf.float32),
"b": tf.zeros((5, 6), dtype=tf.float32)}
specs = {"a": tf.TensorSpec(shape=(4,), dtype=tf.float32)} assert
is_batched_nested_tensors(tensors, specs, allow_extra_fields=True)
The above example would raise a ValueError if allow_extra_fields was
False.
|
check_dtypes
|
If True will validate that tensors and specs have the same
dtypes.
|
Returns | |
---|---|
True if all Tensors are batched and False if all Tensors are unbatched. |
Raises | |
---|---|
ValueError
|
If
|
TypeError
|
If dtypes between tensors and specs are not compatible.
|