Missed TensorFlow World? Check out the recap. Learn more

TensorFlow 2.0 での tf.function と AutoGraph

View on TensorFlow.org Run in Google Colab View source on GitHub

TensorFlow 2.0 では Eager Execution の使いやすさとTensorFlow 1.0 のパワーとを同時に提供します。この統合の中核となるのは tf.function です。これは Python の構文のサブセットを移植可能でハイパフォーマンスな TensorFlow のグラフに変換します。

tf.functionの魅力的な特徴に AutoGraph があります。これはグラフを Python の構文そのものを用いて記述できるようにします。 AutoGraph で利用可能な Python の機能の一覧は、AutoGraph Capabilities and Limitations (Autograph の性能と制限事項) で確認できます。また、tf.functionの詳細については RFC TF 2.0: Functions, not Sessions を参照してください。AutoGraph の詳細については tf.autograph を参照してください。

このチュートリアルでは tf.function と AutoGraph の基本的な特徴についてひととおり確認します。

セットアップ

TensorFlow 2.0 をインポートして、TF 2.0 モードを有効にしてください。

from __future__ import absolute_import, division, print_function, unicode_literals
import numpy as np
try:
  !pip install -q tf-nightly-2.0-preview
except Exception:
  pass
import tensorflow as tf

tf.function デコレータ

tf.functionを用いてある関数にアノテーションを付けたとしても、一般の関数と変わらずに呼び出せます。一方、実行時にはその関数はグラフへとコンパイルされます。これにより、より高速な実行や、 GPU や TPU での実行、SavedModel へのエクスポートといった利点が得られます。

@tf.function
def simple_nn_layer(x, y):
  return tf.nn.relu(tf.matmul(x, y))


x = tf.random.uniform((3, 3))
y = tf.random.uniform((3, 3))

simple_nn_layer(x, y)
<tf.Tensor: shape=(3, 3), dtype=float32, numpy=
array([[0.41838184, 0.30984652, 0.56670266],
       [0.79515874, 0.8423839 , 0.91438687],
       [0.29930425, 0.31829548, 0.3870223 ]], dtype=float32)>

アノテーションの結果を調べてみると、 TensorFlow ランタイムとのやり取りのすべてを処理する特別な呼び出し可能オブジェクトを確認できます。

simple_nn_layer
<tensorflow.python.eager.def_function.Function at 0x7f7aeebd1710>

記述したコードで複数の関数を利用していたとしても、すべての関数にアノテーションを付ける必要はありません。アノテーションをつけた関数から呼び出されるすべての関数は、グラフモードで実行されます。

def linear_layer(x):
  return 2 * x + 1


@tf.function
def deep_net(x):
  return tf.nn.relu(linear_layer(x))


deep_net(tf.constant((1, 2, 3)))
<tf.Tensor: shape=(3,), dtype=int32, numpy=array([3, 5, 7], dtype=int32)>

グラフが大量の軽量な演算から構成される場合、関数は Eager Execution で実行するコードよりも高速になる場合があります。しかし、 graph が少量の (畳み込み演算のような) 計算に時間のかかる演算からなる場合、高速化はそれほど見込めないでしょう。

import timeit
conv_layer = tf.keras.layers.Conv2D(100, 3)

@tf.function
def conv_fn(image):
  return conv_layer(image)

image = tf.zeros([1, 200, 200, 100])
# warm up
conv_layer(image); conv_fn(image)
print("Eager conv:", timeit.timeit(lambda: conv_layer(image), number=10))
print("Function conv:", timeit.timeit(lambda: conv_fn(image), number=10))
print("Note how there's not much difference in performance for convolutions")
Eager conv: 0.3120034640014637
Function conv: 0.26202421199923265
Note how there's not much difference in performance for convolutions
lstm_cell = tf.keras.layers.LSTMCell(10)

@tf.function
def lstm_fn(input, state):
  return lstm_cell(input, state)

input = tf.zeros([10, 10])
state = [tf.zeros([10, 10])] * 2
# warm up
lstm_cell(input, state); lstm_fn(input, state)
print("eager lstm:", timeit.timeit(lambda: lstm_cell(input, state), number=10))
print("function lstm:", timeit.timeit(lambda: lstm_fn(input, state), number=10))
eager lstm: 0.006133716999102035
function lstm: 0.0038258340009633685

Python の制御フローの利用

tf.functionの内部でデータに依存した制御フローを用いる場合、Pythonの制御フロー構文を用いることができます。AutoGraph はそれらの構文を TensorFlow の Ops に書き換えます。たとえば、 Tensor に依存する if 文は、tf.cond() に変換されます。

次の例では xTensor です。ですが、 if 文は期待するどおりに動作しています。

@tf.function
def square_if_positive(x):
  if x > 0:
    x = x * x
  else:
    x = 0
  return x


