Help protect the Great Barrier Reef with TensorFlow on Kaggle Join Challenge

tf.where

Return the elements where condition is True (multiplexing x and y).

Used in the notebooks

Used in the guide Used in the tutorials

This operator has two modes: in one mode both x and y are provided, in another mode neither are provided. condition is always expected to be a tf.Tensor of type bool.

Retrieving indices of True elements

If x and y are not provided (both are None):

tf.where will return the indices of condition that are True, in the form of a 2-D tensor with shape (n, d). (Where n is the number of matching indices in condition, and d is the number of dimensions in condition).

Indices are output in row-major order.

tf.where([True, False, False, True])
<tf.Tensor: shape=(2, 1), dtype=int64, numpy=
array([[0],
       [3]])>
tf.where([[True, False], [False, True]])
<tf.Tensor: shape=(2, 2), dtype=int64, numpy=
array([[0, 0],
       [1, 1]])>
tf.where([[[True, False], [False, True], [True, True]]])
<tf.Tensor: shape=(4, 3), dtype=int64, numpy=
array([[0, 0, 0],
       [0, 1, 1],
       [0, 2, 0],
       [0, 2, 1]])>

Multiplexing between x and y

If x and y are provided (both have non-None values):

tf.where will choose an output shape from the shapes of condition, x, and y that all three shapes are broadcastable to.

The condition tensor acts as a mask that chooses whether the corresponding element / row in the output should be taken from x (if the element in condition is True) or y (if it is false).

tf.where([True, False, False, True], [1,2,3,4], [100,200,300,400])
<tf.Tensor: shape=(4,), dtype=int32, numpy=array([  1, 200, 300,   4],
dtype=int32)>
tf.where([True, False, False, True], [1,2,3,4], [100])
<tf.Tensor: shape=(4,), dtype=int32, numpy=array([  1, 100, 100,   4],
dtype=int32)>
tf.where([True, False, False, True], [1,2,3,4], 100)
<tf.Tensor: shape=(4,), dtype=int32, numpy=array([  1, 100, 100,   4],
dtype=int32)>
tf.where([True, False, False, True], 1, 100)
<tf.Tensor: shape=(4,), dtype=int32, numpy=array([  1, 100, 100,   1],
dtype=int32)>
tf.where(True, [1,2,3,4], 100)
<tf.Tensor: shape=(4,), dtype=int32, numpy=array([1, 2, 3, 4],
dtype=int32)>
tf.where(False, [1,2,3,4], 100)
<tf.Tensor: shape=(4,), dtype=int32, numpy=array([100, 100, 100, 100],
dtype=int32)>

Note that if the gradient of either branch of the tf.where generates a NaN, then the gradient of the entire tf.where will be NaN. This is because the gradient calculation for tf.where combines the two branches, for performance reasons.

A workaround is to use an inner tf.where to ensure the function has no asymptote, and to avoid computing a value whose gradient is NaN by replacing dangerous inputs with safe inputs.

Instead of this,

x = tf.constant(0., dtype=tf.float32)
with tf.GradientTape() as tape:
  tape.watch(x)
  y = tf.where(x < 1., 0., 1. / x)
print(tape.gradient(y, x))
tf.Tensor(nan, shape=(), dtype=float32)

Although, the 1. / x values are never used, its gradient is a NaN when x =

  1. Instead, we should guard that with another tf.where
x = tf.constant(0., dtype=tf.float32)
with tf.GradientTape() as tape:
  tape.watch(x)
  safe_x = tf.where(tf.equal(x, 0.), 1., x)
  y = tf.where(x < 1., 0., 1. / safe_x)
print(tape.gradient(y, x))
tf.Tensor(0.0, shape=(), dtype=float32)

condition A tf.Tensor of type bool
x If provided, a Tensor which is