TensorFlow 2.0 での tf.function と AutoGraph

View on TensorFlow.org Run in Google Colab View source on GitHub

TensorFlow 2.0 では Eager Execution の使いやすさとTensorFlow 1.0 のパワーとを同時に提供します。この統合の中核となるのは tf.function です。これは Python の構文のサブセットを移植可能でハイパフォーマンスな TensorFlow のグラフに変換します。

tf.functionの魅力的な特徴に AutoGraph があります。これはグラフを Python の構文そのものを用いて記述できるようにします。 AutoGraph で利用可能な Python の機能の一覧は、AutoGraph Capabilities and Limitations (Autograph の性能と制限事項) で確認できます。また、tf.functionの詳細については RFC TF 2.0: Functions, not Sessions を参照してください。AutoGraph の詳細については tf.autograph を参照してください。

このチュートリアルでは tf.function と AutoGraph の基本的な特徴についてひととおり確認します。

セットアップ

TensorFlow 2.0 をインポートします。

import numpy as np
import tensorflow as tf

tf.function デコレータ

tf.functionを用いてある関数にアノテーションを付けたとしても、一般の関数と変わらずに呼び出せます。一方、実行時にはその関数はグラフへとコンパイルされます。これにより、より高速な実行や、 GPU や TPU での実行、SavedModel へのエクスポートといった利点が得られます。

@tf.function
def simple_nn_layer(x, y):
  return tf.nn.relu(tf.matmul(x, y))


x = tf.random.uniform((3, 3))
y = tf.random.uniform((3, 3))

simple_nn_layer(x, y)
<tf.Tensor: shape=(3, 3), dtype=float32, numpy=
array([[0.4944646 , 0.8567986 , 1.0001004 ],
       [1.3690405 , 1.2988085 , 0.88672316],
       [0.808781  , 1.109717  , 0.8911805 ]], dtype=float32)>

アノテーションの結果を調べてみると、 TensorFlow ランタイムとのやり取りのすべてを処理する特別な呼び出し可能オブジェクトを確認できます。

simple_nn_layer
<tensorflow.python.eager.def_function.Function at 0x7f70e8548cc0>

記述したコードで複数の関数を利用していたとしても、すべての関数にアノテーションを付ける必要はありません。アノテーションをつけた関数から呼び出されるすべての関数は、グラフモードで実行されます。

def linear_layer(x):
  return 2 * x + 1


@tf.function
def deep_net(x):
  return tf.nn.relu(linear_layer(x))


deep_net(tf.constant((1, 2, 3)))
<tf.Tensor: shape=(3,), dtype=int32, numpy=array([3, 5, 7], dtype=int32)>

グラフが大量の軽量な演算から構成される場合、関数は Eager Execution で実行するコードよりも高速になる場合があります。しかし、 graph が少量の (畳み込み演算のような) 計算に時間のかかる演算からなる場合、高速化はそれほど見込めないでしょう。

import timeit
conv_layer = tf.keras.layers.Conv2D(100, 3)

@tf.function
def conv_fn(image):
  return conv_layer(image)

image = tf.zeros([1, 200, 200, 100])
# warm up
conv_layer(image); conv_fn(image)
print("Eager conv:", timeit.timeit(lambda: conv_layer(image), number=10))
print("Function conv:", timeit.timeit(lambda: conv_fn(image), number=10))
print("Note how there's not much difference in performance for convolutions")

Eager conv: 0.004799893000040356
Function conv: 0.003768158999946536
Note how there's not much difference in performance for convolutions

lstm_cell = tf.keras.layers.LSTMCell(10)

@tf.function
def lstm_fn(input, state):
  return lstm_cell(input, state)

input = tf.zeros([10, 10])
state = [tf.zeros([10, 10])] * 2
# warm up
lstm_cell(input, state); lstm_fn(input, state)
print("eager lstm:", timeit.timeit(lambda: lstm_cell(input, state), number=10))
print("function lstm:", timeit.timeit(lambda: lstm_fn(input, state), number=10))

eager lstm: 0.006962984000097094
function lstm: 0.0036654850000559236

Python の制御フローの利用

tf.functionの内部でデータに依存した制御フローを用いる場合、Pythonの制御フロー構文を用いることができます。AutoGraph はそれらの構文を TensorFlow の Ops に書き換えます。たとえば、 Tensor に依存する if 文は、tf.cond() に変換されます。

次の例では xTensor です。ですが、 if 文は期待するどおりに動作しています。

