tf.function으로 성능 향상하기

TensorFlow.org 에서 보기 구글 코랩(Google Colab)에서 실행하기 깃헙(GitHub) 소스 보기 Download notebook

텐서플로 2에서는 즉시 실행(eager execution)이 기본적으로 활성화되어 있습니다. 직관적이고 유연한 사용자 인터페이스를 제공하지만 성능과 배포에 비용이 더 듭니다(하나의 연산을 실행할 때는 훨씬 간단하고 빠릅니다).

성능을 높이고 이식성이 좋은 모델을 만들려면 tf.function을 사용해 그래프로 변환하세요. 하지만 조심해야 할 점이 있습니다. tf.function은 무조건 속도를 높여주는 마법의 은총알이 아닙니다!

이 가이드는 tf.function의 이면에 있는 개념을 이해하고 사용법을 완전히 터득할 수 있도록 도울 것입니다.

여기서 배울 주요 내용과 권고 사항은 다음과 같습니다:

  • 즉시 실행 모드에서 디버깅한 다음 @tf.function으로 데코레이팅하세요.
  • 객체 변경(object mutation)이나 리스트 요소 추가 같은 파이썬의 부수 효과에 의존하지 마세요.
  • tf.function은 텐서플로 연산과 가장 잘 동작합니다: 넘파이와 파이썬 호출은 상수로 바뀝니다.

설정

import tensorflow as tf

에러 출력을 위한 헬퍼 함수를 정의합니다:

import traceback
import contextlib

# 에러 출력을 위한 헬퍼 함수
@contextlib.contextmanager
def assert_raises(error_class):
  try:
    yield
  except error_class as e:
    print('기대하는 예외 발생 \n  {}:'.format(error_class))
    traceback.print_exc(limit=2)
  except Exception as e:
    raise e
  else:
    raise Exception('{}를 기대했지만 아무런 에러도 발생되지 않았습니다!'.format(
        error_class))

기초

tf.function으로 정의한 함수는 기본 텐서플로 연산과 같습니다. 즉시 실행 모드로 실행하거나 그레이디언트를 계산할 수 있습니다.

@tf.function
def add(a, b):
  return a + b

add(tf.ones([2, 2]), tf.ones([2, 2]))  #  [[2., 2.], [2., 2.]]
<tf.Tensor: shape=(2, 2), dtype=float32, numpy=
array([[2., 2.],
       [2., 2.]], dtype=float32)>
v = tf.Variable(1.0)
with tf.GradientTape() as tape:
  result = add(v, 1.0)
tape.gradient(result, v)
<tf.Tensor: shape=(), dtype=float32, numpy=1.0>

다른 함수 내부에 사용할 수 있습니다.

@tf.function
def dense_layer(x, w, b):
  return add(tf.matmul(x, w), b)

dense_layer(tf.ones([3, 2]), tf.ones([2, 2]), tf.ones([2]))
<tf.Tensor: shape=(3, 2), dtype=float32, numpy=
array([[3., 3.],
       [3., 3.],
       [3., 3.]], dtype=float32)>

tf.function은 즉시 실행 모드 보다 빠릅니다. 특히 그래프에 작은 연산이 많을 때 그렇습니다. 하지만 (합성곱처럼) 계산량이 많은 연산 몇 개로 이루어진 그래프는 속도 향상이 크지 않습니다.

import timeit
conv_layer = tf.keras.layers.Conv2D(100, 3)

@tf.function
def conv_fn(image):
  return conv_layer(image)

image = tf.zeros([1, 200, 200, 100])
# 워밍 업
conv_layer(image); conv_fn(image)
print("즉시 실행 합성곱:", timeit.timeit(lambda: conv_layer(image), number=10))
print("tf.function 합성곱:", timeit.timeit(lambda: conv_fn(image), number=10))
print("합성곱 연산 속도에 큰 차이가 없습니다.")
즉시 실행 합성곱: 0.0029387790000328096
tf.function 합성곱: 0.0031409620000886207
합성곱 연산 속도에 큰 차이가 없습니다.

디버깅

일반적으로 tf.function 보다 즉시 실행 모드가 디버깅하기 쉽습니다. tf.function으로 데코레이팅하기 전에 즉시 실행 모드에서 에러가 없는지 확인하세요. 디버깅 과정을 위해 tf.config.run_functions_eagerly(True)으로 전체 tf.function을 비활성화하고 나중에 다시 활성화할 수 있습니다.

