텐서플로 2.0의 tf.function과 오토그래프 (AutoGraph)

TensorFlow.org 에서 보기 구글 코랩(Google Colab)에서 실행하기 깃헙(GitHub) 소스 보기

TF 2.0 버전은 즉시 실행 (eager execution)의 편리함과 TF 1.0의 성능을 합쳤습니다. 이러한 결합의 중심에는 tf.function 이 있는데, 이는 파이썬 문법의 일부를 이식 가능하고 높은 성능의 텐서플로 그래프 코드로 변환시켜줍니다.

tf.function의 멋지고 새로운 특징은 오토그래프 (AutoGraph)입니다. 이는 자연스러운 파이썬 문법을 활용해서 그래프 코드를 작성할 수 있도록 돕습니다. 오토그래프로 사용할 수 있는 파이썬 특징들의 목록을 보려면 오토그래프 지원 범위를 참고하세요. tf.function에 관한 더 자세한 내용을 확인하려면 RFC TF 2.0: Functions, not Sessions을 참고하세요. 오토그래프에 대한 더 자세한 내용은 tf.autograph를 참고하세요.

본 튜토리얼은 tf.function와 오토그래프의 기초적인 특징에 대해서 설명할 것입니다.

설정

텐서플로 2.0 프리뷰 나이틀리 (Preview Nightly) 버전을 임포트(import)하고, TF 2.0 모드를 설정합니다:

import numpy as np
!pip install -q tensorflow==2.0.0-beta1
import tensorflow as tf
/tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/framework/dtypes.py:516: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
/tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/framework/dtypes.py:517: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
/tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/framework/dtypes.py:518: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
/tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/framework/dtypes.py:519: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
/tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/framework/dtypes.py:520: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
/tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/framework/dtypes.py:525: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
  np_resource = np.dtype([("resource", np.ubyte, 1)])
/tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorboard/compat/tensorflow_stub/dtypes.py:541: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
/tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorboard/compat/tensorflow_stub/dtypes.py:542: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
/tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorboard/compat/tensorflow_stub/dtypes.py:543: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
/tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorboard/compat/tensorflow_stub/dtypes.py:544: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
/tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorboard/compat/tensorflow_stub/dtypes.py:545: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
/tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorboard/compat/tensorflow_stub/dtypes.py:550: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
  np_resource = np.dtype([("resource", np.ubyte, 1)])

tf.function 데코레이터

tf.function을 함수에 붙여줄 경우, 여전히 다른 일반 함수들처럼 사용할 수 있습니다. 하지만 그래프 내에서 컴파일 되었을 때는 더 빠르게 실행하고, GPU나 TPU를 사용해서 작동하고, 세이브드모델(SavedModel)로 내보내는 것이 가능해집니다.

@tf.function
def simple_nn_layer(x, y):
  return tf.nn.relu(tf.matmul(x, y))


x = tf.random.uniform((3, 3))
y = tf.random.uniform((3, 3))

simple_nn_layer(x, y)
<tf.Tensor: id=23, shape=(3, 3), dtype=float32, numpy=
array([[0.4161443 , 0.98615474, 0.75076133],
       [0.48243555, 1.1921146 , 0.89628875],
       [0.34842876, 0.14859593, 0.17541337]], dtype=float32)>

데코레이터를 붙인 결과를 확인해보면, 텐서플로 런타임시의 모든 상호작용들을 다룰 수 있다는 것을 알 수 있습니다.

simple_nn_layer
<tensorflow.python.eager.def_function.Function at 0x7febe0e35a58>

만일 여러분의 코드가 여러 함수들을 포함하고 있다면, 그것들에 모두 데코레이터를 붙일 필요는 없습니다. 데코레이터가 붙은 함수로부터 호출된 모든 함수들은 그래프 모드에서 동작합니다.

def linear_layer(x):
  return 2 * x + 1


@tf.function
def deep_net(x):
  return tf.nn.relu(linear_layer(x))


deep_net(tf.constant((1, 2, 3)))
<tf.Tensor: id=36, shape=(3,), dtype=int32, numpy=array([3, 5, 7], dtype=int32)>

작은 연산들을 많이 포함한 그래프의 경우 함수들은 즉시 실행 코드 (eager code) 보다 더 빠르게 동작합니다. 하지만 무거운 연산들을 조금 포함한 그래프의 경우 (컨볼루션 등), 그렇게 빠른 속도 향상은 기대하기 어렵습니다.

