Missed TensorFlow Dev Summit? Check out the video playlist. Watch recordings

tf.function で性能アップ

View on TensorFlow.org Run in Google Colab View source on GitHub Download notebook

TensorFlow 2.0 では Eager Execution が既定で有効になっています。ユーザーインターフェイスは直感的で柔軟です(演算を一度だけ行う場合にはずっと簡単に、かつ迅速に実行されます)。しかしながら、それは性能と展開の面での犠牲の上に成り立っています。

最高性能を得ながら、モデルをどこへでも展開できるようにするには、tf.function を使ってプログラムから計算グラフを作成します。 AutoGraph のおかげで、驚くほど多くの Python コードが tf.function でそのまま動作しますが、気をつけなければならない落とし穴も存在します。

ポイントと推奨事項は下記の通りです。

  • オブジェクトの変更やリストへの追加のような Python の副作用に依存しないこと
  • tf.functions は NumPy の演算や Python の組み込み演算よりも、TensorFlow の演算に適していること
  • 迷ったときは、for x in y というイディオムを使うこと
from __future__ import absolute_import, division, print_function, unicode_literals

import tensorflow as tf
import contextlib

# 遭遇するかもしれないいくつかのエラーをデモするためのヘルパー関数
@contextlib.contextmanager
def assert_raises(error_class):
  try:
    yield
  except error_class as e:
    print('Caught expected exception \n  {}: {}'.format(error_class, e))
  except Exception as e:
    print('Got unexpected exception \n  {}: {}'.format(type(e), e))
  else:
    raise Exception('Expected {} to be raised but no error was raised!'.format(
        error_class))

あなたが定義した tf.function は TensorFlow Core の演算に似たものです。例えばそれを即時に実行することも、計算グラフで使うこともできますし、勾配を計算することも可能です。

# function は演算のように振る舞う

@tf.function
def add(a, b):
  return a + b

add(tf.ones([2, 2]), tf.ones([2, 2]))  #  [[2., 2.], [2., 2.]]
<tf.Tensor: shape=(2, 2), dtype=float32, numpy=
array([[2., 2.],
       [2., 2.]], dtype=float32)>
# function は勾配を計算できる

@tf.function
def add(a, b):
  return a + b

v = tf.Variable(1.0)
with tf.GradientTape() as tape:
  result = add(v, 1.0)
tape.gradient(result, v)
<tf.Tensor: shape=(), dtype=float32, numpy=1.0>
# function 内で function を使うこともできる

@tf.function
def dense_layer(x, w, b):
  return add(tf.matmul(x, w), b)

dense_layer(tf.ones([3, 2]), tf.ones([2, 2]), tf.ones([2]))
<tf.Tensor: shape=(3, 2), dtype=float32, numpy=
array([[3., 3.],
       [3., 3.],
       [3., 3.]], dtype=float32)>

トレーシングとポリモーフィズム

Python の動的型付けは、関数をさまざまな型の引数で呼び出すことができ、Python がそれぞれのシナリオで異なる動作をするということを意味します。

他方で、TensorFlow の計算グラフでは、dtype と shape の次元が静的であることが必要です。tf.function は、正しい計算グラフを生成するために必要なときには関数を再トレースして、このギャップをつなぐ役割を果たします。

異なる型の引数を使って関数を呼び出し、何が起きるか見てみましょう。

# Function はポリモーフィック

@tf.function
def double(a):
  print("Tracing with", a)
  return a + a

print(double(tf.constant(1)))
print()
print(double(tf.constant(1.1)))
print()
print(double(tf.constant("a")))
print()

Tracing with Tensor("a:0", shape=(), dtype=int32)
tf.Tensor(2, shape=(), dtype=int32)

Tracing with Tensor("a:0", shape=(), dtype=float32)
tf.Tensor(2.2, shape=(), dtype=float32)

Tracing with Tensor("a:0", shape=(), dtype=string)
tf.Tensor(b'aa', shape=(), dtype=string)

トレースの動作を制御するためには、下記のようなテクニックを使います。

  • 新しい tf.function を作成する。別々の tf.function オブジェクトがトレースを共有することはない。
  • 特定のトレースを得るには get_concrete_function メソッドを使用する。
  • 計算グラフの呼び出し時に1回だけトレースを行うには、 input_signature を指定して tf.function を呼び出す。
print("Obtaining concrete trace")
double_strings = double.get_concrete_function(tf.TensorSpec(shape=None, dtype=tf.string))
print("Executing traced function")
print(double_strings(tf.constant("a")))
print(double_strings(a=tf.constant("b")))
print("Using a concrete trace with incompatible types will throw an error")
with assert_raises(tf.errors.InvalidArgumentError):
  double_strings(tf.constant(1))