tf.function 함수에서 버그를 추적할 때 다음 팁을 참고하세요:

  • 파이썬 print 함수는 트레이싱(tracing)하는 동안에만 호출되므로 함수가 (재)트레이싱될 때 추적하는데 도움이 됩니다.
  • tf.print 함수는 언제나 실행되므로 실행하는 동안 중간 값을 추적할 때 도움이 됩니다.
  • tf.debugging.enable_check_numerics을 사용하면 쉽게 NaN과 Inf가 발생되는 곳을 추적할 수 있습니다.
  • pdb는 어떻게 트레이싱이 일어나는지 이해하는데 도움이 됩니다(주의: pdb는 오토그래프(AutoGraph)가 변환한 소스 코드를 보여줄 것입니다).

트레이싱과 다형성

파이썬의 동적 타이핑 덕분에 여러 종류의 매개변수 타입을 사용해 함수를 호출할 수 있고 파이썬은 각기 다르게 수행됩니다.

반면 텐서플로 그래프는 정적인 dtype과 shape 차원이 필요합니다. tf.function은 올바른 그래프를 생성하기 위해 필요하면 함수를 다시 트레이싱하여 이 문제를 해결합니다. tf.function을 사용할 때 발생하는 문제점은 대부분 이런 재트레이싱(retracing) 동작에서 옵니다.

다른 종류의 매개변수를 함수를 호출할 때 무슨 일이 일어나는지 확인해 보죠.

# 함수와 다형성

@tf.function
def double(a):
  print("트레이싱:", a)
  return a + a

print(double(tf.constant(1)))
print()
print(double(tf.constant(1.1)))
print()
print(double(tf.constant("a")))
print()
트레이싱: Tensor("a:0", shape=(), dtype=int32)
tf.Tensor(2, shape=(), dtype=int32)

트레이싱: Tensor("a:0", shape=(), dtype=float32)
tf.Tensor(2.2, shape=(), dtype=float32)

트레이싱: Tensor("a:0", shape=(), dtype=string)
tf.Tensor(b'aa', shape=(), dtype=string)


트레이싱 동작을 제어하기 위해 다음 기법을 사용할 수 있습니다:

새로운 tf.function을 만듭니다. 별도의 tf.function 객체는 트레이싱이 따로 일어납니다.

def f():
  print('트레이싱!')
  tf.print('실행')

tf.function(f)()
tf.function(f)()
트레이싱!
실행
트레이싱!
실행

get_concrete_function 메서드를 사용해 트레이싱된 특정 함수를 얻을 수 있습니다.

print("콘크리트 함수 얻기")
double_strings = double.get_concrete_function(tf.TensorSpec(shape=None, dtype=tf.string))
print("트레이싱된 함수 실행")
print(double_strings(tf.constant("a")))
print(double_strings(a=tf.constant("b")))
print("콘크리트 함수에 다른 타입을 사용하면 예외가 발생합니다")
with assert_raises(tf.errors.InvalidArgumentError):
  double_strings(tf.constant(1))
콘크리트 함수 얻기
트레이싱: Tensor("a:0", dtype=string)
트레이싱된 함수 실행
tf.Tensor(b'aa', shape=(), dtype=string)
tf.Tensor(b'bb', shape=(), dtype=string)
콘크리트 함수에 다른 타입을 사용하면 예외가 발생합니다
기대하는 예외 발생 
  <class 'tensorflow.python.framework.errors_impl.InvalidArgumentError'>:

Traceback (most recent call last):
  File "<ipython-input-1-0b7df3a10ecd>", line 8, in assert_raises
    yield
  File "<ipython-input-1-4d5317f9d67a>", line 8, in <module>
    double_strings(tf.constant(1))
tensorflow.python.framework.errors_impl.InvalidArgumentError: cannot compute __inference_double_182 as input #0(zero-based) was expected to be a string tensor but is a int32 tensor [Op:__inference_double_182]

tf.functioninput_signature를 지정하여 트레이싱을 제한할 수도 있습니다.

