Missed TensorFlow World? Check out the recap. Learn more

tf.case

View source on GitHub

Create a case operation.

Aliases:

  • tf.compat.v2.case
tf.case(
    pred_fn_pairs,
    default=None,
    exclusive=False,
    strict=False,
    name='case'
)

See also tf.switch_case.

The pred_fn_pairs parameter is a list of pairs of size N. Each pair contains a boolean scalar tensor and a python callable that creates the tensors to be returned if the boolean evaluates to True. default is a callable generating a list of tensors. All the callables in pred_fn_pairs as well as default (if provided) should return the same number and types of tensors.

If exclusive==True, all predicates are evaluated, and an exception is thrown if more than one of the predicates evaluates to True. If exclusive==False, execution stops at the first predicate which evaluates to True, and the tensors generated by the corresponding function are returned immediately. If none of the predicates evaluate to True, this operation returns the tensors generated by default.

tf.case supports nested structures as implemented in tf.contrib.framework.nest. All of the callables must return the same (possibly nested) value structure of lists, tuples, and/or named tuples. Singleton lists and tuples form the only exceptions to this: when returned by a callable, they are implicitly unpacked to single values. This behavior is disabled by passing strict=True.

Example 1:

Pseudocode:

if (x < y) return 17;
else return 23;

Expressions:

f1 = lambda: tf.constant(17)
f2 = lambda: tf.constant(23)
r = tf.case([(tf.less(x, y), f1)], default=f2)

Example 2:

Pseudocode:

if (x < y && x > z) raise OpError("Only one predicate may evaluate to True");
if (x < y) return 17;
else if (x > z) return 23;
else return -1;

Expressions:

def f1(): return tf.constant(17)
def f2(): return tf.constant(23)
def f3(): return tf.constant(-1)
r = tf.case([(tf.less(x, y), f1), (tf.greater(x, z), f2)],
         default=f3, exclusive=True)

Args:

  • pred_fn_pairs: List of pairs of a boolean scalar tensor and a callable which returns a list of tensors.
  • default: Optional callable that returns a list of tensors.
  • exclusive: True iff at most one predicate is allowed to evaluate to True.
  • strict: A boolean that enables/disables 'strict' mode; see above.
  • name: A name for this operation (optional).

Returns:

The tensors returned by the first pair whose predicate evaluated to True, or those returned by default if none does.

Raises:

  • TypeError: If pred_fn_pairs is not a list/tuple.
  • TypeError: If pred_fn_pairs is a list but does not contain 2-tuples.
  • TypeError: If fns[i] is not callable for any i, or default is not callable.

V2 Compatibility

pred_fn_pairs could be a dictionary in v1. However, tf.Tensor and tf.Variable are no longer hashable in v2, so cannot be used as a key for a dictionary. Please use a list or a tuple instead.