Obtaining concrete trace
Tracing with Tensor("a:0", dtype=string)
Executing traced function
tf.Tensor(b'aa', shape=(), dtype=string)
tf.Tensor(b'bb', shape=(), dtype=string)
Using a concrete trace with incompatible types will throw an error
Caught expected exception 
  <class 'tensorflow.python.framework.errors_impl.InvalidArgumentError'>: cannot compute __inference_double_88 as input #0(zero-based) was expected to be a string tensor but is a int32 tensor [Op:__inference_double_88]
@tf.function(input_signature=(tf.TensorSpec(shape=[None], dtype=tf.int32),))
def next_collatz(x):
  print("Tracing with", x)
  return tf.where(tf.equal(x % 2, 0), x // 2, 3 * x + 1)

print(next_collatz(tf.constant([1, 2])))
# 1次元のテンソルを input signature として指定しているので、これは失敗する
with assert_raises(ValueError):
  next_collatz(tf.constant([[1, 2], [3, 4]]))
Tracing with Tensor("x:0", shape=(None,), dtype=int32)
tf.Tensor([4 1], shape=(2,), dtype=int32)
Caught expected exception 
  <class 'ValueError'>: Python inputs incompatible with input_signature:
  inputs: (
    tf.Tensor(
[[1 2]
 [3 4]], shape=(2, 2), dtype=int32))
  input_signature: (
    TensorSpec(shape=(None,), dtype=tf.int32, name=None))

いつ再トレースするのか?

ポリモーフィックな tf.function はトレーシングによって生成された具象関数のキャッシュを保持しています。キャッシュのキーは、実際にはその関数の引数及びキーワード引数から生成されたキーのタプルです。tf.Tensor 引数から生成されるキーは、テンソルの shape と型です。Python の組み込み型引数から生成されるキーはその値です。それ以外の Python の型では、キーはオブジェクトの id() に基づいており、メソッドはクラスのインスタンスひとつずつ独立にトレースされます。将来、TensorFlowには、Python オブジェクトについて安全にテンソルに変換できるような、より洗練されたキャッシングが追加されるかもしれません。

引数は Python か? Tensor か?

しばしば、ハイパーパラメータやグラフ構成を制御するために Python の組み込み型の引数が使われます。例えば、num_layers=10training=True あるいは nonlinearity='relu' のようにです。このため、この Python の組み込み型の引数が変更されると、計算グラフを再びトレースする必要があるということになります。

しかし、グラフの生成を制御するために Python の組み込み型の引数を使用する必要はありません。これらのケースでは、Python引数の値の変更が不必要な再トレースを引き起こす可能性があります。例えば、この訓練ループでは、AutoGraph は動的に展開を行います。複数回トレースを行っていますが、生成される計算グラフは全く変わりません。これは少し非効率です。

def train_one_step():
  pass

@tf.function
def train(num_steps):
  print("Tracing with num_steps = {}".format(num_steps))
  for _ in tf.range(num_steps):
    train_one_step()

train(num_steps=10)
train(num_steps=20)

Tracing with num_steps = 10
Tracing with num_steps = 20

ここでの簡単な回避方法は、生成されたグラフの shape が変わらないのであれば、引数をテンソルにキャストすることです。

train(num_steps=tf.constant(10))
train(num_steps=tf.constant(20))
Tracing with num_steps = Tensor("num_steps:0", shape=(), dtype=int32)

tf.function の中の副作用

一般的には、(印字やオブジェクト変更のような)Python の副作用は、トレーシングの最中にだけ発生します。それでは、どうしたら tf.function で安定的に副作用を起こすことができるでしょうか?

一般的な原則は、トレースをデバッグする際にだけ Python の副作用を使用するというものです。あるいは、tf.Variable.assigntf.print、そして tf.summary のような TensorFlow の演算を使うことで、コードがトレースされるときにも、TensorFlowランタイムによって都度呼び出される際にも、確実に実行されるようにできます。一般には、関数型のスタイルを使用することで最も良い結果を得られます。

@tf.function
def f(x):
  print("Traced with", x)
  tf.print("Executed with", x)

f(1)
f(1)
f(2)

Traced with 1
Executed with 1
Executed with 1
Traced with 2
Executed with 2

tf.function が呼び出されるたびに Python のコードを実行したいのであれば、tf.py_function がぴったりです。tf.py_function の欠点は、ポータブルでないこと、それほど性能が高くないこと、(マルチGPU、TPUの)分散環境ではうまく動作しないことなどです。また、tf.py_function は計算グラフに組み込まれるため、入出力すべてをテンソルにキャストします。

external_list = []

def side_effect(x):
  print('Python side effect')
  external_list.append(x)

@tf.function
def f(x):
  tf.py_function(side_effect, inp=[x], Tout=[])

f(1)
f(1)
f(1)
assert len(external_list) == 3
# .numpy() call required because py_function casts 1 to tf.constant(1)
assert external_list[0].numpy() == 1

Python side effect
Python side effect
Python side effect

Python の状態に注意

ジェネレーターやイテレーターなど Python の機能の多くは、状態を追跡するために Python のランタイムに依存しています。これらの仕組みは、一般的には Eager モードでも期待通りに動作しますが、トレーシングの振る舞いにより、tf.function の中では予期しないことが起きることがあります。

1例として、イテレーターの状態が進むのは Python の副作用であり、トレーシングの中だけで発生します。

external_var = tf.Variable(0)
@tf.function
def buggy_consume_next(iterator):
  external_var.assign_add(next(iterator))
  tf.print("Value of external_var:", external_var)

iterator = iter([0, 1, 2, 3])
buggy_consume_next(iterator)
# 次のコードは、イテレーターの次の値を使うのではなく、最初の値を再利用する
buggy_consume_next(iterator)
buggy_consume_next(iterator)
Value of external_var: 0
Value of external_var: 0
Value of external_var: 0

イテレーターが tf.function の中で生成されすべて使われる場合には、正しく動作するはずです。しかし、イテレーター全体がトレースされることとなり、巨大な計算グラフの生成をまねく可能性があります。これは、望みどおりの動作かもしれません。しかし、もし Python のリストとして表されたメモリー上の巨大なデータセットを使って訓練を行うとすると、これは非常に大きな計算グラフを生成することになり、tf.function がスピードアップにはつながらないと考えられます。

Python データを繰り返し使用する場合、もっとも安全な方法は tf.data.Dataset でラップして、for x in y というイディオムを使用することです。AutoGraph には、y がテンソルあるいは tf.data.Dataset である場合、for ループを安全に変換する特別な機能があります。

def measure_graph_size(f, *args):
  g = f.get_concrete_function(*args).graph
  print("{}({}) contains {} nodes in its graph".format(
      f.__name__, ', '.join(map(str, args)), len(g.as_graph_def().node)))

@tf.function
def train(dataset):
  loss = tf.constant(0)
  for x, y in dataset:
    loss += tf.abs(y - x) # ダミー計算
  return loss

small_data = [(1, 1)] * 2
big_data = [(1, 1)] * 10
measure_graph_size(train, small_data)
measure_graph_size(train, big_data)

measure_graph_size(train, tf.data.Dataset.from_generator(
    lambda: small_data, (tf.int32, tf.int32)))
measure_graph_size(train, tf.data.Dataset.from_generator(
    lambda: big_data, (tf.int32, tf.int32)))
train([(1, 1), (1, 1)]) contains 8 nodes in its graph
train([(1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1)]) contains 32 nodes in its graph
train(<FlatMapDataset shapes: (<unknown>, <unknown>), types: (tf.int32, tf.int32)>) contains 9 nodes in its graph
train(<FlatMapDataset shapes: (<unknown>, <unknown>), types: (tf.int32, tf.int32)>) contains 9 nodes in its graph

Python/Numpy のデータを Dataset でラップする際には、tf.data.Dataset.from_generatortf.data.Dataset.from_tensors の違いに留意しましょう。前者はデータを Python のまま保持し tf.py_function を通じて取得するため、性能に影響する場合があります。これに対して後者はデータのコピーを計算グラフの中の、ひとつの大きな tf.constant() に結びつけるため、メモリー消費に影響する可能性があります。

TFRecordDataset/CsvDataset/などを通じてデータをファイルから読み込むことが、データを使用する最も効率的な方法です。TensorFlow 自身が Python とは関係なく非同期のデータ読み込みとプリフェッチを管理することができるからです。

自動的な依存関係の制御

プログラミングモデルとしての関数が一般的なデータフローグラフに対して非常に優位である点は、意図したコードの振る舞いがどのようなものであるかということについて、より多くの情報をランタイムに与えられるということにあります。

例えば、同じ変数を何度も読んだり書いたりするコードを書く場合、データフローグラフではもともと意図されていた演算の順番を自然に組み込むわけではありません。tf.function の中では、もともとの Python コードの文の実行順序を参照することで、実行順序の曖昧さを解消します。これにより、tf.function の中のステートフルな演算の順序が、先行実行モードのセマンティクスを模していることになります。

これは、手動で制御の依存関係を加える必要がないことを意味しています。tf.function は十分賢いので、あなたのコードが正しく動作するために必要十分な最小限の制御の依存関係を追加してくれます。

# 自動的な依存関係の制御

a = tf.Variable(1.0)
b = tf.Variable(2.0)

@tf.function
def f(x, y):
  a.assign(y * b)
  b.assign_add(x * a)
  return a + b

f(1.0, 2.0)  # 10.0
<tf.Tensor: shape=(), dtype=float32, numpy=10.0>

変数

tf.function の中では、意図したコードの実行順序を活用するという同じアイデアを使って、変数の作成と活用を簡単に行うことができます。しかし、ひとつだけ非常に重要な欠点があります。それは、変数を使った場合、先行実行モードとグラフモードでは動作が変わるコードを書いてしまう可能性があるということです。

特に、呼び出しの都度新しい変数を作成する場合にこれが発生します。トレーシングの意味では、tf.function は呼び出しのたびに同じ変数を再利用しますが、Eager モードでは呼び出しごとに新しい変数を生成します。この間違いを防止するため、tf.function は危険な変数の生成動作を見つけるとエラーを発生させます。

@tf.function
def f(x):
  v = tf.Variable(1.0)
  v.assign_add(x)
  return v

with assert_raises(ValueError):
  f(1.0)
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_core/python/ops/resource_variable_ops.py:1786: calling BaseResourceVariable.__init__ (from tensorflow.python.ops.resource_variable_ops) with constraint is deprecated and will be removed in a future version.
Instructions for updating:
If using Keras pass *_constraint arguments to layers.
Caught expected exception 
  <class 'ValueError'>: in converted code:

    <ipython-input-17-73e410646579>:3 f  *
        v = tf.Variable(1.0)
    /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_core/python/ops/variables.py:260 __call__
        return cls._variable_v2_call(*args, **kwargs)
    /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_core/python/ops/variables.py:254 _variable_v2_call
        shape=shape)
    /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_core/python/ops/variables.py:65 getter
        return captured_getter(captured_previous, **kwargs)
    /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_core/python/eager/def_function.py:502 invalid_creator_scope
        "tf.function-decorated function tried to create "

    ValueError: tf.function-decorated function tried to create variables on non-first call.

