Control Flow Operations

TensorFlow provides several operations and classes that you can use to control the execution of operations and add conditional dependencies to your graph.

tf.identity(input, name=None)

Return a tensor with the same shape and contents as the input tensor or value.

  • input: A Tensor.
  • name: A name for the operation (optional).

A Tensor. Has the same type as input.

tf.tuple(tensors, name=None, control_inputs=None)

Group tensors together.

This creates a tuple of tensors with the same values as the tensors argument, except that the value of each tensor is only returned after the values of all tensors have been computed.

control_inputs contains additional ops that have to finish before this op finishes, but whose outputs are not returned.

This can be used as a "join" mechanism for parallel computations: all the argument tensors can be computed in parallel, but the values of any tensor returned by tuple are only available after all the parallel computations are done.

See also group and with_dependencies.

  • tensors: A list of Tensors or IndexedSlices, some entries can be None.
  • name: (optional) A name to use as a name_scope for the operation.
  • control_inputs: List of additional ops to finish before returning.

Same as tensors.

  • ValueError: If tensors does not contain any Tensor or IndexedSlices.
  • TypeError: If control_inputs is not a list of Operation or Tensor objects.*inputs, **kwargs)

Create an op that groups multiple operations.

When this op finishes, all ops in input have finished. This op has no output.

See also tuple and with_dependencies.

  • *inputs: Zero or more tensors to group.
  • **kwargs: Optional parameters to pass when constructing the NodeDef.
  • name: A name for this operation (optional).

An Operation that executes all its inputs.

  • ValueError: If an unknown keyword argument is provided.


Does nothing. Only useful as a placeholder for control edges.

  • name: A name for the operation (optional).

The created Operation.

tf.count_up_to(ref, limit, name=None)

Increments 'ref' until it reaches 'limit'.

This operation outputs "ref" after the update is done. This makes it easier to chain operations that need to use the updated value.

  • ref: A mutable Tensor. Must be one of the following types: int32, int64. Should be from a scalar Variable node.
  • limit: An int. If incrementing ref would bring it above limit, instead generates an 'OutOfRange' error.
  • name: A name for the operation (optional).

A Tensor. Has the same type as ref. A copy of the input before increment. If nothing else modifies the input, the values produced will all be distinct.

tf.cond(pred, fn1, fn2, name=None)

Return either fn1() or fn2() based on the boolean predicate pred.

fn1 and fn2 both return lists of output tensors. fn1 and fn2 must have the same non-zero number and type of outputs.

Note that the conditional execution applies only to the operations defined in fn1 and fn2. Consider the following simple program:

z = tf.mul(a, b)
result = tf.cond(x < y, lambda: tf.add(x, z), lambda: tf.square(y))

If x < y, the tf.add operation will be executed and tf.square operation will not be executed. Since z is needed for at least one branch of the cond, the tf.mul operation is always executed, unconditionally. Although this behavior is consistent with the dataflow model of TensorFlow, it has occasionally surprised some users who expected a lazier semantics.

  • pred: A scalar determining whether to return the result of fn1 or fn2.
  • fn1: The callable to be performed if pred is true.
  • fn2: The callable to be performed if pref is false.
  • name: Optional name prefix for the returned tensors.

Tensors returned by the call to either fn1 or fn2. If the callables return a singleton list, the element is extracted from the list.

  • TypeError: if fn1 or fn2 is not callable.
  • ValueError: if fn1 and fn2 do not return the same number of tensors, or return tensors of different types.

  • Example:

  x = tf.constant(2)
  y = tf.constant(5)
  def f1(): return tf.mul(x, 17)
  def f2(): return tf.add(y, 23)
  r = cond(tf.less(x, y), f1, f2)
  # r is set to f1().
  # Operations in f2 (e.g., tf.add) are not executed., default, exclusive=False, name='case')

Create a case operation.

The pred_fn_pairs parameter is a dict or list of pairs of size N. Each pair contains a boolean scalar tensor and a python callable that creates the tensors to be returned if the boolean evaluates to True. default is a callable generating a list of tensors. All the callables in pred_fn_pairs as well as default should return the same number and types of tensors.