@tf.function
def square_if_positive(x):
  if x > 0:
    x = x * x
  else:
    x = 0
  return x


print('square_if_positive(2) = {}'.format(square_if_positive(tf.constant(2))))
print('square_if_positive(-2) = {}'.format(square_if_positive(tf.constant(-2))))
square_if_positive(2) = 4
square_if_positive(-2) = 0

AutoGraph は while, for, if, break, continue, return といった典型的なPythonの構文をサポートしています。また、これらを入れ子にして利用する場合もサポートしています。つまり、Tensorを返す式をwhile 文や if 文の条件式として用いることが可能です。また、for 文で Tensor の要素に渡って反復することも可能です。

@tf.function
def sum_even(items):
  s = 0
  for c in items:
    if c % 2 > 0:
      print(c)
      continue
    s += c
  return s


sum_even(tf.constant([10, 12, 15, 20]))
Tensor("TensorArrayV2Read/TensorListGetItem:0", shape=(), dtype=int32)

<tf.Tensor: shape=(), dtype=int32, numpy=42>

より高度な使い方をするユーザーのために、AutoGraph は低レベルAPIも提供しています。次の例では AutoGraph が生成したコードを確認できます。

print(tf.autograph.to_code(sum_even.python_function))
def tf__sum_even(items):
  do_return = False
  retval_ = ag__.UndefinedReturnValue()
  with ag__.FunctionScope('sum_even', 'fscope', ag__.ConversionOptions(recursive=True, user_requested=True, optional_features=(), internal_convert_user_code=True)) as fscope:
    s = 0

    def get_state_2():
      return ()

    def set_state_2(_):
      pass

    def loop_body(iterates, s):
      c = iterates
      continue_ = False

      def get_state():
        return ()

      def set_state(_):
        pass

      def if_true():
        print(c)
        continue_ = True
        return continue_

      def if_false():
        return continue_
      cond = c % 2 > 0
      continue_ = ag__.if_stmt(cond, if_true, if_false, get_state, set_state, ('continue_',), ())

      def get_state_1():
        return ()

      def set_state_1(_):
        pass

      def if_true_1():
        s_1, = s,
        s_1 += c
        return s_1

      def if_false_1():
        return s
      cond_1 = ag__.not_(continue_)
      s = ag__.if_stmt(cond_1, if_true_1, if_false_1, get_state_1, set_state_1, ('s',), ())
      return s,
    s, = ag__.for_stmt(items, None, loop_body, get_state_2, set_state_2, (s,), ('s',), ())
    do_return = True
    retval_ = fscope.mark_return_value(s)
  do_return,
  return ag__.retval(retval_)


次はより複雑な制御フローの例です。

@tf.function
def fizzbuzz(n):
  for i in tf.range(n):
    if i % 3 == 0:
      tf.print('Fizz')
    elif i % 5 == 0:
      tf.print('Buzz')
    else:
      tf.print(i)

fizzbuzz(tf.constant(15))
Fizz
1
2
Fizz
4
Buzz
Fizz
7
8
Fizz
Buzz
11
Fizz
13
14

Keras での AutoGraph の利用

動的でない Keras モデルでは AutoGraph をデフォルトで利用できます。より詳細が必要な場合、tf.keras を確認してください。

class CustomModel(tf.keras.models.Model):

  @tf.function
  def call(self, input_data):
    if tf.reduce_mean(input_data) > 0:
      return input_data
    else:
      return input_data // 2


model = CustomModel()

model(tf.constant([-2, -4]))
<tf.Tensor: shape=(2,), dtype=int32, numpy=array([-1, -2], dtype=int32)>

副作用

Eager モードのように、通常の場合 tf.function の中で、tf.assigntf.print といった副作用のある命令を実行できます。また、実行時の順序を保つために、処理順について必要な依存関係を書き加えます。

v = tf.Variable(5)

@tf.function
def find_next_odd():
  v.assign(v + 1)
  if v % 2 == 0:
    v.assign(v + 1)


find_next_odd()
v
<tf.Variable 'Variable:0' shape=() dtype=int32, numpy=7>

デバッグ

tf.function や AutoGraph はコードを生成し、それを TensorFlow のグラフにトレースします。この機構は現在のところ、pdb のようなステップ・バイ・ステップのデバッガーをサポートしていません。ですが、tf.config.experimental_run_functions_eagerly(True) を呼び出すことで、一時的に Eager Execution を tf.function の中で有効にして、好きなデバッガーを用いることができます。