# しかし、曖昧さの無いコードは大丈夫

v = tf.Variable(1.0)

@tf.function
def f(x):
  return v.assign_add(x)

print(f(1.0))  # 2.0
print(f(2.0))  # 4.0
tf.Tensor(2.0, shape=(), dtype=float32)
tf.Tensor(4.0, shape=(), dtype=float32)
# 初めて関数が実行されるときだけ変数が生成されることを保証できれば
# tf.function 内で変数を作成できる

class C: pass
obj = C(); obj.v = None

@tf.function
def g(x):
  if obj.v is None:
    obj.v = tf.Variable(1.0)
  return obj.v.assign_add(x)

print(g(1.0))  # 2.0
print(g(2.0))  # 4.0
tf.Tensor(2.0, shape=(), dtype=float32)
tf.Tensor(4.0, shape=(), dtype=float32)
# 変数の初期化は、関数の引数や他の変数の値に依存可能
# 制御の依存関係を生成するのと同じ手法で、正しい初期化の順序を発見可能

state = []
@tf.function
def fn(x):
  if not state:
    state.append(tf.Variable(2.0 * x))
    state.append(tf.Variable(state[0] * 3.0))
  return state[0] * x * state[1]

print(fn(tf.constant(1.0)))
print(fn(tf.constant(3.0)))
tf.Tensor(12.0, shape=(), dtype=float32)
tf.Tensor(36.0, shape=(), dtype=float32)

