このページは Cloud Translation API によって翻訳されました。
Switch to English

tf.while_loop

TensorFlow 1つのバージョン GitHubの上のソースを表示

繰り返しbody条件ながら、 cond真です。 (引数値を非推奨)

ノートPCで使用されます

チュートリアルで使用されます

condブーリアンスカラーテンソルを返す呼び出し可能です。 body同じアリティ(長さ及び構造)のテンソル及び種類の(おそらくネストした)タプル、namedtuple又はリストを返す呼び出し可能であるloop_varsloop_vars (おそらくネスト)タプル、namedtupleまたは両方に渡されるテンソルのリストであるcondbodycondbodyあるとして多くの引数として取るの両方loop_vars

通常のテンソルまたはIndexedSlicesに加えて、体が受け入れ、TensorArrayオブジェクトを返すことがあります。適切ループ間および勾配計算中に転送されるTensorArrayオブジェクトの流れ。

注意while_loop呼び出してcondしてbody (するための呼び出しの内部で正確に一度 while_loopの間、およびすべてではないにSession.run() while_loop一緒にステッチ中に作成されたグラフの断片をcondし、 body繰り返し、グラフフロー作成するためにいくつかの追加のグラフのノードとの通話bodyするまでcondを返すの虚偽を。

正しさのために、 tf.while_loop()厳密にループ変数の形状不変量を強制します。形状不変量は、ループの反復を横切って変化しない(おそらく部分的な)形状です。反復後のループ変数の形状はより一般的な、またはその形状不変と互換性があると判定された場合は、エラーが発生します。例えば、[11]、[なし]の形状は[11、17]の形状よりも一般的であり、そして[11、21]、[11、17]と互換性がありません。デフォルトでは(引数の場合shape_invariants指定されていない)、各テンソルの初期形状と仮定されるloop_vars全ての反復でも同じです。 shape_invariants引数は、呼び出し側が形状が反復間で変化する場合に必要とされる各ループの変数のためのより少ない特定の形状不変量を指定することを可能にします。 tf.Tensor.set_shape機能はまた、使用することができるbody出力ループ変数が特定の形状を有していることを示すように機能します。次のようにSparseTensorとIndexedSlicesする形状不変量は、特別に処理されます。

ループ変数がSparseTensor場合a)に示すように、形状不変量は、rがスパーステンソルで表される緻密なテンソルの階数であるTensorShape([R])でなければなりません。それはSparseTensorの三のテンソルの形状であることを意味する([なし]、[なし]、[R]、[R])。注:ここでの形状不変はSparseTensor.dense_shapeプロパティの形状です。これは、ベクトルの形でなければなりません。

ループ変数がIndexedSlices場合b)に示すように、形状不変量はIndexedSlicesの値テンソルの形状不変でなければなりません。これはIndexedSlicesの三のテンソルの形状は(形状、[形状[0]]、[shape.ndims])であることを意味します。

while_loop 、並列に実行する複数の反復を可能にする、非厳格なセマンティクスを実装しています。並列反復の最大数がによって制御することができるparallel_iterationsユーザにメモリ消費量と実行順序をある程度制御することができます。正しいプログラムでは、 while_loop > 0は任意のparallel_iterationsために同じ結果を返す必要があります。

トレーニングのために、TensorFlowは前向き推論で生産されており、バックプロパゲーションで必要とされるテンソルを格納します。これらのテンソルは、メインメモリ消費のソースと、多くの場合、OOMエラー原因GPU上で訓練されています。フラグswap_memoryがtrueの場合、我々はGPUからCPUにこれらのテンソルをスワップアウト。これは、例えば、私たちは非常に長いシーケンスと大きなバッチでRNNモデルを訓練することができます。

cond ループの終了条件を表して呼び出し可能。
body ループ本体を表し呼び出し可能。
loop_vars (おそらくネスト)タプル、namedtupleまたはnumpyの配列のリスト、 Tensor 、及びTensorArrayオブジェクト。
shape_invariants ループ変数の形状不変。
parallel_iterations 反復回数は、並列に実行することができました。これは、正の整数でなければなりません。
back_prop (オプション)推奨されていません。 Falseの無効は、バックプロパゲーションのためにサポートされています。使用して優先tf.stop_gradient代わりに。
swap_memory かどうかはGPU-CPUのメモリスワップは、このループのために有効になっています。
maximum_iterations whileループの反復のオプションの最大数を実行します。提供される場合、 cond出力はAND-ED実行反復回数を確保する追加条件とは、以下であるmaximum_iterations
name 返されたテンソルのためのオプションの名前の接頭辞。

ループ後のループ変数の出力テンソル。戻り値は、同じ構造を有するloop_vars

TypeError 場合condまたはbody呼び出すことはできません。
ValueError 場合loop_vars空です。

例:

 i = tf.constant(0)
c = lambda i: tf.less(i, 10)
b = lambda i: (tf.add(i, 1), )
r = tf.while_loop(c, b, [i])
 

ネスティングとnamedtupleとの例:

 import collections
Pair = collections.namedtuple('Pair', 'j, k')
ijk_0 = (tf.constant(0), Pair(tf.constant(1), tf.constant(2)))
c = lambda i, p: i < 10
b = lambda i, p: (i + 1, Pair((p.j + p.k), (p.j - p.k)))
ijk_final = tf.while_loop(c, b, ijk_0)
 

shape_invariantsを使用した例:

 i0 = tf.constant(0)
m0 = tf.ones([2, 2])
c = lambda i, m: i < 10
b = lambda i, m: [i+1, tf.concat([m, m], axis=0)]
tf.while_loop(
    c, b, loop_vars=[i0, m0],
    shape_invariants=[i0.get_shape(), tf.TensorShape([None, 2])])
 

非厳密なセマンティクスを示す例は、次の例では、カウンタの最終値は、 i依存しないx 。だから、 while_loopのアップデートにカウンター並列をインクリメントすることができますx 。 1回のループ反復のループカウンタが前回の反復の値に依存するためしかし、ループカウンタ自体が並列にインクリメントすることができません。私たちは(私たちはラインに印刷カウンタの最終値たい場合は、そのためprint(sess.run(i)) )、そしてxインクリメントされることはありませんが、カウンタは、単一のスレッド上で更新されます。逆に、我々は、出力の値場合は(我々は、ライン上に印刷するprint(sess.run(out).shape)しながら、カウンタは、それ自身のスレッドでインクリメントされてもよいx Aに平行でインクリメントすることができます別のスレッド。極端なケースでは、前にカウンタをインクリメントするスレッドが完了するまで実行されることが考えられるxあっても、単一の時間にインクリメントされます。決して起こらないことができる唯一のことは、スレッドの更新ということであるxインクリメントスレッドので、控えカウンタスレッドの取得することはありませんすることができx 、カウンタの値に依存します。

 import tensorflow as tf

n = 10000
x = tf.constant(list(range(n)))
c = lambda i, x: i < n
b = lambda i, x: (tf.compat.v1.Print(i + 1, [i]), tf.compat.v1.Print(x + 1,
[i], "x:"))
i, out = tf.while_loop(c, b, (0, x))
with tf.compat.v1.Session() as sess:
    print(sess.run(i))  # prints [0] ... [9999]

    # The following line may increment the counter and x in parallel.
    # The counter thread may get ahead of the other thread, but not the
    # other way around. So you may see things like
    # [9996] x:[9987]
    # meaning that the counter thread is on iteration 9996,
    # while the other thread is on iteration 9987
    print(sess.run(out).shape)