@tf.function
def f(x):
  if x > 0:
    # ブレークポイントを設定してみてください!
    # 例:
    #   import pdb
    #   pdb.set_trace()
    x = x + 1
  return x

tf.config.experimental_run_functions_eagerly(True)

# ブレークポイントを設定し、デバッガーでコードを動かせるようになります
f(tf.constant(1))

tf.config.experimental_run_functions_eagerly(False)

応用例: グラフ内での訓練ループ

前のセクションでは AutoGraph を Keras のレイヤーやモデルの中で利用できることを見てきました。 AutoGraph の中で Keras モデルを用いることもできます。

このサンプルはシンプルな Keras モデルを MNIST を用いて、訓練のすべてのプロセス――バッチを読み込み、勾配を計算し、パラメーターを更新し、検証データでの正確度を計算し、これらを収束するまで繰り返すこと――をグラフ内部で行う方法を示しています。

データのダウンロード

def prepare_mnist_features_and_labels(x, y):
  x = tf.cast(x, tf.float32) / 255.0
  y = tf.cast(y, tf.int64)
  return x, y

def mnist_dataset():
  (x, y), _ = tf.keras.datasets.mnist.load_data()
  ds = tf.data.Dataset.from_tensor_slices((x, y))
  ds = ds.map(prepare_mnist_features_and_labels)
  ds = ds.take(20000).shuffle(20000).batch(100)
  return ds

train_dataset = mnist_dataset()
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz
11493376/11490434 [==============================] - 0s 0us/step

モデルの定義

model = tf.keras.Sequential((
    tf.keras.layers.Reshape(target_shape=(28 * 28,), input_shape=(28, 28)),
    tf.keras.layers.Dense(100, activation='relu'),
    tf.keras.layers.Dense(100, activation='relu'),
    tf.keras.layers.Dense(10)))
model.build()
optimizer = tf.keras.optimizers.Adam()

学習のためのループ処理の定義

compute_loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)

compute_accuracy = tf.keras.metrics.SparseCategoricalAccuracy()


def train_one_step(model, optimizer, x, y):
  with tf.GradientTape() as tape:
    logits = model(x)
    loss = compute_loss(y, logits)

  grads = tape.gradient(loss, model.trainable_variables)
  optimizer.apply_gradients(zip(grads, model.trainable_variables))

  compute_accuracy(y, logits)
  return loss


@tf.function
def train(model, optimizer):
  train_ds = mnist_dataset()
  step = 0
  loss = 0.0
  accuracy = 0.0
  for x, y in train_ds:
    step += 1
    loss = train_one_step(model, optimizer, x, y)
    if step % 10 == 0:
      tf.print('Step', step, ': loss', loss, '; accuracy', compute_accuracy.result())
  return step, loss, accuracy

step, loss, accuracy = train(model, optimizer)
print('Final step', step, ': loss', loss, '; accuracy', compute_accuracy.result())
Step 10 : loss 1.85564053 ; accuracy 0.363
Step 20 : loss 1.1814419 ; accuracy 0.4965
Step 30 : loss 0.873045206 ; accuracy 0.590333343
Step 40 : loss 0.621631742 ; accuracy 0.64675
Step 50 : loss 0.672275484 ; accuracy 0.6838
Step 60 : loss 0.476123571 ; accuracy 0.713333309
Step 70 : loss 0.37669152 ; accuracy 0.735857129
Step 80 : loss 0.41039294 ; accuracy 0.754875
Step 90 : loss 0.257115066 ; accuracy 0.772
Step 100 : loss 0.503718317 ; accuracy 0.783
Step 110 : loss 0.344334573 ; accuracy 0.79218179
Step 120 : loss 0.344536632 ; accuracy 0.800833344
Step 130 : loss 0.4136177 ; accuracy 0.808538437
Step 140 : loss 0.212624252 ; accuracy 0.815642834
Step 150 : loss 0.322427183 ; accuracy 0.821733356
Step 160 : loss 0.477276951 ; accuracy 0.826625
Step 170 : loss 0.234079644 ; accuracy 0.831647038
Step 180 : loss 0.369053036 ; accuracy 0.836722195
Step 190 : loss 0.203556687 ; accuracy 0.841947377
Step 200 : loss 0.250807196 ; accuracy 0.84575
Final step tf.Tensor(200, shape=(), dtype=int32) : loss tf.Tensor(0.2508072, shape=(), dtype=float32) ; accuracy tf.Tensor(0.84575, shape=(), dtype=float32)

