tf.function으로 성능 향상하기

TensorFlow.org에서 보기 Google Colab에서 실행 GitHub에서 소스 보기 노트북 다운로드

텐서플로 2에서는 즉시 실행(eager execution)이 기본적으로 활성화되어 있습니다. 직관적이고 유연한 사용자 인터페이스를 제공하지만 성능과 배포에 비용이 더 듭니다(하나의 연산을 실행할 때는 훨씬 간단하고 빠릅니다).

성능을 높이고 이식성이 좋은 모델을 만들려면 tf.function을 사용해 그래프로 변환하세요. 하지만 조심해야 할 점이 있습니다. tf.function은 무조건 속도를 높여주는 마법의 은총알이 아닙니다!

이 가이드는 tf.function의 이면에 있는 개념을 이해하고 효과적으로 사용할 수 있도록 돕습니다.

여기서 배울 주요 내용과 권고 사항은 다음과 같습니다:

  • 즉시 실행 모드에서 디버깅한 다음 @tf.function으로 데코레이팅하세요.
  • 객체 변경(object mutation)이나 리스트 요소 추가 같은 Python의 부수 효과에 의존하지 마세요.
  • tf.function은 텐서플로 연산과 가장 잘 동작합니다: 넘파이와 파이썬 호출은 상수로 바뀝니다.

설정

# Update TensorFlow, as this notebook requires version 2.9 or later
!pip install -q -U tensorflow>=2.9.0
import tensorflow as tf
2022-12-14 22:12:06.346776: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory
2022-12-14 22:12:06.346872: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory
2022-12-14 22:12:06.346882: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.

에러 출력을 위한 헬퍼 함수를 정의합니다:

import traceback
import contextlib

# Some helper code to demonstrate the kinds of errors you might encounter.
@contextlib.contextmanager
def assert_raises(error_class):
  try:
    yield
  except error_class as e:
    print('Caught expected exception \n  {}:'.format(error_class))
    traceback.print_exc(limit=2)
  except Exception as e:
    raise e
  else:
    raise Exception('Expected {} to be raised but no error was raised!'.format(
        error_class))

기초

사용법

정의하는 Function(예: @tf.function 데코레이터를 적용하는 예시)은 핵심 TensorFlow 연산과 매우 비슷합니다. 즉, 즉시 실행할 수 있으며 그래디언트 계산과 같은 작업이 가능합니다.

@tf.function  # The decorator converts `add` into a `Function`.
def add(a, b):
  return a + b

add(tf.ones([2, 2]), tf.ones([2, 2]))  #  [[2., 2.], [2., 2.]]
<tf.Tensor: shape=(2, 2), dtype=float32, numpy=
array([[2., 2.],
       [2., 2.]], dtype=float32)>
v = tf.Variable(1.0)
with tf.GradientTape() as tape:
  result = add(v, 1.0)
tape.gradient(result, v)
<tf.Tensor: shape=(), dtype=float32, numpy=1.0>

다른 함수 내부에 사용할 수 있습니다.

@tf.function
def dense_layer(x, w, b):
  return add(tf.matmul(x, w), b)

dense_layer(tf.ones([3, 2]), tf.ones([2, 2]), tf.ones([2]))
<tf.Tensor: shape=(3, 2), dtype=float32, numpy=
array([[3., 3.],
       [3., 3.],
       [3., 3.]], dtype=float32)>

Function은 Eager 코드보다 빠릅니다. 특히 그래프에 작은 ops가 많을 때 그렇습니다. 하지만 (합성곱처럼) 계산량이 많은 ops 몇 개로 이루어진 그래프는 속도 향상이 크지 않습니다.

import timeit
conv_layer = tf.keras.layers.Conv2D(100, 3)

@tf.function
def conv_fn(image):
  return conv_layer(image)

image = tf.zeros([1, 200, 200, 100])
# Warm up
conv_layer(image); conv_fn(image)
print("Eager conv:", timeit.timeit(lambda: conv_layer(image), number=10))
print("Function conv:", timeit.timeit(lambda: conv_fn(image), number=10))
print("Note how there's not much difference in performance for convolutions")
Eager conv: 0.006060872000489326
Function conv: 0.006282959999225568
Note how there's not much difference in performance for convolutions

추적

이 섹션에서는 향후 변경될 수 있는 구현 세부 정보를 포함하여 내부에서 Function이 작동하는 방식을 노출합니다. 그러나 추적이 발생하는 이유와 시기를 이해하면 tf.function을 효과적으로 사용하기가 훨씬 쉽습니다!

"추적"이란 무엇입니까?

FunctionTensorFlow Graph에서 프로그램을 실행합니다. 그러나 tf.Graph는 사용자가 즉시 실행 TensorFlow 프로그램에서 작성하고자 하는 모든 요소를 나타낼 수는 없습니다. 예를 들어 Python은 다형성을 지원하지만 tf.Graph는 입력에 데이터 유형과 차원의 지정을 요구합니다. 또는 사용자가 명령줄 인수 읽기, 오류 발생 또는 더 복잡한 Python 객체 작업과 같은 부수적인 작업을 수행할 수도 있지만, 이 중 어떤 작업도 tf.Graph에서 실행할 수 없습니다.

Function은 코드를 두 단계로 분리하여 이러한 문제를 해소합니다.

  1. "추적"이라고 하는 첫 번째 단계에서 Function은 새 tf.Graph를 만듭니다. Python 코드는 정상적으로 실행되지만 모든 TensorFlow 연산(예: 두 개의 텐서 추가)이 지연되어, 결국 실행되지 않고 tf.Graph에 의해 캡처됩니다.

  2. 두 번째 단계에서는 첫 번째 단계에서 지연된 모든 부분을 포함하는 tf.Graph가 실행됩니다. 이 단계는 추적 단계보다 훨씬 빠릅니다.

입력에 따라 Function이 호출시 항상 첫 번째 단계를 실행하지는 않습니다. 이 결정이 내려지는 방식을 더 잘 이해하려면 아래의 "추적 규칙"을 참조합니다. 첫 번째 단계를 건너뛰고 두 번째 단계만 실행하면 TensorFlow가 높은 성능을 발휘합니다.

Function이 추적하기로 결정하면 추적 단계 바로 다음에 두 번째 단계가 이어지므로 Function 호출로tf.Graph가 만들어지는 동시에 실행됩니다. 나중에 get_concrete_function으로 추적 단계만 실행하는 방법을 볼 수 있습니다.

다른 유형의 인수를 Function으로 전달하면 두 단계가 모두 실행됩니다.

@tf.function
def double(a):
  print("Tracing with", a)
  return a + a

print(double(tf.constant(1)))
print()
print(double(tf.constant(1.1)))
print()
print(double(tf.constant("a")))
print()
Tracing with Tensor("a:0", shape=(), dtype=int32)
tf.Tensor(2, shape=(), dtype=int32)

Tracing with Tensor("a:0", shape=(), dtype=float32)
tf.Tensor(2.2, shape=(), dtype=float32)

Tracing with Tensor("a:0", shape=(), dtype=string)
tf.Tensor(b'aa', shape=(), dtype=string)

같은 인수 유형으로 Function을 반복해서 호출하는 경우, 생성되는 그래프가 동일하므로 TensorFlow는 추적 단계를 건너뛰고 이전에 추적한 그래프를 재사용합니다.

# This doesn't print 'Tracing with ...'
print(double(tf.constant("b")))
tf.Tensor(b'bb', shape=(), dtype=string)

pretty_printed_concrete_signatures()를 사용하여 사용 가능한 모든 추적을 볼 수 있습니다.

print(double.pretty_printed_concrete_signatures())
double(a)
  Args:
    a: int32 Tensor, shape=()
  Returns:
    int32 Tensor, shape=()

double(a)
  Args:
    a: float32 Tensor, shape=()
  Returns:
    float32 Tensor, shape=()

double(a)
  Args:
    a: string Tensor, shape=()
  Returns:
    string Tensor, shape=()

지금까지 tf.function이 TensorFlow의 그래프 추적 로직을 통해 캐시된 동적 디스패치 레이어를 생성하는 과정을 확인했습니다. 다음은 용어에 대한 보충 설명입니다.

  • tf.Graph는 언어에 구애받지 않고 TensorFlow 계산을 이식 가능하게 원시 형태로 표현한 것입니다.
  • ConcreteFunctiontf.Graph를 래핑합니다.
  • FunctionConcreteFunction의 캐시를 관리하고 입력에 적합한 캐시를 선택합니다.
  • tf.function은 Python 함수를 래핑하여 Function 객체를 반환합니다.
  • 추적(tracing)은 tf.Graph를 생성하고 추적(trace)이라고도 하는 ConcreteFunction에서 이를 래핑합니다.