print('square_if_positive(2) = {}'.format(square_if_positive(tf.constant(2))))
print('square_if_positive(-2) = {}'.format(square_if_positive(tf.constant(-2))))
square_if_positive(2) = 4
square_if_positive(-2) = 0

AutoGraph は while, for, if, break, continue, return といった典型的なPythonの構文をサポートしています。また、これらを入れ子にして利用する場合もサポートしています。つまり、Tensorを返す式をwhile 文や if 文の条件式として用いることが可能です。また、for 文で Tensor の要素に渡って反復することも可能です。

@tf.function
def sum_even(items):
  s = 0
  for c in items:
    if c % 2 > 0:
      print(c)
      continue
    s += c
  return s


sum_even(tf.constant([10, 12, 15, 20]))
Tensor("TensorArrayV2Read/TensorListGetItem:0", shape=(), dtype=int32)

<tf.Tensor: shape=(), dtype=int32, numpy=42>

より高度な使い方をするユーザーのために、AutoGraph は低レベルAPIも提供しています。次の例では AutoGraph が生成したコードを確認できます。

print(tf.autograph.to_code(sum_even.python_function))
def tf__sum_even(items):
  do_return = False
  retval_ = ag__.UndefinedReturnValue()
  with ag__.FunctionScope('sum_even', 'sum_even_scope', ag__.ConversionOptions(recursive=True, user_requested=True, optional_features=(), internal_convert_user_code=True)) as sum_even_scope:
    s = 0

    def get_state_2():
      return ()

    def set_state_2(_):
      pass

    def loop_body(iterates, s):
      c = iterates
      continue_ = False

      def get_state():
        return ()

      def set_state(_):
        pass

      def if_true():
        print(c)
        continue_ = True
        return continue_

      def if_false():
        return continue_
      cond = c % 2 > 0
      continue_ = ag__.if_stmt(cond, if_true, if_false, get_state, set_state, ('continue_',), ())

      def get_state_1():
        return ()

      def set_state_1(_):
        pass

      def if_true_1():
        s_1, = s,
        s_1 += c
        return s_1

      def if_false_1():
        return s
      cond_1 = ag__.not_(continue_)
      s = ag__.if_stmt(cond_1, if_true_1, if_false_1, get_state_1, set_state_1, ('s',), ())
      return s,
    s, = ag__.for_stmt(items, None, loop_body, get_state_2, set_state_2, (s,), ('s',), ())
    do_return = True
    retval_ = sum_even_scope.mark_return_value(s)
  do_return,
  return ag__.retval(retval_)

次はより複雑な制御フローの例です。

@tf.function
def fizzbuzz(n):
  msg = tf.constant('')
  for i in tf.range(n):
    if i % 3 == 0:
      tf.print('Fizz')
    elif i % 5 == 0:
      tf.print('Buzz')
    else:
      tf.print(i)

fizzbuzz(tf.constant(15))
Fizz
1
2
Fizz
4
Buzz
Fizz
7
8
Fizz
Buzz
11
Fizz
13
14

Keras での AutoGraph の利用

tf.function はオブジェクトのメソッドに対しても利用できます。たとえば、カスタムしたKeras モデルにデコレーターを適用できます、典型的には call 関数にアノテーションを付けることで実現できるでしょう。より詳細が必要な場合、tf.keras を確認してください。

class CustomModel(tf.keras.models.Model):

  @tf.function
  def call(self, input_data):
    if tf.reduce_mean(input_data) > 0:
      return input_data
    else:
      return input_data // 2


model = CustomModel()

model(tf.constant([-2, -4]))
<tf.Tensor: shape=(2,), dtype=int32, numpy=array([-1, -2], dtype=int32)>

副作用

Eager モードのように、通常の場合 tf.function の中で、tf.assigntf.print といった副作用のある命令を実行できます。また、実行時の順序を保つために、処理順について必要な依存関係を書き加えます。

v = tf.Variable(5)

@tf.function
def find_next_odd():
  v.assign(v + 1)
  if v % 2 == 0:
    v.assign(v + 1)


find_next_odd()
v
<tf.Variable 'Variable:0' shape=() dtype=int32, numpy=7>

例: シンプルなモデルの学習

AutoGraph はこれまで見てきたよりもずっと多くの演算を TensorFlow の内部で実行できます。たとえば、学習のためのループ処理は単に制御フローなので、実際にそれを TensorFlow に持ち込んで処理できます。

データのダウンロード

def prepare_mnist_features_and_labels(x, y):
  x = tf.cast(x, tf.float32) / 255.0
  y = tf.cast(y, tf.int64)
  return x, y

def mnist_dataset():
  (x, y), _ = tf.keras.datasets.mnist.load_data()
  ds = tf.data.Dataset.from_tensor_slices((x, y))
  ds = ds.map(prepare_mnist_features_and_labels)
  ds = ds.take(20000).shuffle(20000).batch(100)
  return ds

