此页面由 Cloud Translation API 翻译。
Switch to English

tf.while_loop

TensorFlow 1版 GitHub上查看源代码

重复body而条件cond为真。 (不建议使用的参数值)

用在笔记本电脑

使用教程

cond是一个可调用的返回一个布尔标张量。 body是一个可调用返回相同的元数(长度和结构)和作为类型的张量的一个(可能是嵌套的)元组,namedtuple或列表loop_varsloop_vars是一个(可能是嵌套)倍速,namedtuple或传递给两个张量的列表condbodycondbody都接受的,因为有很多争论loop_vars

除了常规的张量或IndexedSlices,身体可以接受并返回TensorArray对象。该TensorArray对象将环路之间,并在梯度计算被适当地转发的流动。

需要注意的是while_loop调用condbody 一次 (内部调用while_loop ,而不是在所有期间Session.run() while_loop针一起期间创建的图形片段condbody具有一些附加图形节点调用创建图形流即重复body直到cond返回false。

对于正确性, tf.while_loop()严格执行的循环变量不变的形状。形状不变是(可能局部的)形状,其是在整个循环迭代不变。如果确定迭代之后的循环变量的形状为比多个通用或与其形状不变的不相容的错误将被提高。例如,[11,无]的形状比[11,17]的形状更一般的,和[11,21]是与[11,17]兼容。默认情况下(如果参数shape_invariants未指定)时,假定在各张量的初始形状loop_vars是在每次迭代相同的。所述shape_invariants参数允许调用者指定为每个循环变量,如果形状迭代之间变化所需要较不具体的形状不变的。该tf.Tensor.set_shape功能也可以在所使用的body功能,以指示所述输出循环变量具有特定的形状。对于SparseTensor和IndexedSlices形状不变的特殊处理如下:

a)如果循环变量是一个SparseTensor,形状不变必须TensorShape([R]),其中R是由所述稀疏张量表示的致密张量的维数。这意味着SparseTensor的三个张量的形状([无],[无,R],[R])。注:外形不变的这里是SparseTensor.dense_shape属性的形状。它必须是一个矢量的形状。

b)若循环变量是一个IndexedSlices,形状不变必须是IndexedSlices的值张量的形状不变。这意味着IndexedSlices的三个张量的形状(形状,[形状[0]],[shape.ndims])。

while_loop器具非严格的语义,可以使多个迭代以并行运行。并行迭代的最大数目可通过被控制parallel_iterations ,其为用户提供对存储器消耗和执行顺序的一些控制。对于正确的程序, while_loop应该返回相同的结果对于任何parallel_iterations> 0。

对于训练,TensorFlow存储了在正向推理产生并且需要在反向传播的张量。这些张量内存消耗的主要来源,还常常造成OOM错误在GPU上训练的时候。当标志swap_memory是真的,我们换出这些张量从GPU到CPU。这例如允许我们训练RNN模型很长的序列和大批量。

cond 代表该循环的终止条件A调用。
body 表示循环体调用。
loop_vars 一个(可能嵌套)倍速,namedtuple或numpy的阵列的列表中, Tensor ,和TensorArray对象。
shape_invariants 形状不变的循环变量。
parallel_iterations 迭代次数允许并行运行。它必须是一个正整数。
back_prop (可选)已过时。假禁用反向传播支持。喜欢使用tf.stop_gradient代替。
swap_memory 不论是GPU-CPU内存交换为这个循环启用。
maximum_iterations while循环的迭代可选最大数量的运行。如果提供, cond输出AND-ED与附加条件保证执行的迭代的数量不大于maximum_iterations
name 可选的名称前缀为返回的张量。

输出张量为循环后的循环变量。返回值具有相同的结构loop_vars

TypeError 如果condbody不调用。
ValueError 如果loop_vars是空的。

例:

 i = tf.constant(0)
c = lambda i: tf.less(i, 10)
b = lambda i: (tf.add(i, 1), )
r = tf.while_loop(c, b, [i])
 

实施例与嵌套和一个namedtuple:

 import collections
Pair = collections.namedtuple('Pair', 'j, k')
ijk_0 = (tf.constant(0), Pair(tf.constant(1), tf.constant(2)))
c = lambda i, p: i < 10
b = lambda i, p: (i + 1, Pair((p.j + p.k), (p.j - p.k)))
ijk_final = tf.while_loop(c, b, ijk_0)
 

使用实施例shape_invariants:

 i0 = tf.constant(0)
m0 = tf.ones([2, 2])
c = lambda i, m: i < 10
b = lambda i, m: [i+1, tf.concat([m, m], axis=0)]
tf.while_loop(
    c, b, loop_vars=[i0, m0],
    shape_invariants=[i0.get_shape(), tf.TensorShape([None, 2])])
 

例如这表明非严格语义:在下面的例子中,计数器的最终值i不依赖于x 。所以while_loop可以增加平行于更新计数器x 。然而,由于在一个循环迭代循环计数器取决于在先前的迭代中值,循环计数器本身不能并行递增。因此,如果我们只是想计数器(我们打印上线的最终值print(sess.run(i))那么x将永远不会被增加,但计数器将在单个线程进行更新。相反,如果我们希望的输出的值(我们就行打印print(sess.run(out).shape)则计数器可以在它自己的线程递增,而x可以并行地递增一单独的线程。在极端情况下,可以想象的是线增进计数器运行,直到完成前x递增甚至一次。可以永远不会发生的唯一的事情就是线程更新x永远无法提前获得反螺纹的,因为线增进x取决于该计数器的值。

 import tensorflow as tf

n = 10000
x = tf.constant(list(range(n)))
c = lambda i, x: i < n
b = lambda i, x: (tf.compat.v1.Print(i + 1, [i]), tf.compat.v1.Print(x + 1,
[i], "x:"))
i, out = tf.while_loop(c, b, (0, x))
with tf.compat.v1.Session() as sess:
    print(sess.run(i))  # prints [0] ... [9999]

    # The following line may increment the counter and x in parallel.
    # The counter thread may get ahead of the other thread, but not the
    # other way around. So you may see things like
    # [9996] x:[9987]
    # meaning that the counter thread is on iteration 9996,
    # while the other thread is on iteration 9987
    print(sess.run(out).shape)