import timeit
conv_layer = tf.keras.layers.Conv2D(100, 3)

@tf.function
def conv_fn(image):
  return conv_layer(image)

image = tf.zeros([1, 200, 200, 100])
# 데이터 준비 (warm up)
conv_layer(image); conv_fn(image)
print("컨볼루션 즉시 실행:", timeit.timeit(lambda: conv_layer(image), number=10))
print("컨볼루션 함수:", timeit.timeit(lambda: conv_fn(image), number=10))
print("컨볼루션의 성능에는 큰 차이가 없다는 것을 확인할 수 있습니다")

컨볼루션 즉시 실행: 0.21622591999766883
컨볼루션 함수: 0.20913486199788167
컨볼루션의 성능에는 큰 차이가 없다는 것을 확인할 수 있습니다

lstm_cell = tf.keras.layers.LSTMCell(10)

@tf.function
def lstm_fn(input, state):
  return lstm_cell(input, state)

input = tf.zeros([10, 10])
state = [tf.zeros([10, 10])] * 2
# 데이터 준비 (warm up)
lstm_cell(input, state); lstm_fn(input, state)
print("lstm 즉시 실행:", timeit.timeit(lambda: lstm_cell(input, state), number=10))
print("lstm 함수:", timeit.timeit(lambda: lstm_fn(input, state), number=10))

lstm 즉시 실행: 0.005552468999667326
lstm 함수: 0.0037669710000045598

파이썬의 제어 흐름 사용하기

tf.function 내에서 데이터 기반 제어 흐름을 사용할 때, 파이썬의 제어 흐름 문을 사용할 수 있고, 오토그래프 기능은 그것들을 모두 적절한 텐서플로 연산으로 변환할 수 있습니다. 예를 들어, if 문은 Tensor를 기반으로 작동해야할 때 tf.cond() 로 변환될 수 있습니다.

아래 예시에서, xTensor이지만 if문이 예상한대로 정상 작동합니다:

@tf.function
def square_if_positive(x):
  if x > 0:
    x = x * x
  else:
    x = 0
  return x


print('square_if_positive(2) = {}'.format(square_if_positive(tf.constant(2))))
print('square_if_positive(-2) = {}'.format(square_if_positive(tf.constant(-2))))
square_if_positive(2) = 4
square_if_positive(-2) = 0

오토그래프는 기본적인 파이썬 문인 while, for, if, break, continue, return과 네스팅(nesting)을 지원합니다. 이는 Tensor 표현을 whileif 문의 조건 부분에서 사용하거나 for 루프에서 Tensor를 반복할 수 있다는 것을 의미합니다.

@tf.function
def sum_even(items):
  s = 0
  for c in items:
    if c % 2 > 0:
      continue
    s += c
  return s


sum_even(tf.constant([10, 12, 15, 20]))
<tf.Tensor: id=606, shape=(), dtype=int32, numpy=42>

또한 오토그래프는 고급 사용자를 위해 낮은 수준의 API를 제공합니다. 예를 들어, 여러분은 생성된 코드를 확인하기 위해 다음과 같이 작성할 수 있습니다.

print(tf.autograph.to_code(sum_even.python_function))
def tf__sum_even(items):
  do_return = False
  retval_ = ag__.UndefinedReturnValue()
  s = 0

  def loop_body(loop_vars, s_2):
    c = loop_vars
    continue_ = False
    cond = c % 2 > 0

    def get_state():
      return ()

    def set_state(_):
      pass

    def if_true():
      continue_ = True
      return continue_

    def if_false():
      return continue_
    continue_ = ag__.if_stmt(cond, if_true, if_false, get_state, set_state)
    cond_1 = ag__.not_(continue_)

    def get_state_1():
      return ()

    def set_state_1(_):
      pass

    def if_true_1():
      s_1, = s_2,
      s_1 += c
      return s_1

    def if_false_1():
      return s_2
    s_2 = ag__.if_stmt(cond_1, if_true_1, if_false_1, get_state_1, set_state_1)
    return s_2,
  s, = ag__.for_stmt(items, None, loop_body, (s,))
  do_return = True
  retval_ = s
  cond_2 = ag__.is_undefined_return(retval_)

  def get_state_2():
    return ()

  def set_state_2(_):
    pass

  def if_true_2():
    retval_ = None
    return retval_

  def if_false_2():
    return retval_
  retval_ = ag__.if_stmt(cond_2, if_true_2, if_false_2, get_state_2, set_state_2)
  return retval_