@tf.function(input_signature=(tf.TensorSpec(shape=[None], dtype=tf.int32),))
def next_collatz(x):
  print("트레이싱", x)
  return tf.where(x % 2 == 0, x // 2, 3 * x + 1)

print(next_collatz(tf.constant([1, 2])))
# input_signature에 1-D 텐서를 지정했기 때문에 다음은 실패합니다.
with assert_raises(ValueError):
  next_collatz(tf.constant([[1, 2], [3, 4]]))
트레이싱 Tensor("x:0", shape=(None,), dtype=int32)
tf.Tensor([4 1], shape=(2,), dtype=int32)
기대하는 예외 발생 
  <class 'ValueError'>:

Traceback (most recent call last):
  File "<ipython-input-1-0b7df3a10ecd>", line 8, in assert_raises
    yield
  File "<ipython-input-1-4de5c726506e>", line 9, in <module>
    next_collatz(tf.constant([[1, 2], [3, 4]]))
ValueError: Python inputs incompatible with input_signature:
  inputs: (
    tf.Tensor(
[[1 2]
 [3 4]], shape=(2, 2), dtype=int32))
  input_signature: (
    TensorSpec(shape=(None,), dtype=tf.int32, name=None))

언제 다시 트레이싱되나요?

다형성을 지원하는 tf.function은 트레이싱으로 생성된 콘크리트 함수를 캐싱합니다. 이 캐시의 키는 함수의 위치 매개변수(args)와 키워드 매개변수(kwargs)에서 생성된 키의 튜플입니다. tf.Tensor 매개변수를 위해 생성된 키는 차원 개수와 타입이 됩니다. 파이썬 기본 자료형(정수, 실수, 문자열, 불리언)으로 생성된 키는 해당 변수의 값이 됩니다. 그외 다른 파이썬 타입에서 키는 id()를 기반으로 합니다. 따라서 클래스 메서드는 인스턴스마다 독립적으로 트레이싱됩니다. 향후 텐서플로는 파이썬 객체를 안전하게 텐서로 변환하기 위한 고급 캐싱 기능을 제공할 수 있습니다.

콘크리트 함수를 참고하세요.

파이썬 매개변수 vs 텐서 매개변수

하이퍼파라미터 조작하고 그래프를 구성하기 위해 파이썬 매개변수가 자주 사용됩니다. 예를 들면 num_layers=10이나 training=True, nonlinearity='relu'입니다. 파이썬 매개변수가 바뀌면 그래프가 다시 트레이싱됩니다.

하지만 파이썬 매개변수가 그래프 구성에 사용되지 않을 수 있습니다. 이런 경우 파이썬 값이 변하면 불필요한 재트레이싱을 일으킵니다. 예를 들어 다음은 오토그래프가 동적으로 펼치는 훈련 반복 루프입니다. 다중 트레이싱이 되었지만 생성된 그래프는 실제로 동일하기 때문에 조금 비효율적입니다.

def train_one_step():
  pass

@tf.function
def train(num_steps):
  print("트레이싱 num_steps = {}".format(num_steps))
  for _ in tf.range(num_steps):
    train_one_step()

train(num_steps=10)
train(num_steps=20)
트레이싱 num_steps = 10
트레이싱 num_steps = 20

이를 해결하는 간단한 방법은 생성된 그래프에 영향을 미치지 않도록 매개변수를 Tensor로 바꾸는 것입니다.

train(num_steps=tf.constant(10))
train(num_steps=tf.constant(20))
트레이싱 num_steps = Tensor("num_steps:0", shape=(), dtype=int32)

tf.function의 부수 효과

일반적으로 (출력이나 객체 변경 같은) 파이썬 부수 효과(side effect)는 트레이싱 동안에만 일어납니다. 어떻게 tf.function에서 안정적으로 부수 효과를 일으킬 수 있을까요?

일반적인 규칙은 파이썬 부수 효과만을 사용하여 트레이싱을 디버깅하는 것입니다. 그외에는 tf.Variable.assign, tf.print, tf.summary 같은 텐서플로 연산이 텐서플로 런타임에 의해 코드가 트레이싱되고 실행되는지 확인하는 가장 좋은 방법입니다. 일반적으로 함수 스타일을 사용하는 것이 가장 좋습니다.

@tf.function
def f(x):
  print("트레이싱", x)
  tf.print("실행", x)

f(1)
f(1)
f(2)
트레이싱 1
실행 1
실행 1
트레이싱 2
실행 2

tf.function을 호출할 때마다 파이썬 코드를 실행하려면 tf.py_function이 해결책입니다. tf.py_function의 단점은 이식성과 성능이 좋지 않고 분산 환경(다중 GPU나 다중 TPU)에서 잘 동작하지 않는다는 것입니다. 또한 tf.py_function은 미분 가능하도록 그래프를 만들기 때문에 모든 입력/출력을 텐서로 변환합니다.

external_list = []

def side_effect(x):
  print('파이썬 부수 효과')
  external_list.append(x)

@tf.function
def f(x):
  tf.py_function(side_effect, inp=[x], Tout=[])

f(1)
f(1)
f(1)
assert len(external_list) == 3
# py_function이 1을 tf.constant(1)로 바꾸기 때문에 .numpy()를 호출해야 합니다.
assert external_list[0].numpy() == 1
파이썬 부수 효과
파이썬 부수 효과
파이썬 부수 효과

파이썬 상태 주의하기

제러네이터와 반복자(iterator) 같은 파이썬의 많은 기능은 상태 추적을 위해 파이썬 런타임에 의존합니다. 일반적으로 이런 요소들은 즉시 실행 모드와 같이 동작하지만 트레이싱 동작 때문에 tf.function 안에서는 예상밖의 일이 일어날 수 있습니다.

예를 하나 들면, 다음 반복자 값을 얻는 것이 파이썬 부수 효과이기 때문에 트레이싱 동안에만 일어납니다.

external_var = tf.Variable(0)
@tf.function
def buggy_consume_next(iterator):
  external_var.assign_add(next(iterator))
  tf.print("external_var의 값:", external_var)

iterator = iter([0, 1, 2, 3])
buggy_consume_next(iterator)
# 다음은 반복자의 다음 값을 추출하지 않고 첫 번째 값을 재사용합니다.
buggy_consume_next(iterator)
buggy_consume_next(iterator)
external_var의 값: 0
external_var의 값: 0
external_var의 값: 0

변수

코드가 의도한 순서대로 실행되는 것처럼 tf.function에서 매우 쉽게 변수를 생성하고 사용할 수 있습니다. 하지만 아주 중요한 주의 사항이 있습니다. 변수는 즉시 실행 모드와 그래프 모드에서 다르게 동작하는 코드를 만들 수 있습니다.

특히 호출마다 새로운 변수를 만들 때 일어납니다. 트레이싱 구조 때문에 tf.function은 호출마다같은 변수를 재사용합니다. 하지만 즉시 실행 모드에서는 호출마다 새로운 변수가 생성됩니다. 이런 실수를 방지하기 위해 tf.function은 위험한 변수 생성이 감지되면 에러를 발생합니다.

@tf.function
def f(x):
  v = tf.Variable(1.0)
  v.assign_add(x)
  return v

with assert_raises(ValueError):
  f(1.0)
기대하는 예외 발생 
  <class 'ValueError'>:

Traceback (most recent call last):
  File "<ipython-input-1-0b7df3a10ecd>", line 8, in assert_raises
    yield
  File "<ipython-input-1-73e410646579>", line 8, in <module>
    f(1.0)
ValueError: in user code:

    <ipython-input-1-73e410646579>:3 f  *
        v = tf.Variable(1.0)
    /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/ops/variables.py:262 __call__  **
        return cls._variable_v2_call(*args, **kwargs)
    /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/ops/variables.py:256 _variable_v2_call
        shape=shape)
    /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/ops/variables.py:67 getter
        return captured_getter(captured_previous, **kwargs)
    /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/eager/def_function.py:702 invalid_creator_scope
        "tf.function-decorated function tried to create "

    ValueError: tf.function-decorated function tried to create variables on non-first call.