추적 규칙

호출하면 Function이 각 인수의 tf.types.experimental.TraceType을 사용하여 기존 ConcreteFunction에 호출 인수를 일치시킵니다. 일치하는 ConcreteFunction이 발견되면 호출이 전달됩니다. 일치하는 항목이 없으면 새 ConcreteFunction이 추적됩니다.

일치하는 항목이 여러 개 있는 경우 가장 구체적인 서명이 선택됩니다. 즉, C++ 또는 Java의 일반 함수 호출과 마찬가지로 매칭이 서브타이핑으로 수행됩니다. 예를 들어 TensorShape([1, 2])TensorShape([None, None])의 하위 유형이므로 TensorShape([1, 2])TensorShape([None, None])로 생성한 ConcreteFunction에 전달할 수 있지만 TensorShape([1, None])를 사용하는 ConcreteFunction가 존재하고 더 구체적일 경우 더 높은 우선순위를 갖습니다.

TraceType은 다음과 같이 입력 인수에서 결정됩니다.

  • Tensor의 경우 유형이 Tensordtypeshape에 의해 매개변수화됩니다. 순위 형상은 순위가 지정되지 않은 형상의 하위 유형입니다. 고정 차원은 알 수 없는 차원의 하위 유형입니다.
  • Variable의 경우 유형이 Tensor와 유사하지만 제어 종속성을 올바르게 연결하는 데 필요한 변수의 고유 리소스 ID도 포함합니다.
  • Python 기본 값의 경우 유형은 자체에 해당합니다. 예를 들어 3 값의 TraceTypeint가 아니라 LiteralTraceType<3>입니다.
  • listtuple 등과 같은 순서가 유지되는 Python 컨테이너의 경우 유형이 요소 유형에 따라 매개변수화됩니다. 예를 들어 [1, 2]의 유형은 ListTraceType<LiteralTraceType<1>, LiteralTraceType<2>>이고 [2, 1]의 유형은 앞선 유형과는 달리 ListTraceType<LiteralTraceType<2>, LiteralTraceType<1>>입니다.
  • dict와 같은 Python 매핑의 경우 유형은 동일한 키에서 실제 값 대신의 값 유형으로의 매핑이기도 합니다. 예를 들어 {1: 2, 3: 4}의 유형은 MappingTraceType<<KeyValue<1, LiteralTraceType<2>>>, <KeyValue<3, LiteralTraceType<4>>>>입니다. 순서가 정해져 있는 컨테이너와 달리 {1: 2, 3: 4}{3: 4, 1: 2}는 동일한 유형을 갖습니다.
  • __tf_tracing_type__ 메서드를 구현하는 Python 객체의 경우 해당 메소드가 반환하는 모든 항목이 유형으로 지정됩니다.
  • 다른 Python 개체의 경우 유형은 매칭을 위해 객체의 Python 동등성 및 해싱을 사용하는 제네릭 TraceType입니다(참고: 객체에 대한 weakref에 의존하므로 객체가 범위 내에 있거나 삭제되지 않은 경우에만 작동합니다).

참고: TraceTypeFunction 입력 매개변수를 기반으로 하므로 전역 및 자유 변수에 대한 변경만으로는 새 추적이 생성되지 않습니다. Python 전역 및 자유 변수를 처리할 때 권장되는 방법은 이 섹션을 참고합니다.

재추적 제어

Function이 두 개 이상의 추적을 생성하는 경우 재추적을 수행하면 TensorFlow가 각 입력 세트에 대해 올바른 그래프를 생성하는 데 도움이 됩니다. 그러나 추적은 비용이 많이 드는 작업입니다! 호출할 때마다 Function이 새 그래프를 재추적하면 tf.function을 사용하지 않는 경우보다 코드가 더 느리게 실행됩니다.

추적 동작을 제어하기 위해 다음 방법을 사용할 수 있습니다.

고정된 input_signaturetf.function에 전달하기