AutoGraph の使用

autograph ライブラリは tf.function に完全に統合されており、計算グラフの中で動的に実行される条件文や繰り返しを書くことができます。

tf.condtf.while_looptf.function でも使えますが、制御フローを含むコードは、命令形式で書いたほうが書きやすいし理解しやすいです。

# 単純な繰り返し

@tf.function
def f(x):
  while tf.reduce_sum(x) > 1:
    tf.print(x)
    x = tf.tanh(x)
  return x

f(tf.random.uniform([5]))
[0.741314411 0.410949 0.918613672 0.676402211 0.609279156]
[0.629938602 0.389278173 0.725240946 0.589175463 0.543619514]
[0.55801 0.370737791 0.620145202 0.529302359 0.49572292]
[0.50649941 0.354636878 0.551229119 0.484847605 0.458746821]
[0.467213213 0.340481311 0.501440823 0.450117499 0.429062307]
[0.435944915 0.327907 0.463249534 0.42199561 0.40453738]
[0.410277575 0.316638857 0.432728976 0.398610294 0.383824617]
[0.388708353 0.306464583 0.407599449 0.378759265 0.366024315]
[0.370246172 0.297217369 0.386432499 0.361629486 0.350509346]
[0.354207 0.28876406 0.368280649 0.346648514 0.336827159]
[0.340101242 0.280996859 0.352486908 0.333399922 0.324641854]
[0.327567756 0.273827434 0.338579208 0.321572423 0.313698053]
[0.316333592 0.267182648 0.326208383 0.310928017 0.303797603]
[0.306187958 0.261001348 0.31510973 0.301281095 0.294784069]
[0.296965182 0.255231887 0.305078447 0.292484552 0.286532104]
[0.288532883 0.249830365 0.295953184 0.284419984 0.278939843]
[0.280783921 0.244759187 0.287604868 0.27699092 0.271923602]
[0.27363044 0.239985928 0.279928833 0.270117819 0.265413821]
[0.266999722 0.235482454 0.272839218 0.26373446 0.259352237]
[0.260830849 0.231224224 0.266264737 0.257785171 0.253689557]
[0.255072474 0.227189705 0.260145754 0.252222747 0.248383746]
[0.249680907 0.223359942 0.254431844 0.247006938 0.243398756]
[0.244618684 0.219718158 0.249080107 0.242103085 0.23870343]
[0.239853516 0.216249421 0.244053751 0.237481207 0.234270707]
[0.235357389 0.212940425 0.239321008 0.233115241 0.230076939]
[0.231105849 0.209779248 0.234854311 0.228982344 0.226101354]
[0.227077439 0.206755191 0.230629578 0.22506246 0.222325638]
[0.22325328 0.203858614 0.226625681 0.22133787 0.218733549]
[0.219616637 0.201080784 0.222823992 0.217792839 0.215310648]
[0.216152638 0.198413789 0.219208017 0.214413375 0.212044045]
[0.212848023 0.195850432 0.215763077 0.211186945 0.208922163]
[0.209690914 0.193384171 0.21247609 0.208102331 0.205934599]
[0.206670642 0.191009015 0.209335312 0.205149412 0.203072]
[0.203777567 0.188719481 0.206330195 0.202319056 0.200325847]