하지만 모호하지 않은 코드는 괜찮습니다.

v = tf.Variable(1.0)

@tf.function
def f(x):
  return v.assign_add(x)

print(f(1.0))  # 2.0
print(f(2.0))  # 4.0
tf.Tensor(2.0, shape=(), dtype=float32)
tf.Tensor(4.0, shape=(), dtype=float32)

함수가 처음 호출될 때만 변수가 생성되도록 tf.function 안에서 변수를 생성할 수 있습니다.

class C:
  pass

obj = C()
obj.v = None

@tf.function
def g(x):
  if obj.v is None:
    obj.v = tf.Variable(1.0)
  return obj.v.assign_add(x)

print(g(1.0))  # 2.0
print(g(2.0))  # 4.0
tf.Tensor(2.0, shape=(), dtype=float32)
tf.Tensor(4.0, shape=(), dtype=float32)

변수 초기화가 함수 매개변수와 다른 변수 값에 의존할 수 있습니다. 올바른 초기화 순서를 찾기 위해 제어 의존성을 생성하는 메서드를 사용할 수 있습니다.

state = []
@tf.function
def fn(x):
  if not state:
    state.append(tf.Variable(2.0 * x))
    state.append(tf.Variable(state[0] * 3.0))
  return state[0] * x * state[1]

