View source on GitHub 
Returns the indices of nonzero elements, or multiplexes x
and y
.
tf.where(
condition, x=None, y=None, name=None
)
This operation has two modes:
 Return the indices of nonzero elements  When only
condition
is provided the result is anint64
tensor where each row is the index of a nonzero element ofcondition
. The result's shape is[tf.math.count_nonzero(condition), tf.rank(condition)]
.  Multiplex
x
andy
 When bothx
andy
are provided the result has the shape ofx
,y
, andcondition
broadcast together. The result is taken fromx
wherecondition
is nonzero ory
wherecondition
is zero.
1. Return the indices of nonzero elements
If x
and y
are not provided (both are None):
tf.where
will return the indices of condition
that are nonzero,
in the form of a 2D tensor with shape [n, d]
, where n
is the number of
nonzero elements in condition
(tf.count_nonzero(condition)
), and d
is
the number of axes of condition
(tf.rank(condition)
).
Indices are output in rowmajor order. The condition
can have a dtype
of
tf.bool
, or any numeric dtype
.
Here condition
is a 1axis bool
tensor with 2 True
values. The result
has a shape of [2,1]
tf.where([True, False, False, True]).numpy()
array([[0],
[3]])
Here condition
is a 2axis integer tensor, with 3 nonzero values. The
result has a shape of [3, 2]
.
tf.where([[1, 0, 0], [1, 0, 1]]).numpy()
array([[0, 0],
[1, 0],
[1, 2]])
Here condition
is a 3axis float tensor, with 5 nonzero values. The output
shape is [5, 3]
.
float_tensor = [[[0.1, 0], [0, 2.2], [3.5, 1e6]],
[[0, 0], [0, 0], [99, 0]]]
tf.where(float_tensor).numpy()
array([[0, 0, 0],
[0, 1, 1],
[0, 2, 0],
[0, 2, 1],
[1, 2, 0]])
These indices are the same that tf.sparse.SparseTensor
would use to
represent the condition tensor:
sparse = tf.sparse.from_dense(float_tensor)
sparse.indices.numpy()
array([[0, 0, 0],
[0, 1, 1],
[0, 2, 0],
[0, 2, 1],
[1, 2, 0]])
A complex number is considered nonzero if either the real or imaginary component is nonzero:
tf.where([complex(0.), complex(1.), 0+1j, 1+1j]).numpy()
array([[1],
[2],
[3]])
2. Multiplex x
and y
If x
and y
are also provided (both have nonNone values) the condition
tensor acts as a mask that chooses whether the corresponding
element / row in the output should be taken from x
(if the element in
condition
is True
) or y
(if it is False
).
The shape of the result is formed by
broadcasting
together the shapes of condition
, x
, and y
.
When all three inputs have the same size, each is handled elementwise.
tf.where([True, False, False, True],
[1, 2, 3, 4],
[100, 200, 300, 400]).numpy()
array([ 1, 200, 300, 4], dtype=int32)
There are two main rules for broadcasting:
 If a tensor has fewer axes than the others, length1 axes are added to the left of the shape.
 Axes with length1 are streched to match the coresponding axes of the other tensors.
A length1 vector is streched to match the other vectors:
tf.where([True, False, False, True], [1, 2, 3, 4], [100]).numpy()
array([ 1, 100, 100, 4], dtype=int32)
A scalar is expanded to match the other arguments:
tf.where([[True, False], [False, True]], [[1, 2], [3, 4]], 100).numpy()
array([[ 1, 100], [100, 4]], dtype=int32)
tf.where([[True, False], [False, True]], 1, 100).numpy()
array([[ 1, 100], [100, 1]], dtype=int32)
A scalar condition
returns the complete x
or y
tensor, with
broadcasting applied.
tf.where(True, [1, 2, 3, 4], 100).numpy()
array([1, 2, 3, 4], dtype=int32)
tf.where(False, [1, 2, 3, 4], 100).numpy()
array([100, 100, 100, 100], dtype=int32)
For a nontrivial example of broadcasting, here condition
has a shape of
[3]
, x
has a shape of [3,3]
, and y
has a shape of [3,1]
.
Broadcasting first expands the shape of condition
to [1,3]
. The final
broadcast shape is [3,3]
. condition
will select columns from x
and y
.
Since y
only has one column, all columns from y
will be identical.
tf.where([True, False, True],
x=[[1, 2, 3],
[4, 5, 6],
[7, 8, 9]],
y=[[100],
[200],
[300]]
).numpy()
array([[ 1, 100, 3],
[ 4, 200, 6],
[ 7, 300, 9]], dtype=int32)
Note that if the gradient of either branch of the tf.where
generates
a NaN
, then the gradient of the entire tf.where
will be NaN
. This is
because the gradient calculation for tf.where
combines the two branches, for
performance reasons.
A workaround is to use an inner tf.where
to ensure the function has
no asymptote, and to avoid computing a value whose gradient is NaN
by
replacing dangerous inputs with safe inputs.
Instead of this,
x = tf.constant(0., dtype=tf.float32)
with tf.GradientTape() as tape:
tape.watch(x)
y = tf.where(x < 1., 0., 1. / x)
print(tape.gradient(y, x))
tf.Tensor(nan, shape=(), dtype=float32)
Although, the 1. / x
values are never used, its gradient is a NaN
when
x = 0
. Instead, we should guard that with another tf.where
x = tf.constant(0., dtype=tf.float32)
with tf.GradientTape() as tape:
tape.watch(x)
safe_x = tf.where(tf.equal(x, 0.), 1., x)
y = tf.where(x < 1., 0., 1. / safe_x)
print(tape.gradient(y, x))
tf.Tensor(0.0, shape=(), dtype=float32)
See also:
tf.sparse
 The indices returned by the first form oftf.where
can be useful intf.sparse.SparseTensor
objects.tf.gather_nd
,tf.scatter_nd
, and related ops  Given the list of indices returned fromtf.where
thescatter
andgather
family of ops can be used fetch values or insert values at those indices.tf.strings.length
tf.string
is not an allowed dtype for thecondition
. Use the string length instead.
Args  

condition

A tf.Tensor of dtype bool, or any numeric dtype. condition
must have dtype bool when x and y are provided.

x

If provided, a Tensor which is of the same type as y , and has a shape
broadcastable with condition and y .

y

If provided, a Tensor which is of the same type as x , and has a shape
broadcastable with condition and x .

name

A name of the operation (optional). 
Returns  

If x and y are provided:
A Tensor with the same type as x and y , and shape that
is broadcast from condition , x , and y .
Otherwise, a Tensor with shape [tf.math.count_nonzero(condition),
tf.rank(condition)] .

Raises  

ValueError

When exactly one of x or y is nonNone, or the shapes
are not all broadcastable.
