tf_agents.utils.nest_utils.where

View source on GitHub

Generalization of tf.where for nested structures.

This generalization handles applying where across nested structures and the special case where the rank of the condition is smaller than the rank of the true and false cases.

condition A boolean Tensor of shape [B, ...]. The shape of condition must be equal to or a prefix of the shape of true_outputs and false_outputs. If condition's rank is smaller than the rank of true_outputs and false_outputs, dimensions of size 1 are added to condition to make its rank match that of true_outputs and false_outputs in order to satisfy the requirements of tf.where.
true_outputs Tensor or nested tuple of Tensors of any dtype, each with shape [B, ...], to be split based on condition.
false_outputs Tensor or nested tuple of Tensors of any dtype, each with shape [B, ...], to be split based on condition.

Interleaved output from true_outputs and false_outputs based on condition.