print(fn(tf.constant(1.0)))
print(fn(tf.constant(3.0)))
tf.Tensor(12.0, shape=(), dtype=float32)
tf.Tensor(36.0, shape=(), dtype=float32)

오토그래프 변환

오토그래프(AutoGraph)는 tf.function안에 기본으로 활성화되어 있습니다. 파이썬의 즉시 실행 코드를 그래프 호환 텐서플로 연산으로 변환합니다. 여기에는 if, for, while 같은 제어 흐름이 포함됩니다.

tf.condtf.while_loop 같은 텐서플로 연산을 여전히 사용할 수 있지만 파이썬으로 제어 흐름을 작성하는 것이 만들기도 이해하기도 쉽습니다.

# 간단한 루프

@tf.function
def f(x):
  while tf.reduce_sum(x) > 1:
    tf.print(x)
    x = tf.tanh(x)
  return x

f(tf.random.uniform([5]))
[0.718375444 0.318117738 0.180776238 0.81085825 0.47323966]
[0.615902 0.307803959 0.178832397 0.670063436 0.440813184]
[0.548268318 0.298437953 0.176950067 0.585021555 0.414318264]
[0.499221236 0.289882481 0.175126061 0.526305676 0.392133117]
[0.461504459 0.282026649 0.173357442 0.482552052 0.373197705]
[0.431309611 0.274779737 0.171641454 0.448285133 0.356785536]
[0.406415194 0.268066764 0.169975519 0.420488387 0.34237951]
[0.385424644 0.261825025 0.168357268 0.397341818 0.329600066]
[0.36740917 0.256001741 0.16678445 0.377672225 0.318161368]
[0.351723462 0.250552028 0.165255 0.360684216 0.307843477]
[0.337903112 0.245437503 0.163766921 0.345816553 0.298473954]
[0.325604081 0.240625083 0.162318408 0.332660228 0.289915442]
[0.314565331 0.236086085 0.16090773 0.320909083 0.282057]
[0.304584622 0.231795505 0.159533262 0.310328662 0.274807781]
[0.295502543 0.227731436 0.158193484 0.30073604 0.268092781]
[0.287191451 0.223874584 0.156886965 0.291986048 0.261849254]
[0.279547781 0.2202079 0.155612335 0.283961743 0.25602439]
[0.272486478 0.216716215 0.154368326 0.276567787 0.250573277]
[0.265936971 0.213386014 0.153153732 0.269725531 0.24545747]
[0.259840131 0.210205197 0.151967406 0.263369411 0.240643889]
[0.25414598 0.207162902 0.150808275 0.257444352 0.236103833]
[0.24881196 0.204249337 0.149675295 0.251903594 0.231812298]
[0.243801564 0.201455668 0.148567513 0.246707231 0.227747351]
[0.23908326 0.198773876 0.147484 0.241820917 0.223889709]
[0.234629661 0.196196675 0.146423891 0.237214938 0.220222279]
[0.230416864 0.193717435 0.145386353 0.232863426 0.216729909]
[0.226423889 0.19133009 0.144370586 0.228743732 0.213399082]

<tf.Tensor: shape=(5,), dtype=float32, numpy=
array([0.22263221, 0.1890291 , 0.14337584, 0.22483593, 0.21021768],
      dtype=float32)>

관심있다면 오토그래프가 생성한 코드를 확인해 볼 수 있습니다.