<tf.Tensor: shape=(5,), dtype=float32, numpy=
array([0.201003  , 0.18651053, 0.20345125, 0.199603  , 0.19768846],
      dtype=float32)>
# 興味があれば AutoGraph が生成するコードを調べることができる
# ただし、アセンブリ言語を読むような感じがする

def f(x):
  while tf.reduce_sum(x) > 1:
    tf.print(x)
    x = tf.tanh(x)
  return x

print(tf.autograph.to_code(f))
def tf__f(x):
  do_return = False
  retval_ = ag__.UndefinedReturnValue()
  with ag__.FunctionScope('f', 'fscope', ag__.ConversionOptions(recursive=True, user_requested=True, optional_features=(), internal_convert_user_code=True)) as fscope:

    def get_state():
      return ()

    def set_state(_):
      pass

    def loop_body(x):
      ag__.converted_call(tf.print, (x,), None, fscope)
      x = ag__.converted_call(tf.tanh, (x,), None, fscope)
      return x,

    def loop_test(x):
      return ag__.converted_call(tf.reduce_sum, (x,), None, fscope) > 1
    x, = ag__.while_stmt(loop_test, loop_body, get_state, set_state, (x,), ('x',), ())
    do_return = True
    retval_ = fscope.mark_return_value(x)
  do_return,
  return ag__.retval(retval_)

AutoGraph: 条件分岐

AutoGraph は if 文を等価である tf.cond の呼び出しに変換します。

この置換は条件がテンソルである場合に行われます。そうでない場合には、条件分岐はトレーシングの中で実行されます。

def test_tf_cond(f, *args):
  g = f.get_concrete_function(*args).graph
  if any(node.name == 'cond' for node in g.as_graph_def().node):
    print("{}({}) uses tf.cond.".format(
        f.__name__, ', '.join(map(str, args))))
  else:
    print("{}({}) executes normally.".format(
        f.__name__, ', '.join(map(str, args))))

@tf.function
def hyperparam_cond(x, training=True):
  if training:
    x = tf.nn.dropout(x, rate=0.5)
  return x

@tf.function
def maybe_tensor_cond(x):
  if x < 0:
    x = -x
  return x

test_tf_cond(hyperparam_cond, tf.ones([1], dtype=tf.float32))
test_tf_cond(maybe_tensor_cond, tf.constant(-1))
test_tf_cond(maybe_tensor_cond, -1)
hyperparam_cond(tf.Tensor([1.], shape=(1,), dtype=float32)) executes normally.
maybe_tensor_cond(tf.Tensor(-1, shape=(), dtype=int32)) uses tf.cond.
maybe_tensor_cond(-1) executes normally.

tf.cond には、色々と注意すべき細かな点があります。

  • tf.cond は条件分岐の両方をトレーシングし、条件に従って実行時に適切な分岐を選択することで機能します。分岐の両方をトレースすることで、Python プログラムを予期せず実行する可能性があります。
  • tf.cond では、分岐の一方が後ほど使用されるテンソルを作成する場合、もう一方の分岐もそのテンソルを作成することが必要です。
@tf.function
def f():
  x = tf.constant(0)
  if tf.constant(True):
    x = x + 1
    print("Tracing `then` branch")
  else:
    x = x - 1
    print("Tracing `else` branch")
  return x

f()
Tracing `then` branch
Tracing `else` branch

<tf.Tensor: shape=(), dtype=int32, numpy=1>
@tf.function
def f():
  if tf.constant(True):
    x = tf.ones([3, 3])
  return x

