|TensorFlow 2 version||View source on GitHub|
true_fn() if the predicate
pred is true else
false_fn(). (deprecated arguments)
tf.cond( pred, true_fn=None, false_fn=None, strict=False, name=None, fn1=None, fn2=None )
false_fn both return lists of output tensors.
false_fn must have the same non-zero number and type of outputs.
WARNING: Any Tensors or Operations created outside of
false_fn will be executed regardless of which branch is selected at runtime.
Although this behavior is consistent with the dataflow model of TensorFlow, it has frequently surprised users who expected a lazier semantics. Consider the following simple program:
z = tf.multiply(a, b) result = tf.cond(x < y, lambda: tf.add(x, z), lambda: tf.square(y))
x < y, the
tf.add operation will be executed and
operation will not be executed. Since
z is needed for at least one
branch of the
tf.multiply operation is always executed,
false_fn exactly once (inside the
cond, and not at all during
stitches together the graph fragments created during the
false_fn calls with some additional graph nodes to ensure that the right
branch gets executed depending on the value of
tf.cond supports nested structures as implemented in
false_fn must return the
same (possibly nested) value structure of lists, tuples, and/or named tuples.
Singleton lists and tuples form the only exceptions to this: when returned by
false_fn, they are implicitly unpacked to single values.
This behavior is disabled by passing
pred: A scalar determining whether to return the result of
true_fn: The callable to be performed if pred is true.
false_fn: The callable to be performed if pred is false.
strict: A boolean that enables/disables 'strict' mode; see above.
name: Optional name prefix for the returned tensors.
Tensors returned by the call to either
false_fn. If the
callables return a singleton list, the element is extracted from the list.
false_fnis not callable.
false_fndo not return the same number of tensors, or return tensors of different types.
x = tf.constant(2) y = tf.constant(5) def f1(): return tf.multiply(x, 17) def f2(): return tf.add(y, 23) r = tf.cond(tf.less(x, y), f1, f2) # r is set to f1(). # Operations in f2 (e.g., tf.add) are not executed.