If exclusive==True, all predicates are evaluated, and a logging operation with an error is returned if more than one of the predicates evaluates to True. If exclusive==False, execution stops are the first predicate which evaluates to True, and the tensors generated by the corresponding function are returned immediately. If none of the predicates evaluate to True, this operation returns the tensors generated by default.

Example 1: Pseudocode: if (x < y) return 17; else return 23;

Expressions: f1 = lambda: tf.constant(17) f2 = lambda: tf.constant(23) r = case([(tf.less(x, y), f1)], default=f2)

Example 2: Pseudocode: if (x < y && x > z) raise OpError("Only one predicate may evaluate true"); if (x < y) return 17; else if (x > z) return 23; else return -1;

Expressions: x = tf.constant(0) y = tf.constant(1) z = tf.constant(2) def f1(): return tf.constant(17) def f2(): return tf.constant(23) def f3(): return tf.constant(-1) r = case({tf.less(x, y): f1, tf.greater(x, z): f2}, default=f3, exclusive=True)

  • pred_fn_pairs: Dict or list of pairs of a boolean scalar tensor and a callable which returns a list of tensors.
  • default: A callable that returns a list of tensors.
  • exclusive: True iff more than one predicate is allowed to evaluate to True.
  • name: A name for this operation (optional).

The tensors returned by the first pair whose predicate evaluated to True, or those returned by default if none does.

  • TypeError: If pred_fn_pairs is not a list/dictionary.
  • TypeError: If pred_fn_pairs is a list but does not contain 2-tuples.
  • TypeError: If fns[i] is not callable for any i, or default is not callable.

tf.while_loop(cond, body, loop_vars, parallel_iterations=10, back_prop=True, swap_memory=False, name=None)

Repeat body while the condition cond is true.

cond is a callable returning a boolean scalar tensor. body is a callable returning a (possibly nested) tuple or list of tensors of the same arity (length and structure) and types as loop_vars. loop_vars is a (possibly nested) tuple or list of tensors that is passed to both cond and body. cond and body both take as many arguments as there are loop_vars.

In addition to regular Tensors or IndexedSlices, the body may accept and return TensorArray objects. The flows of the TensorArray objects will be appropriately forwarded between loops and during gradient calculations.

While cond evaluates to true, body is executed.

while_loop implements non-strict semantics, enabling multiple iterations to run in parallel. The maximum number of parallel iterations can be controlled by parallel_iterations, which gives users some control over memory consumption and execution order. For correct programs, while_loop should return the same result for any parallel_iterations > 0.

For training, TensorFlow remembers the tensors that are produced in the forward inference but needed in back propagation. These tensors can be a main source of memory consumption and often cause OOM problems when training on GPUs. When the flag swap_memory is true, we swap out these tensors from GPU to CPU. This for example allows us to train RNN models with very long sequences and large batches.

  • cond: A callable that represents the termination condition of the loop.
  • body: A callable that represents the loop body.
  • loop_vars: A (possibly nested) tuple or list of numpy array, Tensor, and TensorArray objects.
  • parallel_iterations: The number of iterations allowed to run in parallel.
  • back_prop: Whether backprop is enabled for this while loop.
  • swap_memory: Whether GPU-CPU memory swap is enabled for this loop.
  • name: Optional name prefix for the returned tensors.

The output tensors for the loop variables after the loop. When the length of loop_vars is 1 this is a Tensor, TensorArry or IndexedSlice and when the length of loop_vars is greater than 1 it returns a list.

  • TypeError: if cond or body is not callable.
  • ValueError: if loop_vars is empty.

  • Example:

python i = tf.constant(0) c = lambda i: tf.less(i, 10) b = lambda i: tf.add(i, 1) r = tf.while_loop(c, b, [i])

Example with nesting:

python ijk_0 = (tf.constant(0), (tf.constant(1), tf.constant(2))) c = lambda i, (j, k): i < 10 b = lambda i, (j, k): (i + 1, ((j + k), (j - k))) ijk_final = tf.while_loop(c, b, ijk_0)