# 分岐のどちらの枝でも `x` を定義する必要があるためエラーが発生
with assert_raises(ValueError):
  f()
Caught expected exception 
  <class 'ValueError'>: in converted code:

    <ipython-input-26-3c92b2df645c>:3 f  *
        if tf.constant(True):
    /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_core/python/autograph/operators/control_flow.py:918 if_stmt
        basic_symbol_names, composite_symbol_names)
    /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_core/python/autograph/operators/control_flow.py:956 tf_if_stmt
        error_checking_orelse)
    /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_core/python/util/deprecation.py:507 new_func
        return func(*args, **kwargs)
    /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_core/python/ops/control_flow_ops.py:1174 cond
        return cond_v2.cond_v2(pred, true_fn, false_fn, name)
    /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_core/python/ops/cond_v2.py:90 cond_v2
        op_return_value=pred)
    /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_core/python/framework/func_graph.py:978 func_graph_from_py_func
        func_outputs = python_func(*func_args, **func_kwargs)
    /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_core/python/autograph/operators/control_flow.py:949 error_checking_orelse
        result[orelse_branch] = orelse()
    /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_core/python/autograph/operators/control_flow.py:987 wrapper
        new_vars = func()
    /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_core/python/autograph/operators/control_flow.py:1013 wrapper
        tuple(s.symbol_name for s in undefined)))

    ValueError: The following symbols must also be initialized in the else branch: ('x',). Alternatively, you may initialize them before the if statement.

AutoGraph と繰り返し

AutoGraph には繰り返しの変換にいくつかの単純なルールがあります。

  • for: イテラブルがテンソルである場合に変換する
  • while: while 条件がテンソルに依存している場合に変換する

繰り返しが変換される場合、tf.while_loop によって動的に展開されます。あるいは、 for x in tf.data.Dataset という特別なケースの場合には、 tf.data.Dataset.reduce に変換されます。

繰り返しが変換されない場合、それは静的に展開されます。

def test_dynamically_unrolled(f, *args):
  g = f.get_concrete_function(*args).graph
  if any(node.name == 'while' for node in g.as_graph_def().node):
    print("{}({}) uses tf.while_loop.".format(
        f.__name__, ', '.join(map(str, args))))
  elif any(node.name == 'ReduceDataset' for node in g.as_graph_def().node):
    print("{}({}) uses tf.data.Dataset.reduce.".format(
        f.__name__, ', '.join(map(str, args))))
  else:
    print("{}({}) gets unrolled.".format(
        f.__name__, ', '.join(map(str, args))))
@tf.function
def for_in_range():
  x = 0
  for i in range(5):
    x += i
  return x

test_dynamically_unrolled(for_in_range)
for_in_range() gets unrolled.
@tf.function
def for_in_tfrange():
  x = tf.constant(0, dtype=tf.int32)
  for i in tf.range(5):
    x += i
  return x

test_dynamically_unrolled(for_in_tfrange)
for_in_tfrange() uses tf.while_loop.
@tf.function
def for_in_tfdataset():
  x = tf.constant(0, dtype=tf.int64)
  for i in tf.data.Dataset.range(5):
    x += i
  return x

test_dynamically_unrolled(for_in_tfdataset)
for_in_tfdataset() uses tf.data.Dataset.reduce.
@tf.function
def while_py_cond():
  x = 5
  while x > 0:
    x -= 1
  return x

test_dynamically_unrolled(while_py_cond)
while_py_cond() gets unrolled.
@tf.function
def while_tf_cond():
  x = tf.constant(5)
  while x > 0:
    x -= 1
  return x

test_dynamically_unrolled(while_tf_cond)
while_tf_cond() uses tf.while_loop.

繰り返しに、テンソルに依存する break や、途中での return がある場合、一番外側の条件あるいはイテラブルはテンソルである必要があります。

比較してみましょう。

@tf.function
def while_py_true_py_break(x):
  while True:  # py true
    if x == 0: # py break
      break
    x -= 1
  return x

test_dynamically_unrolled(while_py_true_py_break, 5)
while_py_true_py_break(5) gets unrolled.
@tf.function
def buggy_while_py_true_tf_break(x):
  while True:   # py true
    if tf.equal(x, 0): # tf break
      break
    x -= 1
  return x

with assert_raises(TypeError):
  test_dynamically_unrolled(buggy_while_py_true_tf_break, 5)
Caught expected exception 
  <class 'TypeError'>: in converted code:

    <ipython-input-34-453240ea98e6>:3 buggy_while_py_true_tf_break  *
        while True:   # py true
    /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_core/python/autograph/operators/control_flow.py:755 while_stmt
        return _py_while_stmt(test, body, get_state, set_state, init_vars, opts)
    /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_core/python/autograph/operators/control_flow.py:870 _py_while_stmt
        while test(*loop_vars):
    /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_core/python/framework/ops.py:757 __bool__
        self._disallow_bool_casting()
    /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_core/python/framework/ops.py:523 _disallow_bool_casting
        "using a `tf.Tensor` as a Python `bool`")
    /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_core/python/framework/ops.py:510 _disallow_when_autograph_enabled
        " decorating it directly with @tf.function.".format(task))

    OperatorNotAllowedInGraphError: using a `tf.Tensor` as a Python `bool` is not allowed: AutoGraph did not convert this function. Try decorating it directly with @tf.function.

