Missed TensorFlow Dev Summit? Check out the video playlist. Watch recordings

tf_agents.utils.nest_utils.where

View source on GitHub

Generalization of tf.where for nested structures.

tf_agents.utils.nest_utils.where(
    condition, true_outputs, false_outputs
)

This generalization handles applying where across nested structures and the special case where the rank of the condition is smaller than the rank of the true and false cases.

Args:

  • condition: A boolean Tensor of shape [B, ...]. The shape of condition must be equal to or a prefix of the shape of true_outputs and false_outputs. If condition's rank is smaller than the rank of true_outputs and false_outputs, dimensions of size 1 are added to condition to make its rank match that of true_outputs and false_outputs in order to satisfy the requirements of tf.where.
  • true_outputs: Tensor or nested tuple of Tensors of any dtype, each with shape [B, ...], to be split based on condition.
  • false_outputs: Tensor or nested tuple of Tensors of any dtype, each with shape [B, ...], to be split based on condition.

Returns:

Interleaved output from true_outputs and false_outputs based on condition.