@tf.function(input_signature=(tf.TensorSpec(shape=[None], dtype=tf.int32),))
def next_collatz(x):
  print("Tracing with", x)
  return tf.where(x % 2 == 0, x // 2, 3 * x + 1)

print(next_collatz(tf.constant([1, 2])))
# You specified a 1-D tensor in the input signature, so this should fail.
with assert_raises(ValueError):
  next_collatz(tf.constant([[1, 2], [3, 4]]))

# You specified an int32 dtype in the input signature, so this should fail.
with assert_raises(ValueError):
  next_collatz(tf.constant([1.0, 2.0]))
Tracing with Tensor("x:0", shape=(None,), dtype=int32)
tf.Tensor([4 1], shape=(2,), dtype=int32)
Caught expected exception 
  <class 'ValueError'>:
Caught expected exception 
  <class 'ValueError'>:
Traceback (most recent call last):
  File "/tmpfs/tmp/ipykernel_184655/3551158538.py", line 8, in assert_raises
    yield
  File "/tmpfs/tmp/ipykernel_184655/1851403433.py", line 9, in <module>
    next_collatz(tf.constant([[1, 2], [3, 4]]))
ValueError: Python inputs incompatible with input_signature:
  inputs: (
    tf.Tensor(
[[1 2]
 [3 4]], shape=(2, 2), dtype=int32))
  input_signature: (
    TensorSpec(shape=(None,), dtype=tf.int32, name=None)).
Traceback (most recent call last):
  File "/tmpfs/tmp/ipykernel_184655/3551158538.py", line 8, in assert_raises
    yield
  File "/tmpfs/tmp/ipykernel_184655/1851403433.py", line 13, in <module>
    next_collatz(tf.constant([1.0, 2.0]))
ValueError: Python inputs incompatible with input_signature:
  inputs: (
    tf.Tensor([1. 2.], shape=(2,), dtype=float32))
  input_signature: (
    TensorSpec(shape=(None,), dtype=tf.int32, name=None)).

유연성을 위해 알 수 없는 차원 사용하기

TensorFlow는 형상에 따라 텐서를 일치시키므로 None 차원을 와일드카드로 사용하면 Function이 크기가 가변적인 입력에 대한 추적을 재사용할 수 있습니다. 길이가 다른 시퀀스 또는 각 배치에 대해 다른 크기의 이미지가 있는 경우에 크기가 가변적인 입력이 발생할 수 있습니다(TransformerDeep Dream 튜토리얼의 예제 참조).

@tf.function(input_signature=(tf.TensorSpec(shape=[None], dtype=tf.int32),))
def g(x):
  print('Tracing with', x)
  return x

# No retrace!
print(g(tf.constant([1, 2, 3])))
print(g(tf.constant([1, 2, 3, 4, 5])))
Tracing with Tensor("x:0", shape=(None,), dtype=int32)
tf.Tensor([1 2 3], shape=(3,), dtype=int32)
tf.Tensor([1 2 3 4 5], shape=(5,), dtype=int32)

파이썬 리터럴 대신 텐서 전달하기

종종 Python 인수는 하이퍼파라미터와 그래프 구성을 제어하는 데 사용됩니다(예: num_layers=10, training=True 또는 nonlinearity='relu'). 따라서 Python 인수가 변경되면 그래프를 다시 추적해야 합니다.

그러나 그래프 구성을 제어하는 데 Python 인수를 사용하지 않을 수도 있습니다. 이러한 경우 Python 값이 변경되면 불필요한 재추적이 실행될 수 있습니다. 예를 들어, AutoGraph가 동적으로 펼쳐지는 훈련 루프를 생각해봅니다. 여러 추적에도 불구하고 생성된 그래프는 실제로 동일하므로 다시 추적할 필요가 없습니다.

def train_one_step():
  pass

@tf.function
def train(num_steps):
  print("Tracing with num_steps = ", num_steps)
  tf.print("Executing with num_steps = ", num_steps)
  for _ in tf.range(num_steps):
    train_one_step()

print("Retracing occurs for different Python arguments.")
train(num_steps=10)
train(num_steps=20)

print()
print("Traces are reused for Tensor arguments.")
train(num_steps=tf.constant(10))
train(num_steps=tf.constant(20))
Retracing occurs for different Python arguments.
Tracing with num_steps =  10
Executing with num_steps =  10
Tracing with num_steps =  20
Executing with num_steps =  20

Traces are reused for Tensor arguments.
Tracing with num_steps =  Tensor("num_steps:0", shape=(), dtype=int32)
Executing with num_steps =  10
Executing with num_steps =  20

강제로 다시 추적해야 하는 경우 새 Function을 만듭니다. 별도의 Function 객체는 추적을 공유하지 않을 것이 보장됩니다.

def f():
  print('Tracing!')
  tf.print('Executing')

tf.function(f)()
tf.function(f)()
Tracing!
Executing
Tracing!
Executing

추적 프로토콜 사용하기

가능한 경우 대신 Python 유형을 tf.experimental.ExtensionType으로 변환하는 것이 좋습니다. 또한 ExtensionTypeTraceType은 이와 연결되어 있는 tf.TypeSpec입니다. 따라서 필요한 경우 기본 tf.TypeSpec을 재정의하여 ExtensionTypeTracing Protocol을 제어할 수 있습니다. 자세한 내용은 확장 유형 가이드의 ExtensionType의 TypeSpec 사용자 정의하기섹션을 참고합니다.

그 외에는 특정 Python 유형과 관련하여 Function이 재추적해야 하는 시기를 직접 제어하기 위해 이에 대한 Tracing Protocol을 직접 구현할 수 있습니다.

@tf.function
def get_mixed_flavor(fruit_a, fruit_b):
  return fruit_a.flavor + fruit_b.flavor

class Fruit:
  flavor = tf.constant([0, 0])

class Apple(Fruit):
  flavor = tf.constant([1, 2])

class Mango(Fruit):
  flavor = tf.constant([3, 4])

# As described in the above rules, a generic TraceType for `Apple` and `Mango`
# is generated (and a corresponding ConcreteFunction is traced) but it fails to 
# match the second function call since the first pair of Apple() and Mango() 
# have gone out out of scope by then and deleted.
get_mixed_flavor(Apple(), Mango()) # Traces a new concrete function
get_mixed_flavor(Apple(), Mango()) # Traces a new concrete function again

# However, each subclass of the `Fruit` class has a fixed flavor, and you
# can reuse an existing traced concrete function if it was the same
# subclass. Avoiding such unnecessary tracing of concrete functions
# can have significant performance benefits.

class FruitTraceType(tf.types.experimental.TraceType):
  def __init__(self, fruit_type):
    self.fruit_type = fruit_type

  def is_subtype_of(self, other):
      return (type(other) is FruitTraceType and
              self.fruit_type is other.fruit_type)

  def most_specific_common_supertype(self, others):
      return self if all(self == other for other in others) else None

  def __eq__(self, other):
    return type(other) is FruitTraceType and self.fruit_type == other.fruit_type

  def __hash__(self):
    return hash(self.fruit_type)

class FruitWithTraceType:

  def __tf_tracing_type__(self, context):
    return FruitTraceType(type(self))

class AppleWithTraceType(FruitWithTraceType):
  flavor = tf.constant([1, 2])

class MangoWithTraceType(FruitWithTraceType):
  flavor = tf.constant([3, 4])

# Now if you try calling it again:
get_mixed_flavor(AppleWithTraceType(), MangoWithTraceType()) # Traces a new concrete function
get_mixed_flavor(AppleWithTraceType(), MangoWithTraceType()) # Re-uses the traced concrete function
<tf.Tensor: shape=(2,), dtype=int32, numpy=array([4, 6], dtype=int32)>

구체적인 함수 얻기

get_concrete_function 메서드를 사용해 트레이싱된 특정 함수를 얻을 수 있습니다.

print("Obtaining concrete trace")
double_strings = double.get_concrete_function(tf.constant("a"))
print("Executing traced function")
print(double_strings(tf.constant("a")))
print(double_strings(a=tf.constant("b")))
Obtaining concrete trace
Executing traced function
tf.Tensor(b'aa', shape=(), dtype=string)
tf.Tensor(b'bb', shape=(), dtype=string)
# You can also call get_concrete_function on an InputSpec
double_strings_from_inputspec = double.get_concrete_function(tf.TensorSpec(shape=[], dtype=tf.string))
print(double_strings_from_inputspec(tf.constant("c")))
tf.Tensor(b'cc', shape=(), dtype=string)

ConcreteFunction를 인쇄하면 입력 인수(유형 포함)와 그 출력 유형의 요약이 표시됩니다.

print(double_strings)
ConcreteFunction double(a)
  Args:
    a: string Tensor, shape=()
  Returns:
    string Tensor, shape=()

구체적인 함수의 서명을 직접 검색할 수도 있습니다.

print(double_strings.structured_input_signature)
print(double_strings.structured_outputs)
((TensorSpec(shape=(), dtype=tf.string, name='a'),), {})
Tensor("Identity:0", shape=(), dtype=string)

호환되지 않는 유형의 구체적인 추적을 사용하면 오류가 발생합니다.

with assert_raises(tf.errors.InvalidArgumentError):
  double_strings(tf.constant(1))
Caught expected exception 
  <class 'tensorflow.python.framework.errors_impl.InvalidArgumentError'>:
Traceback (most recent call last):
  File "/tmpfs/tmp/ipykernel_184655/3551158538.py", line 8, in assert_raises
    yield
  File "/tmpfs/tmp/ipykernel_184655/3196284684.py", line 2, in <module>
    double_strings(tf.constant(1))
tensorflow.python.framework.errors_impl.InvalidArgumentError: cannot compute __inference_double_166 as input #0(zero-based) was expected to be a string tensor but is a int32 tensor [Op:__inference_double_166]

구체적인 함수의 입력 서명에서 Python 인수가 특별하게 처리된다는 것을 알 수 있습니다. TensorFlow 2.3 이전에는 Python 인수가 구체적인 함수의 서명에서 제거되었습니다. TensorFlow 2.3부터 Python 인수는 서명에 남아 있지만 추적 중에 설정된 값을 사용하도록 제한됩니다.

@tf.function
def pow(a, b):
  return a ** b

square = pow.get_concrete_function(a=tf.TensorSpec(None, tf.float32), b=2)
print(square)
ConcreteFunction pow(a, b=2)
  Args:
    a: float32 Tensor, shape=<unknown>
  Returns:
    float32 Tensor, shape=<unknown>
assert square(tf.constant(10.0)) == 100

with assert_raises(TypeError):
  square(tf.constant(10.0), b=3)
Caught expected exception 
  <class 'TypeError'>:
Traceback (most recent call last):
  File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/monomorphic_function.py", line 1487, in _call_impl
    return self._call_with_flat_signature(args, kwargs,
  File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/monomorphic_function.py", line 1532, in _call_with_flat_signature
    raise TypeError(f"{self._flat_signature_summary()} got unexpected "
TypeError: pow(a) got unexpected keyword arguments: b.

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/tmpfs/tmp/ipykernel_184655/3551158538.py", line 8, in assert_raises
    yield
  File "/tmpfs/tmp/ipykernel_184655/2310937119.py", line 4, in <module>
    square(tf.constant(10.0), b=3)
TypeError: ConcreteFunction pow(a, b) was constructed with int value 2 in b, but was called with int value 3.

그래프 얻기

각 구체적인 함수는 tf.Graph를 감싸는 호출 가능한 래퍼입니다. tf.Graph 객체를 검색하는 것이 일반적으로 수행해야 하는 작업은 아니지만 구체적인 함수에서 쉽게 얻을 수 있습니다.

graph = double_strings.graph
for node in graph.as_graph_def().node:
  print(f'{node.input} -> {node.name}')
[] -> a
['a', 'a'] -> add
['add'] -> Identity

디버깅

일반적으로 tf.function 내부에서 할 때보다 Eager 모드가 디버깅하기 쉽습니다. tf.function으로 데코레이팅하기 전에 Eager 모드에서 에러가 없는지 확인합니다. 디버깅 과정을 위해 tf.config.run_functions_eagerly(True)으로 전체 tf.function을 비활성화하고 나중에 다시 활성화할 수 있습니다.

다음은 tf.function 내에서만 나타나는 문제를 추적할 때 사용할 수 있는 몇 가지 팁입니다.

  • Python print 함수는 추적(tracing)하는 동안에만 호출되므로 함수가 (재)추적될 때 추적하는데 도움이 됩니다.
  • tf.print 함수는 언제나 실행되므로 실행하는 동안 중간 값을 추적할 때 도움이 됩니다.
  • tf.debugging.enable_check_numerics을 사용하면 쉽게 NaN과 Inf가 발생되는 곳을 추적할 수 있습니다.
  • pdb(Python 디버거)는 추적 중에 어떤 일이 일어나는지 이해하는데 도움이 될 수 있습니다(주의: pdb는 사용자를 AutoGraph로 변환된 소스 코드로 이동시킵니다).

AutoGraph 변환

AutoGraph는 tf.function안에 기본으로 활성화되어 있는 라이브러리이며 Python의 Eager 코드를 그래프 호환 TensorFlow ops로 변환합니다. 여기에는 if, for, while과 같은 제어 흐름이 포함됩니다.

tf.condtf.while_loop 같은 텐서플로 연산을 여전히 사용할 수 있지만 파이썬으로 제어 흐름을 작성하는 것이 만들기도 이해하기도 쉽습니다.

# A simple loop

@tf.function
def f(x):
  while tf.reduce_sum(x) > 1:
    tf.print(x)
    x = tf.tanh(x)
  return x

f(tf.random.uniform([5]))
[0.395097852 0.319370866 0.79926908 0.448182821 0.534545779]
[0.375746667 0.308937907 0.663628 0.420404136 0.48884818]
[0.359007984 0.299470544 0.580772758 0.397270799 0.453301787]
[0.344339937 0.290828 0.523226857 0.377611309 0.424609303]
[0.331346363 0.282896698 0.480186611 0.360631227 0.400806427]
[0.31973 0.275583893 0.446393 0.345769882 0.380638719]
[0.309262753 0.268812954 0.418929547 0.332618743 0.363262028]
[0.299766243 0.262519956 0.39602825 0.32087183 0.348084033]
[0.291098654 0.256651 0.376545459 0.310295016 0.334675252]
[0.283145696 0.251160443 0.35970363 0.300705433 0.322715402]
[0.275814 0.246009186 0.344952911 0.291958034 0.311960101]
[0.269026428 0.241163582 0.331891924 0.283935964 0.302219182]
[0.262718678 0.236594483 0.320219725 0.276543975 0.293342143]
[0.256836593 0.232276499 0.309705526 0.269703448 0.285208]
[0.25133431 0.228187427 0.30016917 0.263348848 0.277718306]
[0.246172518 0.224307656 0.291467398 0.257425129 0.270792]
[0.241317406 0.220619902 0.283484846 0.251885563 0.26436162]
[0.23673968 0.217108801 0.276127309 0.246690288 0.258370548]
[0.232413873 0.213760659 0.269317031 0.241804957 0.252770811]
[0.228317648 0.210563242 0.262989193 0.237199858 0.247521505]
[0.224431336 0.207505539 0.257089257 0.232849166 0.242587417]
[0.220737576 0.20457764 0.251571 0.228730202 0.237938166]
[0.217220932 0.201770619 0.246394828 0.224823058 0.2335473]
[0.213867664 0.199076355 0.241526753 0.221110165 0.229391694]
[0.210665509 0.196487486 0.236937299 0.217575923 0.225451022]
[0.20760341 0.193997309 0.232600808 0.214206412 0.221707344]
[0.204671398 0.191599682 0.228494823 0.210989177 0.218144745]
[0.201860532 0.189289033 0.22459957 0.207913101 0.214749053]
[0.199162707 0.187060192 0.2208976 0.204968125 0.211507618]
[0.196570486 0.18490845 0.217373401 0.202145159 0.208409086]
<tf.Tensor: shape=(5,), dtype=float32, numpy=
array([0.19407718, 0.18282945, 0.21401317, 0.19943602, 0.20544322],
      dtype=float32)>

관심있다면 오토그래프가 생성한 코드를 확인해 볼 수 있습니다.

print(tf.autograph.to_code(f.python_function))
def tf__f(x):
    with ag__.FunctionScope('f', 'fscope', ag__.ConversionOptions(recursive=True, user_requested=True, optional_features=(), internal_convert_user_code=True)) as fscope:
        do_return = False
        retval_ = ag__.UndefinedReturnValue()

        def get_state():
            return (x,)

        def set_state(vars_):
            nonlocal x
            (x,) = vars_

        def loop_body():
            nonlocal x
            ag__.converted_call(ag__.ld(tf).print, (ag__.ld(x),), None, fscope)
            x = ag__.converted_call(ag__.ld(tf).tanh, (ag__.ld(x),), None, fscope)

        def loop_test():
            return ag__.converted_call(ag__.ld(tf).reduce_sum, (ag__.ld(x),), None, fscope) > 1
        ag__.while_stmt(loop_test, loop_body, get_state, set_state, ('x',), {})
        try:
            do_return = True
            retval_ = ag__.ld(x)
        except:
            do_return = False
            raise
        return fscope.ret(retval_, do_return)

조건문

AutoGraph는 if <condition> 문장을 이와 대등한 tf.cond 호출로 변경합니다. 이런 대체는 <condition>이 텐서일 때 수행됩니다. 그렇지 않다면 if 문장은 Python 조건문으로 실행됩니다.

추적하는 동안 Python 조건문을 실행하기 때문에 정확히 하나의 조건 분기만 그래프에 추가됩니다. Autograph가 없다면 이렇게 추적된 그래프는 데이터 종속 제어 흐름이 있는 경우 대체 분기를 사용할 수 없습니다.

tf.cond는 조건문의 두 분기를 모두 추적하고 그래프에 추가하여 실행 시 분기를 동적으로 선택합니다. 추적에는 의도하지 않은 부작용이 있을 수 있습니다. 자세한 내용은 AutoGraph 추적 효과를 확인하세요.

@tf.function
def fizzbuzz(n):
  for i in tf.range(1, n + 1):
    print('Tracing for loop')
    if i % 15 == 0:
      print('Tracing fizzbuzz branch')
      tf.print('fizzbuzz')
    elif i % 3 == 0:
      print('Tracing fizz branch')
      tf.print('fizz')
    elif i % 5 == 0:
      print('Tracing buzz branch')
      tf.print('buzz')
    else:
      print('Tracing default branch')
      tf.print(i)

fizzbuzz(tf.constant(5))
fizzbuzz(tf.constant(20))
Tracing for loop
Tracing fizzbuzz branch
Tracing fizz branch
Tracing buzz branch
Tracing default branch
1
2
fizz
4
buzz
1
2
fizz
4
buzz
fizz
7
8
fizz
buzz
11
fizz
13
14
fizzbuzz
16
17
fizz
19
buzz

Autograph가 변환한 if 문장에 대한 추가 제약 사항은 참조 문서를 참고하세요.

반복문

Autograph는 일부 forwhile 문장을 tf.while_loop와 같은 동등한 TensorFlow 루프 ops로 바꿉니다. 변환되지 않으면 Python 루프로 forwhile 루프가 실행됩니다.

이런 대체는 다음과 같은 경우에 일어납니다:

  • for x in y: y가 텐서이면 tf.while_loop로 변환됩니다. 특별히 ytf.data.Dataset인 경우에는 tf.data.Dataset 연산의 조합이 생성됩니다.
  • while <condition>: <condition>이 텐서라면 tf.while_loop로 변환됩니다.

추적하는 동안 Python 루프가 실행되므로 매 루프 반복 때마다 tf.Graph에 추가적인 ops가 추가됩니다.

TensorFlow 루프는 루프 블럭을 추적하여 실행 시 얼마나 많은 반복을 수행할지 동적으로 선택합니다. 루프 블럭은 생성된 tf.Graph에 한 번만 포함됩니다.

Autograph가 변환한 forwhile 문장에 대한 추가 제약 사항은 참조 문서를 참고하세요.

Python 데이터로 루핑하기

일반적인 함정은 tf.function 내에서 Python/NumPy 데이터를 루핑하는 것입니다. 이 루프는 추적 프로세스 중에 실행되어 루프의 각 반복에 대한 모델 복사본을 tf.Graph에 추가합니다.

tf.function으로 전체 훈련 루핑을 래핑하고 싶은 경우, 안전한 방법은 데이터를 tf.data.Dataset으로 래핑하여 Autograph가 동적으로 훈련 루프를 펼치게 하는 것입니다.

def measure_graph_size(f, *args):
  g = f.get_concrete_function(*args).graph
  print("{}({}) contains {} nodes in its graph".format(
      f.__name__, ', '.join(map(str, args)), len(g.as_graph_def().node)))

@tf.function
def train(dataset):
  loss = tf.constant(0)
  for x, y in dataset:
    loss += tf.abs(y - x) # Some dummy computation.
  return loss

small_data = [(1, 1)] * 3
big_data = [(1, 1)] * 10
measure_graph_size(train, small_data)
measure_graph_size(train, big_data)

measure_graph_size(train, tf.data.Dataset.from_generator(
    lambda: small_data, (tf.int32, tf.int32)))
measure_graph_size(train, tf.data.Dataset.from_generator(
    lambda: big_data, (tf.int32, tf.int32)))
train([(1, 1), (1, 1), (1, 1)]) contains 11 nodes in its graph
train([(1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1)]) contains 32 nodes in its graph
train(<FlatMapDataset element_spec=(TensorSpec(shape=<unknown>, dtype=tf.int32, name=None), TensorSpec(shape=<unknown>, dtype=tf.int32, name=None))>) contains 6 nodes in its graph
train(<FlatMapDataset element_spec=(TensorSpec(shape=<unknown>, dtype=tf.int32, name=None), TensorSpec(shape=<unknown>, dtype=tf.int32, name=None))>) contains 6 nodes in its graph

데이터 세트에서 Python/NumPy 데이터를 래핑할 때 tf.data.Dataset.from_generatortf.data.Dataset.from_tensors에 유의해야 합니다. 전자는 데이터를 Python에 유지하고 성능에 영향을 미칠 수 있는 tf.py_function을 통해 가져오는 반면 후자는 데이터 복사본을 그래프에서 하나의 큰 tf.constant() 노드로 묶으며 이는 메모리에 영향을 미칠 수 있습니다.

TFRecordDataset, CsvDataset 등을 통해 파일에서 데이터를 읽는 것은 데이터를 소비하는 가장 효과적인 방법이며 이렇게 할 경우 Python을 사용하지 않아도 TensorFlow 자체적으로 데이터의 비동기 로드 및 프리페치를 관리할 수 있습니다. ​자세한 내용은 tf.data: TensorFlow 입력 파이프라인 빌드 가이드를 참조하세요

반복하면서 값을 누적하기

반복하면서 중간 값을 누적하는 패턴은 자주 있습니다. 보통 Python 목록이나 사전에 입력 항목을 추가하는 방식을 사용합니다. 하지만 Python 부수 효과 때문에 동적으로 펼쳐지는 반복에서는 기대대로 동작하지 않습니다. 대신 tf.TensorArray를 사용해 동적으로 펼쳐지는 반복에서 결과를 누적하세요.

batch_size = 2
seq_len = 3
feature_size = 4

def rnn_step(inp, state):
  return inp + state

@tf.function
def dynamic_rnn(rnn_step, input_data, initial_state):
  # [batch, time, features] -> [time, batch, features]
  input_data = tf.transpose(input_data, [1, 0, 2])
  max_seq_len = input_data.shape[0]

  states = tf.TensorArray(tf.float32, size=max_seq_len)
  state = initial_state
  for i in tf.range(max_seq_len):
    state = rnn_step(input_data[i], state)
    states = states.write(i, state)
  return tf.transpose(states.stack(), [1, 0, 2])

dynamic_rnn(rnn_step,
            tf.random.uniform([batch_size, seq_len, feature_size]),
            tf.zeros([batch_size, feature_size]))
<tf.Tensor: shape=(2, 3, 4), dtype=float32, numpy=
array([[[0.31763458, 0.5524497 , 0.24966824, 0.5294143 ],
        [0.5924542 , 0.7307738 , 0.45017755, 0.53719497],
        [1.5015074 , 1.4625493 , 1.1735668 , 0.6612781 ]],

       [[0.88830245, 0.45373   , 0.25992   , 0.9537363 ],
        [1.6396775 , 0.4752848 , 0.4163171 , 1.2923994 ],
        [2.3875687 , 1.2606081 , 1.3037155 , 1.5534725 ]]], dtype=float32)>

한계

TensorFlow Function에는 기본적으로 몇 가지 한계가 있기에 Python 함수를 Function으로 변환할 때 이에 대해 알고 있어야 합니다.

Python 부작용 실행

인쇄, 목록에 추가 및 전역 변경과 같은 부작용은 Function 내에서 예기치 않게 동작할 수 있으며 때로 두 번 실행되거나 전혀 실행되지 않을 수 있습니다. 이러한 동작은 입력 세트를 사용하여 Function을 처음 호출할 때만 발생합니다. 그 후에는 추적된 tf.Graph가 Python 코드를 실행하지 않고 다시 실행됩니다.

경험에 의한 일반적인 규칙은 논리에서 Python 부작용에 의존하지 않고 추적을 디버그하는 데만 사용하는 것입니다. 그렇지 않으면 tf.data, tf.print, tf.summary, tf.Variable.assigntf.TensorArray와 같은 TensorFlow API를 사용하는 것이 각 호출을 통해 TensorFlow 런타임에서 코드가 실행되도록 하는 가장 좋은 방법입니다.

@tf.function
def f(x):
  print("Traced with", x)
  tf.print("Executed with", x)

f(1)
f(1)
f(2)
Traced with 1
Executed with 1
Executed with 1
Traced with 2
Executed with 2

Function을 호출할 때마다 Python 코드를 실행하려는 경우 tf.py_function을 종료 해치로 사용할 수 있습니다. tf.py_function의 단점은 이것이 이식 가능하거나 특별히 성능이 뛰어나지 않고, SavedModel로 저장할 수 없으며, 분산(다중 GPU, TPU) 환경에서 제대로 작동하지 않는다는 것입니다. 또한 tf.py_function은 그래프에 연결되어야 하므로 모든 입력/출력을 텐서로 캐스팅합니다.

Python 전역 및 자유 변수 변경

Python 전역 및 자유 변수 변경은 Python 부작용으로 간주되므로 추적 중에만 발생합니다.

external_list = []

@tf.function
def side_effect(x):
  print('Python side effect')
  external_list.append(x)

side_effect(1)
side_effect(1)
side_effect(1)
# The list append only happened once!
assert len(external_list) == 1
Python side effect

때로는 예상치 못한 작업을 알아차리기가 매우 어렵습니다. 아래 예제의 counter는 변수의 증가를 보호하기 위한 목적으로 사용되었습니다. 그러나 이것은 TensorFlow 객체가 아니라 Python 정수이기 때문에 첫 번째 추적 중에 값을 캡처합니다. tf.function을 사용하면 assign_add가 기본 그래프에 무조건 기록됩니다. 따라서 tf.function을 호출할 때마다 v가 1씩 증가합니다. 이 문제는 Python 부작용(예제에서 counter)을 사용하여 실행할 ops를 결정(예제에서 assign_add)할 때
tf.function 데코레이터를 사용하여 그래프 모드 Tensorflow를 Tensorflow 2로 마이그레이션하려는 사용자 사이에서 일반적입니다. 일반적으로 사용자는 의심스러운 수치 결과 또는 예상보다 현저히 낮은 성능을 본 후에야 이를 깨닫게 됩니다(예: 보호된 작업에 비용이 많이 드는 경우).

class Model(tf.Module):
  def __init__(self):
    self.v = tf.Variable(0)
    self.counter = 0

  @tf.function
  def __call__(self):
    if self.counter == 0:
      # A python side-effect
      self.counter += 1
      self.v.assign_add(1)

    return self.v

m = Model()
for n in range(3):
  print(m().numpy()) # prints 1, 2, 3
1
2
3

예상 동작을 달성하기 위한 해결 방법은 tf.init_scope를 사용하여 함수 그래프 외부에서 작업을 수행하는 것입니다. 이렇게 하면 추적 시간 동안 변수 증가가 한 번만 수행됩니다. init_scope에는 명확한 제어 플로 및 그래디언트 테이프 등의 기타 부작용이 있습니다. 때때로 init_scope의 사용법은 관리하기에는 현실적으로 너무 복잡해질 수 있습니다.

class Model(tf.Module):
  def __init__(self):
    self.v = tf.Variable(0)
    self.counter = 0

  @tf.function
  def __call__(self):
    if self.counter == 0:
      # Lifts ops out of function-building graphs
      with tf.init_scope():
        self.counter += 1
        self.v.assign_add(1)

    return self.v

m = Model()
for n in range(3):
  print(m().numpy()) # prints 1, 1, 1
1
1
1

요약하면 경험상 Function 외부에 있는 목록과 같은 정수 또는 컨테이너 등의 Python 객체는 변경하지 않아야 합니다. 대신 인수와 TF 객체를 사용하도록 합니다. 예를 들어, "루핑하면서 값을 누적하기" 섹션에는 목록과 유사한 연산을 구현할 수 있는 방법에 대한 한 가지 예제가 있습니다.

tf.Variable일 경우, 상태를 캡처하고 조작할 수 있는 경우도 있습니다. 이것이 동일한 ConcreteFunction에 대한 반복 호출로 Keras 모델의 가중치가 업데이트되는 방식입니다.

Python 반복기 및 생성기 사용

생성기 및 반복기와 같은 많은 Python 기능은 상태를 추적하기 위해 Python 런타임에 의존합니다. 일반적으로, 이러한 구조는 Eager 모드에서 예상대로 작동하지만 Python 부작용의 예이므로 추적 중에만 발생합니다.

@tf.function
def buggy_consume_next(iterator):
  tf.print("Value:", next(iterator))

iterator = iter([1, 2, 3])
buggy_consume_next(iterator)
# This reuses the first value from the iterator, rather than consuming the next value.
buggy_consume_next(iterator)
buggy_consume_next(iterator)
Value: 1
Value: 1
Value: 1

TensorFlow가 목록 구성을 위한 tf.TensorArray를 가지고 있는 것과 마찬가지로 반복 구성을 위한 tf.data.Iterator도 가지고 있습니다. 개괄적인 내용은 AutoGraph 변환 섹션을 참조합니다. tf.data API도 생성기 패턴을 구현하는 데 도움이 될 수 있습니다.

@tf.function
def good_consume_next(iterator):
  # This is ok, iterator is a tf.data.Iterator
  tf.print("Value:", next(iterator))

ds = tf.data.Dataset.from_tensor_slices([1, 2, 3])
iterator = iter(ds)
good_consume_next(iterator)
good_consume_next(iterator)
good_consume_next(iterator)
Value: 1
Value: 2
Value: 3

tf.function의 모든 출력은 반환 값이어야 합니다.

tf.Variable을 제외하고 tf.function은 모든 출력을 반환해야 합니다. 반환 값을 거치지 않고 함수의 텐서에 직접 액세스하려고 하면 "누출"이 발생합니다.

예를 들어 아래의 함수는 Python 전역 x를 통해 텐서 a를 '누출'합니다.

x = None

@tf.function
def leaky_function(a):
  global x
  x = a + 1  # Bad - leaks local tensor
  return a + 2

correct_a = leaky_function(tf.constant(1))

print(correct_a.numpy())  # Good - value obtained from function's returns
try:
  x.numpy()  # Bad - tensor leaked from inside the function, cannot be used here
except AttributeError as expected:
  print(expected)
3
'Tensor' object has no attribute 'numpy'

이는 누출된 값이 반환된 경우에도 마찬가지입니다.

@tf.function
def leaky_function(a):
  global x
  x = a + 1  # Bad - leaks local tensor
  return x  # Good - uses local tensor

correct_a = leaky_function(tf.constant(1))

print(correct_a.numpy())  # Good - value obtained from function's returns
try:
  x.numpy()  # Bad - tensor leaked from inside the function, cannot be used here
except AttributeError as expected:
  print(expected)

@tf.function
def captures_leaked_tensor(b):
  b += x  # Bad - `x` is leaked from `leaky_function`
  return b

with assert_raises(TypeError):
  captures_leaked_tensor(tf.constant(2))
2
'Tensor' object has no attribute 'numpy'
Caught expected exception 
  <class 'TypeError'>:
Traceback (most recent call last):
  File "/tmpfs/tmp/ipykernel_184655/3551158538.py", line 8, in assert_raises
    yield
  File "/tmpfs/tmp/ipykernel_184655/566849597.py", line 21, in <module>
    captures_leaked_tensor(tf.constant(2))
TypeError: <tf.Tensor 'add:0' shape=() dtype=int32> is out of scope and cannot be used here. Use return values, explicit Python locals or TensorFlow collections to access it.
Please see https://www.tensorflow.org/guide/function#all_outputs_of_a_tffunction_must_be_return_values for more information.

<tf.Tensor 'add:0' shape=() dtype=int32> was defined here:
    File "/usr/lib/python3.9/runpy.py", line 197, in _run_module_as_main
      return _run_code(code, main_globals, None,
    File "/usr/lib/python3.9/runpy.py", line 87, in _run_code
      exec(code, run_globals)
    File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/ipykernel_launcher.py", line 17, in <module>
      app.launch_new_instance()
    File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/traitlets/config/application.py", line 992, in launch_instance
      app.start()
    File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/ipykernel/kernelapp.py", line 711, in start
      self.io_loop.start()
    File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tornado/platform/asyncio.py", line 215, in start
      self.asyncio_loop.run_forever()
    File "/usr/lib/python3.9/asyncio/base_events.py", line 601, in run_forever
      self._run_once()
    File "/usr/lib/python3.9/asyncio/base_events.py", line 1905, in _run_once
      handle._run()
    File "/usr/lib/python3.9/asyncio/events.py", line 80, in _run
      self._context.run(self._callback, *self._args)
    File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/ipykernel/kernelbase.py", line 510, in dispatch_queue
      await self.process_one()
    File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/ipykernel/kernelbase.py", line 499, in process_one
      await dispatch(*args)
    File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/ipykernel/kernelbase.py", line 406, in dispatch_shell
      await result
    File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/ipykernel/kernelbase.py", line 729, in execute_request
      reply_content = await reply_content
    File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/ipykernel/ipkernel.py", line 411, in do_execute
      res = shell.run_cell(
    File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/ipykernel/zmqshell.py", line 531, in run_cell
      return super().run_cell(*args, **kwargs)
    File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/IPython/core/interactiveshell.py", line 2940, in run_cell
      result = self._run_cell(
    File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/IPython/core/interactiveshell.py", line 2995, in _run_cell
      return runner(coro)
    File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/IPython/core/async_helpers.py", line 129, in _pseudo_sync_runner
      coro.send(None)
    File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/IPython/core/interactiveshell.py", line 3194, in run_cell_async
      has_raised = await self.run_ast_nodes(code_ast.body, cell_name,
    File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/IPython/core/interactiveshell.py", line 3373, in run_ast_nodes
      if await self.run_code(code, result, async_=asy):
    File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/IPython/core/interactiveshell.py", line 3433, in run_code
      exec(code_obj, self.user_global_ns, self.user_ns)
    File "/tmpfs/tmp/ipykernel_184655/566849597.py", line 7, in <module>
      correct_a = leaky_function(tf.constant(1))
    File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/util/traceback_utils.py", line 150, in error_handler
      return fn(*args, **kwargs)
    File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/polymorphic_function.py", line 880, in __call__
      result = self._call(*args, **kwds)
    File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/polymorphic_function.py", line 928, in _call
      self._initialize(args, kwds, add_initializers_to=initializers)
    File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/polymorphic_function.py", line 749, in _initialize
      self._variable_creation_fn    # pylint: disable=protected-access
    File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/tracing_compiler.py", line 162, in _get_concrete_function_internal_garbage_collected
      concrete_function, _ = self._maybe_define_concrete_function(args, kwargs)
    File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/tracing_compiler.py", line 157, in _maybe_define_concrete_function
      return self._maybe_define_function(args, kwargs)
    File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/tracing_compiler.py", line 360, in _maybe_define_function
      concrete_function = self._create_concrete_function(args, kwargs)
    File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/tracing_compiler.py", line 284, in _create_concrete_function
      func_graph_module.func_graph_from_py_func(
    File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/func_graph.py", line 1283, in func_graph_from_py_func
      func_outputs = python_func(*func_args, **func_kwargs)
    File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/polymorphic_function.py", line 645, in wrapped_fn
      out = weak_wrapped_fn().__wrapped__(*args, **kwds)
    File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/func_graph.py", line 1258, in autograph_handler
      return autograph.converted_call(
    File "/tmpfs/tmp/ipykernel_184655/566849597.py", line 4, in leaky_function
      x = a + 1  # Bad - leaks local tensor
    File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/util/traceback_utils.py", line 150, in error_handler
      return fn(*args, **kwargs)
    File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/ops/math_ops.py", line 1407, in binary_op_wrapper
      return func(x, y, name=name)
    File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/util/traceback_utils.py", line 150, in error_handler
      return fn(*args, **kwargs)
    File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/util/dispatch.py", line 1176, in op_dispatch_handler
      return dispatch_target(*args, **kwargs)
    File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/ops/math_ops.py", line 1757, in _add_dispatch
      return gen_math_ops.add_v2(x, y, name=name)
    File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/ops/gen_math_ops.py", line 475, in add_v2
      _, _, _op, _outputs = _op_def_library._apply_op_helper(
    File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/op_def_library.py", line 795, in _apply_op_helper
      op = g._create_op_internal(op_type_name, inputs, dtypes=None,
    File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/func_graph.py", line 749, in _create_op_internal
      return super(FuncGraph, self)._create_op_internal(  # pylint: disable=protected-access
    File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py", line 3798, in _create_op_internal
      ret = Operation(

The tensor <tf.Tensor 'add:0' shape=() dtype=int32> cannot be accessed from here, because it was defined in FuncGraph(name=leaky_function, id=139766631324832), which is out of scope.

일반적으로 이러한 누출은 Python 구문이나 데이터 구조를 사용할 때 발생합니다. 이러한 명령문은 액세스할 수 없는 텐서를 누출하는 것 외에도 Python 부작용으로 간주되고 모든 함수 호출에서 실행되는 것이 보장되지 않기 때문에 잘못될 가능성이 있습니다.

로컬 텐서를 누출하는 일반적인 방법에는 다음과 같이 외부 Python 컬렉션 또는 객체를 변경하는 것 등이 있습니다.

class MyClass:

  def __init__(self):
    self.field = None

external_list = []
external_object = MyClass()

def leaky_function():
  a = tf.constant(1)
  external_list.append(a)  # Bad - leaks tensor
  external_object.field = a  # Bad - leaks tensor

재귀 tf.functions는 지원되지 않습니다.

재귀 Function은 지원되지 않으며 무한 루프가 발생할 수 있습니다. 예를 들면 다음과 같습니다.

@tf.function
def recursive_fn(n):
  if n > 0:
    return recursive_fn(n - 1)
  else:
    return 1

with assert_raises(Exception):
  recursive_fn(tf.constant(5))  # Bad - maximum recursion error.
Caught expected exception 
  <class 'Exception'>:
Traceback (most recent call last):
  File "/tmpfs/tmp/ipykernel_184655/3551158538.py", line 8, in assert_raises
    yield
  File "/tmpfs/tmp/ipykernel_184655/2233998312.py", line 9, in <module>
    recursive_fn(tf.constant(5))  # Bad - maximum recursion error.
tensorflow.python.autograph.impl.api.StagingError: in user code:

    File "/tmpfs/tmp/ipykernel_184655/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_184655/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_184655/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_184655/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_184655/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_184655/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_184655/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_184655/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_184655/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_184655/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_184655/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_184655/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_184655/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_184655/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_184655/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_184655/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_184655/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_184655/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_184655/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_184655/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_184655/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_184655/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_184655/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_184655/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_184655/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_184655/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_184655/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_184655/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_184655/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_184655/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_184655/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_184655/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_184655/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_184655/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_184655/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_184655/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_184655/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_184655/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_184655/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_184655/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_184655/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_184655/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_184655/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_184655/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_184655/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_184655/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_184655/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_184655/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_184655/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_184655/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_184655/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_184655/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_184655/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_184655/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_184655/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_184655/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_184655/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_184655/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_184655/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_184655/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_184655/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_184655/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_184655/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_184655/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_184655/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_184655/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_184655/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_184655/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_184655/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_184655/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_184655/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_184655/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_184655/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_184655/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_184655/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_184655/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_184655/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_184655/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_184655/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_184655/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_184655/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_184655/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_184655/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_184655/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_184655/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_184655/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_184655/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_184655/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_184655/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_184655/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_184655/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_184655/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_184655/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_184655/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_184655/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_184655/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_184655/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_184655/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_184655/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_184655/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_184655/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_184655/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_184655/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_184655/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_184655/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_184655/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_184655/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_184655/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_184655/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_184655/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_184655/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_184655/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_184655/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_184655/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_184655/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_184655/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_184655/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmpfs/tmp/ipykernel_184655/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/usr/lib/python3.9/abc.py", line 119, in __instancecheck__
        return _abc_instancecheck(cls, instance)
    File "/usr/lib/python3.9/abc.py", line 123, in __subclasscheck__
        return _abc_subclasscheck(cls, subclass)

    RecursionError: maximum recursion depth exceeded while calling a Python object

재귀 Function이 작동하는 것처럼 보이더라도 Python 함수는 여러 번 추적되며 성능에 영향을 미칠 수 있습니다. 예를 들면 다음과 같습니다.

@tf.function
def recursive_fn(n):
  if n > 0:
    print('tracing')
    return recursive_fn(n - 1)
  else:
    return 1

recursive_fn(5)  # Warning - multiple tracings
tracing
tracing
tracing
tracing
tracing
<tf.Tensor: shape=(), dtype=int32, numpy=1>

알려진 문제

Function이 올바르게 평가되지 않는 경우, 오류는 이러한 알려진 문제에 의해 설명될 수 있으며, 이 부분은 향후에 수정될 예정입니다.

Python 전역 및 자유 변수에 의존

Function은 Python 인수의 새 값으로 호출될 때 새로운 ConcreteFunction을 생성합니다. 그러나 해당 Function의 Python 클로저, 전역 또는 비로컬에 대해서는 그렇게 하지 않습니다. Function 호출 사이에 값이 변경되면 Function은 추적되었을 때 가지고 있던 값을 계속 사용합니다. 이것은 일반 Python 함수가 작동하는 방식과 다릅니다.

따라서 외부 이름을 닫는 대신 인수를 사용하는 함수형 프로그래밍 방식을 따라야 합니다.

@tf.function
def buggy_add():
  return 1 + foo

@tf.function
def recommended_add(foo):
  return 1 + foo

foo = 1
print("Buggy:", buggy_add())
print("Correct:", recommended_add(foo))
Buggy: tf.Tensor(2, shape=(), dtype=int32)
Correct: tf.Tensor(2, shape=(), dtype=int32)
print("Updating the value of `foo` to 100!")
foo = 100
print("Buggy:", buggy_add())  # Did not change!
print("Correct:", recommended_add(foo))
Updating the value of `foo` to 100!
Buggy: tf.Tensor(2, shape=(), dtype=int32)
Correct: tf.Tensor(101, shape=(), dtype=int32)

전역 값을 업데이트하는 또 다른 방법은 tf.Variable로 만들고 대신 Variable.assign 메서드를 사용하는 것입니다.

@tf.function
def variable_add():
  return 1 + foo

foo = tf.Variable(1)
print("Variable:", variable_add())
Variable: tf.Tensor(2, shape=(), dtype=int32)
print("Updating the value of `foo` to 100!")
foo.assign(100)
print("Variable:", variable_add())
Updating the value of `foo` to 100!
Variable: tf.Tensor(101, shape=(), dtype=int32)

Python 객체에 의존

Python 객체를 tf.function에 인수로 전달하라는 권장 사항에는 여러 가지 알려진 문제가 있으며 향후 수정될 것으로 예상합니다. 일반적으로, Python 기본 형식 또는 tf.nest 호환 구조를 인수로 사용하거나 객체의 다른 인스턴스에서 Function으로 전달하는 경우, 일관된 추적에 의존할 수 있습니다. 그러나 동일한 객체를 전달하고 해당 속성만 변경하는 경우 Function은 새 추적을 생성하지 않습니다.

class SimpleModel(tf.Module):
  def __init__(self):
    # These values are *not* tf.Variables.
    self.bias = 0.
    self.weight = 2.

@tf.function
def evaluate(model, x):
  return model.weight * x + model.bias

simple_model = SimpleModel()
x = tf.constant(10.)
print(evaluate(simple_model, x))
tf.Tensor(20.0, shape=(), dtype=float32)
print("Adding bias!")
simple_model.bias += 5.0
print(evaluate(simple_model, x))  # Didn't change :(
Adding bias!
tf.Tensor(20.0, shape=(), dtype=float32)

동일한 Function을 사용하여 모델의 업데이트된 인스턴스를 평가하는 것은 오류의 위험이 있습니다. 이는 업데이트된 모델이 원래 모델과 동일한 캐시 키를 갖기 때문입니다.

따라서 변경 가능한 객체 속성에 의존하지 않도록 Function을 작성하거나 새 객체를 생성하는 것이 좋습니다.

이것이 가능하지 않은 경우 한 가지 해결 방법은 재추적을 강제 실행하도록 객체를 수정할 때마다 새로운 Function을 만드는 것입니다.

def evaluate(model, x):
  return model.weight * x + model.bias

new_model = SimpleModel()
evaluate_no_bias = tf.function(evaluate).get_concrete_function(new_model, x)
# Don't pass in `new_model`, `Function` already captured its state during tracing.
print(evaluate_no_bias(x))
tf.Tensor(20.0, shape=(), dtype=float32)
print("Adding bias!")
new_model.bias += 5.0
# Create new Function and ConcreteFunction since you modified new_model.
evaluate_with_bias = tf.function(evaluate).get_concrete_function(new_model, x)
print(evaluate_with_bias(x)) # Don't pass in `new_model`.
Adding bias!
tf.Tensor(25.0, shape=(), dtype=float32)

재추적은 비용이 많이 들기 때문에 tf.Variable을 객체 속성으로 사용할 수 있습니다. 그러면 다시 추적할 필요 없이 이를 변형(하지만 변경되지는 않음에 주의!)하여 비슷한 효과를 거둘 수 있습니다.

class BetterModel:

  def __init__(self):
    self.bias = tf.Variable(0.)
    self.weight = tf.Variable(2.)

@tf.function
def evaluate(model, x):
  return model.weight * x + model.bias

better_model = BetterModel()
print(evaluate(better_model, x))
tf.Tensor(20.0, shape=(), dtype=float32)
print("Adding bias!")
better_model.bias.assign_add(5.0)  # Note: instead of better_model.bias += 5
print(evaluate(better_model, x))  # This works!
Adding bias!
tf.Tensor(25.0, shape=(), dtype=float32)

tf.Variables 만들기

Function은 첫 번째 호출에서 한 번 생성되고 후속 함수 호출에서 재사용되는 싱글톤 tf.Variable만 지원합니다. 아래 코드 조각은 모든 함수 호출에서 새로운 tf.Variable을 생성하므로 ValueError 예외가 발생합니다.

예시:

@tf.function
def f(x):
  v = tf.Variable(1.0)
  return v

with assert_raises(ValueError):
  f(1.0)
Caught expected exception 
  <class 'ValueError'>:
Traceback (most recent call last):
  File "/tmpfs/tmp/ipykernel_184655/3551158538.py", line 8, in assert_raises
    yield
  File "/tmpfs/tmp/ipykernel_184655/3018268426.py", line 7, in <module>
    f(1.0)
ValueError: in user code:

    File "/tmpfs/tmp/ipykernel_184655/3018268426.py", line 3, in f  *
        v = tf.Variable(1.0)

    ValueError: tf.function only supports singleton tf.Variables created on the first call. Make sure the tf.Variable is only created once or created outside tf.function. See https://www.tensorflow.org/guide/function#creating_tfvariables for more information.

이 제한을 해결하는 데 사용되는 일반적인 패턴은 Python None 값으로 시작한 다음, 값이 None인 경우 조건부로 tf.Variable을 생성하는 것입니다.

class Count(tf.Module):
  def __init__(self):
    self.count = None

  @tf.function
  def __call__(self):
    if self.count is None:
      self.count = tf.Variable(0)
    return self.count.assign_add(1)

c = Count()
print(c())
print(c())
tf.Tensor(1, shape=(), dtype=int32)
tf.Tensor(2, shape=(), dtype=int32)

여러 Keras 옵티마이저 프로그램과 함께 사용

tf.function과 함께 둘 이상의 Keras 옵티마이저를 사용할 경우 ValueError: tf.function only supports singleton tf.Variables created on the first call.이 발생할 수 있습니다. 이 오류는 옵티마이저가 처음으로 그래디언트를 적용할 때 내부적으로 tf.Variables를 생성하기 때문에 발생합니다.

opt1 = tf.keras.optimizers.Adam(learning_rate = 1e-2)
opt2 = tf.keras.optimizers.Adam(learning_rate = 1e-3)

@tf.function
def train_step(w, x, y, optimizer):
   with tf.GradientTape() as tape:
       L = tf.reduce_sum(tf.square(w*x - y))
   gradients = tape.gradient(L, [w])
   optimizer.apply_gradients(zip(gradients, [w]))

w = tf.Variable(2.)
x = tf.constant([-1.])
y = tf.constant([2.])

train_step(w, x, y, opt1)
print("Calling `train_step` with different optimizer...")
with assert_raises(ValueError):
  train_step(w, x, y, opt2)
Calling `train_step` with different optimizer...
Caught expected exception 
  <class 'ValueError'>:
Traceback (most recent call last):
  File "/tmpfs/tmp/ipykernel_184655/3551158538.py", line 8, in assert_raises
    yield
  File "/tmpfs/tmp/ipykernel_184655/3167358578.py", line 18, in <module>
    train_step(w, x, y, opt2)
ValueError: in user code:

    File "/tmpfs/tmp/ipykernel_184655/3167358578.py", line 9, in train_step  *
        optimizer.apply_gradients(zip(gradients, [w]))
    File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/optimizers/optimizer_experimental/optimizer.py", line 1140, in apply_gradients  **
        return super().apply_gradients(grads_and_vars, name=name)
    File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/optimizers/optimizer_experimental/optimizer.py", line 621, in apply_gradients
        self.build(trainable_variables)
    File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/optimizers/optimizer_experimental/adam.py", line 139, in build
        self.add_variable_from_reference(
    File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/optimizers/optimizer_experimental/optimizer.py", line 1072, in add_variable_from_reference
        return super().add_variable_from_reference(
    File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/optimizers/optimizer_experimental/optimizer.py", line 496, in add_variable_from_reference
        variable = tf.Variable(

    ValueError: tf.function only supports singleton tf.Variables created on the first call. Make sure the tf.Variable is only created once or created outside tf.function. See https://www.tensorflow.org/guide/function#creating_tfvariables for more information.

훈련 중에 옵티마이저를 변경해야 하는 경우, 해결 방법은 각 옵티마이저에 새 Function을 만들어 ConcreteFunction을 직접 호출하는 것입니다.

opt1 = tf.keras.optimizers.Adam(learning_rate = 1e-2)
opt2 = tf.keras.optimizers.Adam(learning_rate = 1e-3)

# Not a tf.function.
def train_step(w, x, y, optimizer):
   with tf.GradientTape() as tape:
       L = tf.reduce_sum(tf.square(w*x - y))
   gradients = tape.gradient(L, [w])
   optimizer.apply_gradients(zip(gradients, [w]))

w = tf.Variable(2.)
x = tf.constant([-1.])
y = tf.constant([2.])

# Make a new Function and ConcreteFunction for each optimizer.
train_step_1 = tf.function(train_step).get_concrete_function(w, x, y, opt1)
train_step_2 = tf.function(train_step).get_concrete_function(w, x, y, opt2)
for i in range(10):
  if i % 2 == 0:
    train_step_1(w, x, y) # `opt1` is not used as a parameter. 
  else:
    train_step_2(w, x, y) # `opt2` is not used as a parameter.

여러 Keras 모델과 함께 사용

동일한 Function에 다른 모델 인스턴스를 전달할 때에도 ValueError: tf.function only supports singleton tf.Variables created on the first call.이 발생할 수 있습니다.

이 오류는 Keras 모델(입력 형상이 정의되지 않음)과 Keras 레이어가 처음 호출될 때 tf.Variables를 만들기 때문에 발생합니다. 이미 호출된 Function 내에서 이러한 변수를 초기화하려고 할 수도 있습니다. 이 오류를 방지하려면 model.build(input_shape)를 호출하여 모델을 훈련하기 전에 모든 가중치를 초기화합니다.

더 읽을 거리

Function을 내보내고 로드하는 방법을 알고 싶은 경우 SavedModel 가이드를 참조합니다. 추적 후 수행되는 그래프 최적화에 대해 자세히 알아보려면 Grappler 가이드를 참조합니다. 데이터 파이프라인을 최적화하고 모델을 프로파일링하는 방법을 알아보려면 프로파일러 가이드를 참조합니다.