@tf.function
def while_tf_true_tf_break(x):
  while tf.constant(True): # tf true
    if x == 0:  # py break
      break
    x -= 1
  return x

test_dynamically_unrolled(while_tf_true_tf_break, 5)
while_tf_true_tf_break(5) uses tf.while_loop.
@tf.function
def buggy_py_for_tf_break():
  x = 0
  for i in range(5):  # py for
    if tf.equal(i, 3): # tf break
      break
    x += i
  return x

with assert_raises(TypeError):
  test_dynamically_unrolled(buggy_py_for_tf_break)
Caught expected exception 
  <class 'TypeError'>: in converted code:

    <ipython-input-36-82742b0a14d0>:4 buggy_py_for_tf_break  *
        for i in range(5):  # py for
    /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_core/python/autograph/operators/control_flow.py:339 for_stmt
        return _py_for_stmt(iter_, extra_test, body, get_state, set_state, init_vars)
    /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_core/python/autograph/operators/control_flow.py:348 _py_for_stmt
        if extra_test is not None and not extra_test(*state):
    /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_core/python/framework/ops.py:757 __bool__
        self._disallow_bool_casting()
    /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_core/python/framework/ops.py:523 _disallow_bool_casting
        "using a `tf.Tensor` as a Python `bool`")
    /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_core/python/framework/ops.py:510 _disallow_when_autograph_enabled
        " decorating it directly with @tf.function.".format(task))

    OperatorNotAllowedInGraphError: using a `tf.Tensor` as a Python `bool` is not allowed: AutoGraph did not convert this function. Try decorating it directly with @tf.function.

@tf.function
def tf_for_py_break():
  x = 0
  for i in tf.range(5): # tf for
    if i == 3:  # py break
      break
    x += i
  return x

test_dynamically_unrolled(tf_for_py_break)
tf_for_py_break() uses tf.while_loop.

動的に展開される繰り返しの結果を集計するため、tf.TensorArray を使いたくなるかもしれません。

batch_size = 2
seq_len = 3
feature_size = 4

def rnn_step(inp, state):
  return inp + state

@tf.function
def dynamic_rnn(rnn_step, input_data, initial_state):
  # [batch, time, features] -> [time, batch, features]
  input_data = tf.transpose(input_data, [1, 0, 2])
  max_seq_len = input_data.shape[0]

  states = tf.TensorArray(tf.float32, size=max_seq_len)
  state = initial_state
  for i in tf.range(max_seq_len):
    state = rnn_step(input_data[i], state)
    states = states.write(i, state)
  return tf.transpose(states.stack(), [1, 0, 2])
  
dynamic_rnn(rnn_step,
            tf.random.uniform([batch_size, seq_len, feature_size]),
            tf.zeros([batch_size, feature_size]))
<tf.Tensor: shape=(2, 3, 4), dtype=float32, numpy=
array([[[0.35583818, 0.83193195, 0.4379022 , 0.32066643],
        [1.0667148 , 0.84722435, 0.8679224 , 0.942137  ],
        [1.6984146 , 0.9410734 , 0.9755955 , 1.5888207 ]],

       [[0.35575068, 0.04498446, 0.52700126, 0.1645391 ],
        [0.43994033, 0.21749938, 1.0328307 , 1.0847937 ],
        [0.54859304, 0.43104947, 1.195856  , 1.6513802 ]]], dtype=float32)>

tf.cond と同様に、tf.while_loop にも、色々と注意すべき細かな点があります。

  • 繰り返しの実行回数が 0 である可能性があるため、while_loop の後で使用されるテンソルは、繰り返しの前に初期化されなければならない
  • すべての繰り返しの変数は、各繰り返しを通じてその形状と dtype が変わらないことが必要
@tf.function
def buggy_loop_var_uninitialized():
  for i in tf.range(3):
    x = i
  return x

with assert_raises(ValueError):
  buggy_loop_var_uninitialized()
Caught expected exception 
  <class 'ValueError'>: in converted code:

    <ipython-input-39-815fd6bba8cc>:3 buggy_loop_var_uninitialized  *
        for i in tf.range(3):
    /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_core/python/autograph/operators/control_flow.py:315 for_stmt
        composite_symbol_names)
    /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_core/python/autograph/operators/control_flow.py:419 _tf_range_for_stmt
        _disallow_undefs_into_loop(*init_vars)
    /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_core/python/autograph/operators/control_flow.py:97 _disallow_undefs_into_loop
        ' before the loop: {}'.format(tuple(s.symbol_name for s in undefined)))

    ValueError: TensorFlow requires that the following symbols must be defined before the loop: ('x',)