print(tf.autograph.to_code(f.python_function))
def tf__f(x):
    with ag__.FunctionScope('f', 'fscope', ag__.ConversionOptions(recursive=True, user_requested=True, optional_features=(), internal_convert_user_code=True)) as fscope:
        do_return = False
        retval_ = ag__.UndefinedReturnValue()

        def get_state():
            return (x,)

        def set_state(vars_):
            nonlocal x
            (x,) = vars_

        def loop_body():
            nonlocal x
            ag__.converted_call(ag__.ld(tf).print, (ag__.ld(x),), None, fscope)
            x = ag__.converted_call(ag__.ld(tf).tanh, (ag__.ld(x),), None, fscope)

        def loop_test():
            return (ag__.converted_call(ag__.ld(tf).reduce_sum, (ag__.ld(x),), None, fscope) > 1)
        ag__.while_stmt(loop_test, loop_body, get_state, set_state, ('x',), {})
        try:
            do_return = True
            retval_ = ag__.ld(x)
        except:
            do_return = False
            raise
        return fscope.ret(retval_, do_return)


조건문

오토그래프는 if <condition> 문장을 이와 대등한 tf.cond 호출로 변경합니다. 이런 대체는 <condition>이 텐서일 때 수행됩니다. 그렇지 않다면 if 문장은 파이썬 조건문으로 실행됩니다.

트레이싱하는 동안 파이썬 조건문을 실행하기 때문에 정확히 하나의 조건 분기만 그래프에 추가됩니다. 오토그래프가 없다면 이렇게 트레이싱된 그래프는 데이터에 따라 제어 흐름을 바꿀 수 없습니다.

tf.cond는 조건 분기를 트레이싱하고 그래프에 추가하여 실행시 동적으로 분기를 선택합니다. 트레이싱때문에 의도치 않은 부수 효과가 발생될 수 있습니다. 더 자세한 내용은 오토그래프 트레이싱 효과를 참고하세요.

@tf.function
def fizzbuzz(n):
  for i in tf.range(1, n + 1):
    print('루프 트레이싱')
    if i % 15 == 0:
      print('fizzbuzz 브랜치 트레이싱')
      tf.print('fizzbuzz')
    elif i % 3 == 0:
      print('fizz 브랜치 트레이싱')
      tf.print('fizz')
    elif i % 5 == 0:
      print('buzz 브랜치 트레이싱')
      tf.print('buzz')
    else:
      print('디폴트 브랜치 트레이싱')
      tf.print(i)

fizzbuzz(tf.constant(5))
fizzbuzz(tf.constant(20))
루프 트레이싱
fizzbuzz 브랜치 트레이싱
fizz 브랜치 트레이싱
buzz 브랜치 트레이싱
디폴트 브랜치 트레이싱
1
2
fizz
4
buzz
1
2
fizz
4
buzz
fizz
7
8
fizz
buzz
11
fizz
13
14
fizzbuzz
16
17
fizz
19
buzz

오토그래프가 변환한 if 문장에 대한 추가 제약 사항에 대해서는 레퍼런스 문서를 참고하세요.

반복문

오토그래프는 일부 forwhile 문장을 tf.while_loop와 같은 동등한 텐서플로 반복 연산으로 바꿉니다. 변환되지 않으면 파이썬 반복문으로 forwhile 반복문이 실행됩니다.

이런 대체는 다음과 같은 경우에 일어납니다:

  • for x in y: y가 텐서이면 tf.while_loop로 변환됩니다. 특별히 ytf.data.Dataset인 경우에는 tf.data.Dataset 연산의 조합이 생성됩니다.
  • while <condition>: <condition>이 텐서라면 tf.while_loop로 변환됩니다.

파이썬 반복문이 트레이싱 동안 실행되므로 매 반복마다 tf.Graph에 추가적인 연산이 포함됩니다.

텐서플로는 반복문 블럭을 트레이싱하여 실행시 얼마나 많은 반복이 수행될지 동적으로 선택합니다. 반복문 블럭은 생성된 tf.Graph에 한 번만 포함됩니다.

오토그래프가 변환한 forwhile 문장에 대한 추가 제약 사항에 대해서는 레퍼런스 문서를 참고하세요.

파이썬 데이터로 반복하기

흔히 저지르기 쉬운 실수는 tf.function 안에서 파이썬이나 넘파이 데이터로 반복하는 것입니다. 트레이싱 과정 동안 반복이 수행되기 때문에 반복마다 tf.Graph에 복사된 모델이 추가될 것입니다.

tf.function으로 전체 훈련 반복을 감싸고 싶다면 안전한 방법은 데이터를 tf.data.Dataset으로 감싸서 오토그래프가 동적으로 훈련 반복을 펼치게 하는 것입니다.