다음은 더 복잡한 제어 흐름의 예시입니다:

@tf.function
def fizzbuzz(n):
  msg = tf.constant('')
  for i in tf.range(n):
    if tf.equal(i % 3, 0):
      tf.print('Fizz')
    elif tf.equal(i % 5, 0):
      tf.print('Buzz')
    else:
      tf.print(i)

fizzbuzz(tf.constant(15))
Fizz
1
2
Fizz
4
Buzz
Fizz
7
8
Fizz
Buzz
11
Fizz
13
14

케라스와 오토그래프

오토그래프는 기본적으로 비동적(non-dynamic) 케라스 모델에서 사용 가능합니다. 더 자세한 정보를 원한다면, tf.keras를 참고하세요.

class CustomModel(tf.keras.models.Model):

  @tf.function
  def call(self, input_data):
    if tf.reduce_mean(input_data) > 0:
      return input_data
    else:
      return input_data // 2


model = CustomModel()

model(tf.constant([-2, -4]))
<tf.Tensor: id=723, shape=(2,), dtype=int32, numpy=array([-1, -2], dtype=int32)>

부수 효과 (Side effects)

즉시 실행 모드 (eager mode)처럼 부수 효과를 사용할 수 있습니다. 예를 들면, tf.function 내에 있는 tf.assign이나 tf.print이 있습니다. 또한 부수 효과들은 작업들이 순서대로 실행된다는 것을 보장하기 위해 필수적인 제어 의존성 (control dependency)을 추가합니다.

v = tf.Variable(5)

@tf.function
def find_next_odd():
  v.assign(v + 1)
  if tf.equal(v % 2, 0):
    v.assign(v + 1)


find_next_odd()
v
<tf.Variable 'Variable:0' shape=() dtype=int32, numpy=7>

디버깅

tf.function 과 오토그래프는 코드를 생성하고 텐서플로 그래프 내에서 해당 코드를 추적함으로써 동작합니다. 이 메커니즘은 아직까지는 pdb같은 단계적 (step-by-step) 디버거를 지원하지 않습니다. 하지만 일시적으로 tf.function 내에서 즉시 실행 (eager execution)을 가능하게 하는 tf.config.run_functions_eagerly(True)을 사용하고 가장 선호하는 디버거를 사용할 수 있습니다:

@tf.function
def f(x):
  if x > 0:
    # 여기에 중단점(breakpoint)을 설정해 보세요!
    # 예시:
    #   import pdb
    #   pdb.set_trace()
    x = x + 1
  return x

tf.config.experimental_run_functions_eagerly(True)

# 이제 중단점을 설정하고 디버거 내에서 코드를 실행할 수 있습니다.
f(tf.constant(1))

tf.config.experimental_run_functions_eagerly(False)

고급 예제: 그래프 내 훈련 루프

이전 섹션은 케라스 레이어나 모델 내부에서 오토그래프를 활용할 수 있는 것을 보여주었습니다. 오토그래프 코드 안에서 케라스 모델을 활용할 수도 있습니다.

이 예제는 배치 불러오기, 그래디언트 계산, 매개변수 갱신, 검증 정확도 계산, 수렴까지 반복 등 그래프 내에서 수행되는 전체 훈련 과정을 통해 간단한 케라스 모델이 어떻게 MNIST 데이터셋에 훈련되는지 보여줍니다.

데이터 다운로드

def prepare_mnist_features_and_labels(x, y):
  x = tf.cast(x, tf.float32) / 255.0
  y = tf.cast(y, tf.int64)
  return x, y

def mnist_dataset():
  (x, y), _ = tf.keras.datasets.mnist.load_data()
  ds = tf.data.Dataset.from_tensor_slices((x, y))
  ds = ds.map(prepare_mnist_features_and_labels)
  ds = ds.take(20000).shuffle(20000).batch(100)
  return ds

train_dataset = mnist_dataset()

모델 정의하기

model = tf.keras.Sequential((
    tf.keras.layers.Reshape(target_shape=(28 * 28,), input_shape=(28, 28)),
    tf.keras.layers.Dense(100, activation='relu'),
    tf.keras.layers.Dense(100, activation='relu'),
    tf.keras.layers.Dense(10)))
model.build()
optimizer = tf.keras.optimizers.Adam()

훈련 (training) 루프 정의하기

compute_loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)