バッチ処理

実際のアプリケーションにおいて、処理をバッチにまとめることはパフォーマンスの観点から重要です。AutoGraphを用いるのにもっとも適しているコードは、制御フローを バッチ の単位で決定するようなコードです。もし、個々の 要素 の単位で制御を決定する場合、パフォーマンスを保つために batch API を試してみてください。

一例として、次の Python コードががあったとします。

def square_if_positive(x):
  return [i ** 2 if i > 0 else i for i in x]


square_if_positive(range(-5, 5))
[-5, -4, -3, -2, -1, 0, 1, 4, 9, 16]

TensorFlowに同等の処理を行わせる場合、次のように記述したくなるかもしれません。 (これは実際には動作します!)

@tf.function
def square_if_positive_naive(x):
  result = tf.TensorArray(tf.int32, size=x.shape[0])
  for i in tf.range(x.shape[0]):
    if x[i] > 0:
      result = result.write(i, x[i] ** 2)
    else:
      result = result.write(i, x[i])
  return result.stack()


square_if_positive_naive(tf.range(-5, 5))
<tf.Tensor: shape=(10,), dtype=int32, numpy=array([-5, -4, -3, -2, -1,  0,  1,  4,  9, 16], dtype=int32)>

しかし、この場合、次のように書くこともできます。

def square_if_positive_vectorized(x):
  return tf.where(x > 0, x ** 2, x)


square_if_positive_vectorized(tf.range(-5, 5))
<tf.Tensor: shape=(10,), dtype=int32, numpy=array([-5, -4, -3, -2, -1,  0,  1,  4,  9, 16], dtype=int32)>

再トレーシング

要点を次に示します:

  • 関数に Tensor でない引数や、shape を変えるような引数を渡す際には注意が必要
  • モジュール化された関数や、モジュール化されたクラスのメソッドに対してデコレートし、ローカルの関数やメソッドに対するデコレートは避ける

tf.function は最初の実行時に遅くなるというコストが必要ですが、Eager Execution よりも格段に早く実行されます。これは、最初の実行時に関数がトレースされ、TensorFlow のグラフに変換されるためです。最適化されたグラフを構築するのには、大抵それを実行するよりもずっと時間がかかります。

import timeit


@tf.function
def f(x, y):
  return tf.matmul(x, y)

print(
    "First invocation:",
    timeit.timeit(lambda: f(tf.ones((10, 10)), tf.ones((10, 10))), number=1))

print(
    "Second invocation:",
    timeit.timeit(lambda: f(tf.ones((10, 10)), tf.ones((10, 10))), number=1))
First invocation: 0.0495843059999288
Second invocation: 0.0011184449999745993

関数がトレースされたことの確認は、関数の最初に print を付け加えることで確認できます。任意の Python のコードはトレースされる際にだけ実行されるため、print の出力が見られるのは関数がトレースされたときに限られます。

@tf.function
def f():
  print('Tracing!')
  tf.print('Executing')

print('First invocation:')
f()

print('Second invocation:')
f()
First invocation:
Tracing!
Executing
Second invocation:
Executing

Tensor でない異なる値を引数として渡したとき、tf.function再トレースを行うことがあります。

@tf.function
def f(n):
  print(n, 'Tracing!')
  tf.print(n, 'Executing')

f(1)
f(1)

f(2)
f(2)
1 Tracing!
1 Executing
1 Executing
2 Tracing!
2 Executing
2 Executing

再トレースinput_signature を指定しない場合、 Tensor の shape が異なるものを引数に渡したときも発生します。

@tf.function
def f(x):
  print(x.shape, 'Tracing!')
  tf.print(x, 'Executing')

f(tf.constant([1]))
f(tf.constant([2]))

f(tf.constant([1, 2]))
f(tf.constant([3, 4]))
(1,) Tracing!
[1] Executing
[2] Executing
(2,) Tracing!
[1 2] Executing
[3 4] Executing

加えて、tf.function を呼び出したときには常に、トレースを行い新しいグラフを作成します。

def f():
  print('Tracing!')
  tf.print('Executing')

tf.function(f)()
tf.function(f)()
Tracing!
Executing
Tracing!
Executing

これはネストされた関数の中で @tf.function デコレーターを用いた場合に、予想外の振る舞いを引き起こす場合があります。

def outer():
  @tf.function
  def f():
    print('Tracing!')
    tf.print('Executing')
  f()

outer()
outer()
Tracing!
Executing
Tracing!
Executing