|View source on GitHub|
Generalization of tf.where for nested structures.
tf_agents.utils.nest_utils.where( condition, true_outputs, false_outputs )
This generalization handles applying where across nested structures and the special case where the rank of the condition is smaller than the rank of the true and false cases.
condition: A boolean Tensor of shape [B, ...]. The shape of condition must be equal to or a prefix of the shape of true_outputs and false_outputs. If condition's rank is smaller than the rank of true_outputs and false_outputs, dimensions of size 1 are added to condition to make its rank match that of true_outputs and false_outputs in order to satisfy the requirements of tf.where.
true_outputs: Tensor or nested tuple of Tensors of any dtype, each with shape [B, ...], to be split based on
false_outputs: Tensor or nested tuple of Tensors of any dtype, each with shape [B, ...], to be split based on
Interleaved output from
false_outputs based on