compute_accuracy = tf.keras.metrics.SparseCategoricalAccuracy()


def train_one_step(model, optimizer, x, y):
  with tf.GradientTape() as tape:
    logits = model(x)
    loss = compute_loss(y, logits)

  grads = tape.gradient(loss, model.trainable_variables)
  optimizer.apply_gradients(zip(grads, model.trainable_variables))

  compute_accuracy(y, logits)
  return loss


@tf.function
def train(model, optimizer):
  train_ds = mnist_dataset()
  step = 0
  loss = 0.0
  accuracy = 0.0
  for x, y in train_ds:
    step += 1
    loss = train_one_step(model, optimizer, x, y)
    if tf.equal(step % 10, 0):
      tf.print('스텝', step, ': 손실', loss, '; 정확도', compute_accuracy.result())
  return step, loss, accuracy

step, loss, accuracy = train(model, optimizer)
print('최종 스텝', step, ': 손실', loss, '; 정확도', compute_accuracy.result())
스텝 10 : 손실 1.75508547 ; 정확도 0.307
스텝 20 : 손실 1.2521 ; 정확도 0.453
스텝 30 : 손실 0.786321282 ; 정확도 0.558333337
스텝 40 : 손실 0.605602443 ; 정확도 0.629
스텝 50 : 손실 0.495852113 ; 정확도 0.672
스텝 60 : 손실 0.582676411 ; 정확도 0.700666666
스텝 70 : 손실 0.541174173 ; 정확도 0.726142883
스텝 80 : 손실 0.274435937 ; 정확도 0.746875
스텝 90 : 손실 0.341384798 ; 정확도 0.765111089
스텝 100 : 손실 0.282570839 ; 정확도 0.7782
스텝 110 : 손실 0.40306896 ; 정확도 0.789181828
스텝 120 : 손실 0.231302723 ; 정확도 0.79925
스텝 130 : 손실 0.305687934 ; 정확도 0.808307707
스텝 140 : 손실 0.220763892 ; 정확도 0.8155
스텝 150 : 손실 0.314131081 ; 정확도 0.8214
스텝 160 : 손실 0.354495198 ; 정확도 0.82743752
스텝 170 : 손실 0.261185437 ; 정확도 0.832705855
스텝 180 : 손실 0.242653638 ; 정확도 0.837166667
스텝 190 : 손실 0.261884 ; 정확도 0.84094739
스텝 200 : 손실 0.318644553 ; 정확도 0.84535
최종 스텝 tf.Tensor(200, shape=(), dtype=int32) : 손실 tf.Tensor(0.31864455, shape=(), dtype=float32) ; 정확도 tf.Tensor(0.84535, shape=(), dtype=float32)

배치 (Batching)

실제 적용시에 배치 (batch) 는 성능을 위해 필수적입니다. 오토그래프로 변환하기 가장 좋은 코드는 제어 흐름이 배치 수준에서 결정되는 코드입니다. 만일 제어 흐름이 개별적인 예제 (example) 수준에서 결정된다면, 성능을 유지하기 위해서 배치 API들을 사용해야합니다.

예를 들어, 파이썬으로 다음과 같은 코드를 작성했다면:

def square_if_positive(x):
  return [i ** 2 if i > 0 else i for i in x]


square_if_positive(range(-5, 5))
[-5, -4, -3, -2, -1, 0, 1, 4, 9, 16]

텐서플로에서는 다음과 같이 작성하고 싶을 것입니다. (그리고 다음 코드는 실제로 동작합니다!):

@tf.function
def square_if_positive_naive(x):
  result = tf.TensorArray(tf.int32, size=x.shape[0])
  for i in tf.range(x.shape[0]):
    if x[i] > 0:
      result = result.write(i, x[i] ** 2)
    else:
      result = result.write(i, x[i])
  return result.stack()


square_if_positive_naive(tf.range(-5, 5))
<tf.Tensor: id=1840, shape=(10,), dtype=int32, numpy=array([-5, -4, -3, -2, -1,  0,  1,  4,  9, 16], dtype=int32)>

하지만 이 경우는 아래와 같이 작성할 수도 있습니다:

def square_if_positive_vectorized(x):
  return tf.where(x > 0, x ** 2, x)


square_if_positive_vectorized(tf.range(-5, 5))
<tf.Tensor: id=1850, shape=(10,), dtype=int32, numpy=array([-5, -4, -3, -2, -1,  0,  1,  4,  9, 16], dtype=int32)>