@tf.function
def f():
  x = tf.constant(0)
  for i in tf.range(3):
    x = i
  return x

f()
<tf.Tensor: shape=(), dtype=int32, numpy=2>
@tf.function
def buggy_loop_type_changes():
  x = tf.constant(0, dtype=tf.float32)
  for i in tf.range(3): # tf.int32 型のテンソルを1つづつ取り出して…
    x = i
  return x

with assert_raises(tf.errors.InvalidArgumentError):
  buggy_loop_type_changes()
Got unexpected exception 
  <class 'TypeError'>: in converted code:

    <ipython-input-41-e04d86588cb2>:4 buggy_loop_type_changes  *
        for i in tf.range(3): # tf.int32 型のテンソルを1つづつ取り出して…
    /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_core/python/autograph/operators/control_flow.py:315 for_stmt
        composite_symbol_names)
    /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_core/python/autograph/operators/control_flow.py:478 _tf_range_for_stmt
        opts=opts,
    /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_core/python/autograph/operators/control_flow.py:794 _tf_while_stmt
        aug_init_vars, **opts)
    /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_core/python/ops/control_flow_ops.py:2675 while_loop
        back_prop=back_prop)
    /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_core/python/ops/while_v2.py:194 while_loop
        add_control_dependencies=add_control_dependencies)
    /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_core/python/framework/func_graph.py:978 func_graph_from_py_func
        func_outputs = python_func(*func_args, **func_kwargs)
    /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_core/python/ops/while_v2.py:172 wrapped_body
        outputs = body(*_pack_sequence_as(orig_loop_vars, args))
    /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_core/python/autograph/operators/control_flow.py:784 aug_body
        composite_symbol_names)
    /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_core/python/autograph/operators/control_flow.py:195 _verify_tf_loop_vars
        first_iter_var)
    /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_core/python/util/nest.py:568 map_structure
        structure[0], [func(*x) for x in entries],
    /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_core/python/util/nest.py:568 <listcomp>
        structure[0], [func(*x) for x in entries],
    /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_core/python/autograph/operators/control_flow.py:179 _check_same_type
        first_iter_var.dtype.name,

    TypeError: "x" has dtype float32 before the loop, but dtype int32 after one iteration. TensorFlow control flow requires it stays the same.

@tf.function
def buggy_concat():
  x = tf.ones([0, 10])
  for i in tf.range(5):
    x = tf.concat([x, tf.ones([1, 10])], axis=0)
  return x

with assert_raises(ValueError):
  buggy_concat()
Caught expected exception 
  <class 'ValueError'>: in converted code:

    <ipython-input-42-bae298a1ce41>:4 buggy_concat  *
        for i in tf.range(5):
    /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_core/python/autograph/operators/control_flow.py:315 for_stmt
        composite_symbol_names)
    /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_core/python/autograph/operators/control_flow.py:478 _tf_range_for_stmt
        opts=opts,
    /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_core/python/autograph/operators/control_flow.py:794 _tf_while_stmt
        aug_init_vars, **opts)
    /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_core/python/ops/control_flow_ops.py:2675 while_loop
        back_prop=back_prop)
    /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_core/python/ops/while_v2.py:194 while_loop
        add_control_dependencies=add_control_dependencies)
    /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_core/python/framework/func_graph.py:978 func_graph_from_py_func
        func_outputs = python_func(*func_args, **func_kwargs)
    /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_core/python/ops/while_v2.py:172 wrapped_body
        outputs = body(*_pack_sequence_as(orig_loop_vars, args))
    /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_core/python/autograph/operators/control_flow.py:784 aug_body
        composite_symbol_names)
    /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_core/python/autograph/operators/control_flow.py:195 _verify_tf_loop_vars
        first_iter_var)
    /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_core/python/util/nest.py:568 map_structure
        structure[0], [func(*x) for x in entries],
    /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_core/python/util/nest.py:568 <listcomp>
        structure[0], [func(*x) for x in entries],
    /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_core/python/autograph/operators/control_flow.py:191 _check_same_type
        first_iter_shape))

    ValueError: "x" has shape (0, 10) before the loop, but shape (1, 10) after one iteration. TensorFlow control flow requires it stays the same or be more specific.

@tf.function
def concat_with_padding():
  x = tf.zeros([5, 10])
  for i in tf.range(5):
    x = tf.concat([x[:i], tf.ones([1, 10]), tf.zeros([4-i, 10])], axis=0)
    x.set_shape([5, 10])
  return x

concat_with_padding()
<tf.Tensor: shape=(5, 10), dtype=float32, numpy=
array([[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
       [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
       [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
       [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
       [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]], dtype=float32)>