def measure_graph_size(f, *args):
  g = f.get_concrete_function(*args).graph
  print("{}({})는 그래프에 {}개의 노드를 포함합니다".format(
      f.__name__, ', '.join(map(str, args)), len(g.as_graph_def().node)))

@tf.function
def train(dataset):
  loss = tf.constant(0)
  for x, y in dataset:
    loss += tf.abs(y - x) # 의미없는 연산
  return loss

small_data = [(1, 1)] * 3
big_data = [(1, 1)] * 10
measure_graph_size(train, small_data)
measure_graph_size(train, big_data)

measure_graph_size(train, tf.data.Dataset.from_generator(
    lambda: small_data, (tf.int32, tf.int32)))
measure_graph_size(train, tf.data.Dataset.from_generator(
    lambda: big_data, (tf.int32, tf.int32)))
train([(1, 1), (1, 1), (1, 1)])는 그래프에 11개의 노드를 포함합니다
train([(1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1)])는 그래프에 32개의 노드를 포함합니다
train(<FlatMapDataset shapes: (<unknown>, <unknown>), types: (tf.int32, tf.int32)>)는 그래프에 8개의 노드를 포함합니다
train(<FlatMapDataset shapes: (<unknown>, <unknown>), types: (tf.int32, tf.int32)>)는 그래프에 8개의 노드를 포함합니다

데이터셋으로 파이썬/넘파이 데이터를 감쌀 때 tf.data.Dataset.from_generatortf.data.Dataset.from_tensors의 차이를 주의하세요. 전자는 파이썬에서 데이터를 유지하고 tf.py_function으로 데이터를 가져오므로 성능에 영향을 미칠 수 있습니다. 후자는 그래프에 있는 하나의 큰 tf.constant() 노드로 데이터를 복사하므로 메모리에 영향을 미칠 수 있습니다.

TFRecordDataset, CsvDataset 등으로 파일에서 데이터를 읽는 것이 가장 효율적으로 데이터를 소비하는 방법입니다. 텐서플로는 파이썬을 거치지 않고 비동기적으로 데이터를 적재하고 프리페칭할 수 있기 때문입니다. 조금 더 자세한 정보는 tf.data guide를 참고하세요.

반복하면서 값을 누적하기

반복하면서 중간 값을 누적하는 패턴은 자주 있습니다. 보통 파이썬 리스트나 딕셔너리에 원소를 추가하는 방식을 사용합니다. 하지만 파이썬 부수 효과 때문에 동적으로 펼쳐지는 반복에서는 기대대로 동작하지 않습니다. 대신 tf.TensorArray를 사용해 동적으로 펼쳐지는 반복에서 결과를 누적하세요.

batch_size = 2
seq_len = 3
feature_size = 4

def rnn_step(inp, state):
  return inp + state

@tf.function
def dynamic_rnn(rnn_step, input_data, initial_state):
  # [batch, time, features] -> [time, batch, features]
  input_data = tf.transpose(input_data, [1, 0, 2])
  max_seq_len = input_data.shape[0]

  states = tf.TensorArray(tf.float32, size=max_seq_len)
  state = initial_state
  for i in tf.range(max_seq_len):
    state = rnn_step(input_data[i], state)
    states = states.write(i, state)
  return tf.transpose(states.stack(), [1, 0, 2])
  
dynamic_rnn(rnn_step,
            tf.random.uniform([batch_size, seq_len, feature_size]),
            tf.zeros([batch_size, feature_size]))
<tf.Tensor: shape=(2, 3, 4), dtype=float32, numpy=
array([[[0.3780445 , 0.26010072, 0.5198095 , 0.34903514],
        [1.2200823 , 0.91877973, 0.65711   , 0.7826295 ],
        [2.0425332 , 1.2292334 , 0.6671903 , 1.3648179 ]],

       [[0.7623887 , 0.37711632, 0.5713397 , 0.10503888],
        [1.3576788 , 1.116627  , 1.1307144 , 0.2367965 ],
        [1.3598431 , 1.5509181 , 1.6007028 , 0.9369447 ]]], dtype=float32)>

더 읽을 거리

tf.function을 트레이싱한 후 수행되는 그래프 최적화에 자세히 알고 싶다면 그래플러(Grappler) 가이드를 참고하세요. 데이터 파이프라인을 최적화하고 모델 프로파일링 방법에 대해 알고 싶다면 프로파일러(Profiler) 가이드를 참고하세요.