train_dataset = mnist_dataset()

モデルの定義

model = tf.keras.Sequential((
    tf.keras.layers.Reshape(target_shape=(28 * 28,), input_shape=(28, 28)),
    tf.keras.layers.Dense(100, activation='relu'),
    tf.keras.layers.Dense(100, activation='relu'),
    tf.keras.layers.Dense(10)))
model.build()
optimizer = tf.keras.optimizers.Adam()

学習のためのループ処理の定義

compute_loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)

compute_accuracy = tf.keras.metrics.SparseCategoricalAccuracy()


def train_one_step(model, optimizer, x, y):
  with tf.GradientTape() as tape:
    logits = model(x)
    loss = compute_loss(y, logits)

  grads = tape.gradient(loss, model.trainable_variables)
  optimizer.apply_gradients(zip(grads, model.trainable_variables))

  compute_accuracy(y, logits)
  return loss


@tf.function
def train(model, optimizer):
  train_ds = mnist_dataset()
  step = 0
  loss = 0.0
  accuracy = 0.0
  for x, y in train_ds:
    step += 1
    loss = train_one_step(model, optimizer, x, y)
    if step % 10 == 0:
      tf.print('Step', step, ': loss', loss, '; accuracy', compute_accuracy.result())
  return step, loss, accuracy

step, loss, accuracy = train(model, optimizer)
print('Final step', step, ': loss', loss, '; accuracy', compute_accuracy.result())
Step 10 : loss 1.78478038 ; accuracy 0.355
Step 20 : loss 1.38942516 ; accuracy 0.505
Step 30 : loss 0.791106 ; accuracy 0.592666686
Step 40 : loss 0.597629786 ; accuracy 0.65025
Step 50 : loss 0.405860215 ; accuracy 0.6918
Step 60 : loss 0.408941984 ; accuracy 0.722833335
Step 70 : loss 0.371362537 ; accuracy 0.743714273
Step 80 : loss 0.411753 ; accuracy 0.7625
Step 90 : loss 0.322867423 ; accuracy 0.777444422
Step 100 : loss 0.424844652 ; accuracy 0.7901
Step 110 : loss 0.380501479 ; accuracy 0.800272703
Step 120 : loss 0.440484554 ; accuracy 0.808666646
Step 130 : loss 0.365675151 ; accuracy 0.816307664
Step 140 : loss 0.347084612 ; accuracy 0.823857129
Step 150 : loss 0.18940562 ; accuracy 0.830933332
Step 160 : loss 0.353083879 ; accuracy 0.836187482
Step 170 : loss 0.262934148 ; accuracy 0.840647042
Step 180 : loss 0.328646421 ; accuracy 0.84577775
Step 190 : loss 0.240262657 ; accuracy 0.850210547
Step 200 : loss 0.160538524 ; accuracy 0.85455
Final step tf.Tensor(200, shape=(), dtype=int32) : loss tf.Tensor(0.16053852, shape=(), dtype=float32) ; accuracy tf.Tensor(0.85455, shape=(), dtype=float32)

バッチ処理

実際のアプリケーションにおいて、処理をバッチにまとめることはパフォーマンスの観点から重要です。AutoGraphを用いるのにもっとも適しているコードは、制御フローを バッチ の単位で決定するようなコードです。もし、個々の 要素 の単位で制御を決定する場合、パフォーマンスを保つために batch API を試してみてください。

一例として、次の Python コードががあったとします。

def square_if_positive(x):
  return [i ** 2 if i > 0 else i for i in x]


square_if_positive(range(-5, 5))
[-5, -4, -3, -2, -1, 0, 1, 4, 9, 16]

TensorFlowに同等の処理を行わせる場合、次のように記述したくなるかもしれません。 (これは実際には動作します!)

@tf.function
def square_if_positive_naive(x):
  result = tf.TensorArray(tf.int32, size=x.shape[0])
  for i in tf.range(x.shape[0]):
    if x[i] > 0:
      result = result.write(i, x[i] ** 2)
    else:
      result = result.write(i, x[i])
  return result.stack()


square_if_positive_naive(tf.range(-5, 5))
<tf.Tensor: shape=(10,), dtype=int32, numpy=array([-5, -4, -3, -2, -1,  0,  1,  4,  9, 16], dtype=int32)>

しかし、この場合、次のように書くこともできます。

def square_if_positive_vectorized(x):
  return tf.where(x > 0, x ** 2, x)


square_if_positive_vectorized(tf.range(-5, 5))
<tf.Tensor: shape=(10,), dtype=int32, numpy=array([-5, -4, -3, -2, -1,  0,  1,  4,  9, 16], dtype=int32)>