此页面由 Cloud Translation API 翻译。
Switch to English

TF功能带来更好的性能

在TensorFlow.org上查看 在Google Colab中运行 在GitHub上查看源代码 下载笔记本

在TensorFlow 2中,急切执行默认情况下处于打开状态。用户界面直观且灵活(运行一次性操作要容易得多且更快),但这可能会牺牲性能和可部署性。

您可以使用tf.function从程序中制作图形。它是一种转换工具,可以根据您的Python代码创建与Python无关的数据流图。这将帮助您创建高性能和可移植的模型,并且需要使用SavedModel

本指南将帮助您概念化tf.functiontf.function运行的方式,以便您可以有效地使用它。

主要要点和建议是:

  • @tf.function模式下进行调试,然后使用@tf.function装饰。
  • 不要依赖于Python的副作用,例如对象突变或列表追加。
  • tf.function与TensorFlow ops tf.function效果最佳; NumPy和Python调用将转换为常量。

建立

 import tensorflow as tf
 

定义一个辅助函数来演示您可能遇到的错误类型:

 import traceback
import contextlib

# Some helper code to demonstrate the kinds of errors you might encounter.
@contextlib.contextmanager
def assert_raises(error_class):
  try:
    yield
  except error_class as e:
    print('Caught expected exception \n  {}:'.format(error_class))
    traceback.print_exc(limit=2)
  except Exception as e:
    raise e
  else:
    raise Exception('Expected {} to be raised but no error was raised!'.format(
        error_class))
 

基本

用法

您定义的Function就像TensorFlow核心操作一样:您可以急切地执行它;您可以计算梯度;等等。

 @tf.function
def add(a, b):
  return a + b

add(tf.ones([2, 2]), tf.ones([2, 2]))  #  [[2., 2.], [2., 2.]]
 
<tf.Tensor: shape=(2, 2), dtype=float32, numpy=
array([[2., 2.],
       [2., 2.]], dtype=float32)>
 v = tf.Variable(1.0)
with tf.GradientTape() as tape:
  result = add(v, 1.0)
tape.gradient(result, v)
 
<tf.Tensor: shape=(), dtype=float32, numpy=1.0>

您可以在其他Function内部使用Function

 @tf.function
def dense_layer(x, w, b):
  return add(tf.matmul(x, w), b)

dense_layer(tf.ones([3, 2]), tf.ones([2, 2]), tf.ones([2]))
 
<tf.Tensor: shape=(3, 2), dtype=float32, numpy=
array([[3., 3.],
       [3., 3.],
       [3., 3.]], dtype=float32)>

Function可以比渴望的代码更快,尤其是对于具有许多小操作的图形。但是对于具有一些昂贵操作(例如卷积)的图,您可能看不到太多的加速。

 import timeit
conv_layer = tf.keras.layers.Conv2D(100, 3)

@tf.function
def conv_fn(image):
  return conv_layer(image)

image = tf.zeros([1, 200, 200, 100])
# warm up
conv_layer(image); conv_fn(image)
print("Eager conv:", timeit.timeit(lambda: conv_layer(image), number=10))
print("Function conv:", timeit.timeit(lambda: conv_fn(image), number=10))
print("Note how there's not much difference in performance for convolutions")

 
Eager conv: 0.0023194860004878137
Function conv: 0.0036776439992536325
Note how there's not much difference in performance for convolutions

追踪

Python的动态类型化意味着您可以使用各种参数类型来调用函数,并且Python在每种情况下都可以做不同的事情。

但是,要创建TensorFlow Graph,需要静态dtypes和形状尺寸。 tf.function通过包装Python函数以创建Function对象来弥合这种差距。基于给定的输入, Function为给定的输入选择适当的图形,并根据需要跟踪Python函数。了解了跟踪的原因和时间后,有效使用tf.function会容易tf.function

您可以使用不同类型的参数调用Function ,以查看实际的这种多态行为。

 @tf.function
def double(a):
  print("Tracing with", a)
  return a + a

print(double(tf.constant(1)))
print()
print(double(tf.constant(1.1)))
print()
print(double(tf.constant("a")))
print()

 
Tracing with Tensor("a:0", shape=(), dtype=int32)
tf.Tensor(2, shape=(), dtype=int32)

Tracing with Tensor("a:0", shape=(), dtype=float32)
tf.Tensor(2.2, shape=(), dtype=float32)

Tracing with Tensor("a:0", shape=(), dtype=string)
tf.Tensor(b'aa', shape=(), dtype=string)


请注意,如果您重复调用具有相同参数类型的Function ,TensorFlow将重用先前跟踪的图,因为生成的图将是相同的。

 # This doesn't print 'Tracing with ...'
print(double(tf.constant("b")))
 
tf.Tensor(b'bb', shape=(), dtype=string)

(以下更改在每晚的TensorFlow中提供,并将在TensorFlow 2.3中提供。)

您可以使用pretty_printed_concrete_signatures()查看所有可用的跟踪:

 print(double.pretty_printed_concrete_signatures())
 
double(a)
  Args:
    a: string Tensor, shape=()
  Returns:
    string Tensor, shape=()

double(a)
  Args:
    a: int32 Tensor, shape=()
  Returns:
    int32 Tensor, shape=()

double(a)
  Args:
    a: float32 Tensor, shape=()
  Returns:
    float32 Tensor, shape=()

到目前为止,您已经看到tf.function在TensorFlow的图跟踪逻辑上创建了一个缓存的动态分配层。更具体地说明术语:

  • tf.Graph是您的计算的原始,与语言无关的可移植表示形式。
  • ConcreteFunction是围绕tf.Graph急切执行的包装器。
  • Function管理ConcreteFunction的缓存,并为您的输入选择正确的缓存。
  • tf.function包装了Python函数,并返回Function对象。

获得具体功能

每次跟踪功能时,都会创建一个新的具体功能。您可以使用get_concrete_function直接获取具体函数。

 print("Obtaining concrete trace")
double_strings = double.get_concrete_function(tf.constant("a"))
print("Executing traced function")
print(double_strings(tf.constant("a")))
print(double_strings(a=tf.constant("b")))

 
Obtaining concrete trace
Executing traced function
tf.Tensor(b'aa', shape=(), dtype=string)
tf.Tensor(b'bb', shape=(), dtype=string)

 # You can also call get_concrete_function on an InputSpec
double_strings_from_inputspec = double.get_concrete_function(tf.TensorSpec(shape=[], dtype=tf.string))
print(double_strings_from_inputspec(tf.constant("c")))
 
Tracing with Tensor("a:0", shape=(), dtype=string)
tf.Tensor(b'cc', shape=(), dtype=string)

(以下更改在每晚的TensorFlow中提供,并将在TensorFlow 2.3中提供。)

打印ConcreteFunction显示其输入参数(带有类型)及其输出类型的摘要。

 print(double_strings)
 
ConcreteFunction double(a)
  Args:
    a: string Tensor, shape=()
  Returns:
    string Tensor, shape=()

您也可以直接检索具体功能的签名。

 print(double_strings.structured_input_signature)
print(double_strings.structured_outputs)
 
((TensorSpec(shape=(), dtype=tf.string, name='a'),), {})
Tensor("Identity:0", shape=(), dtype=string)

使用具有不兼容类型的具体跟踪将引发错误

 with assert_raises(tf.errors.InvalidArgumentError):
  double_strings(tf.constant(1))
 
Caught expected exception 
  <class 'tensorflow.python.framework.errors_impl.InvalidArgumentError'>:

Traceback (most recent call last):
  File "<ipython-input-3-73d0ca52e838>", line 8, in assert_raises
    yield
  File "<ipython-input-15-e4e2860a4364>", line 2, in <module>
    double_strings(tf.constant(1))
tensorflow.python.framework.errors_impl.InvalidArgumentError: cannot compute __inference_double_168 as input #0(zero-based) was expected to be a string tensor but is a int32 tensor [Op:__inference_double_168]

您可能会注意到,在具体函数的输入签名中对Python参数进行了特殊处理。在TensorFlow 2.3之前,仅从具体函数的签名中删除了Python参数。从TensorFlow 2.3开始,Python参数保留在签名中,但仅限于在跟踪过程中采用设置的值。

 @tf.function
def pow(a, b):
  return a ** b

square = pow.get_concrete_function(a=tf.TensorSpec(None, tf.float32), b=2)
print(square)
 
ConcreteFunction pow(a, b=2)
  Args:
    a: float32 Tensor, shape=<unknown>
  Returns:
    float32 Tensor, shape=<unknown>

 assert square(tf.constant(10.0)) == 100

with assert_raises(TypeError):
  square(tf.constant(10.0), b=3)
 
Caught expected exception 
  <class 'TypeError'>:

Traceback (most recent call last):
  File "/tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/eager/function.py", line 1669, in _call_impl
    cancellation_manager)
  File "/tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/eager/function.py", line 1714, in _call_with_flat_signature
    self._flat_signature_summary(), ", ".join(sorted(kwargs))))
TypeError: pow(a) got unexpected keyword arguments: b.

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "<ipython-input-3-73d0ca52e838>", line 8, in assert_raises
    yield
  File "<ipython-input-17-d163f3d206cb>", line 4, in <module>
    square(tf.constant(10.0), b=3)
TypeError: ConcreteFunction pow(a, b) was constructed with int value 2 in b, but was called with int value 3

获取图

每个具体函数都是围绕tf.Graph的可调用包装器。尽管通常不需要检索实际的tf.Graph对象,但是您可以从任何具体函数中轻松获取它。

 graph = double_strings.graph
for node in graph.as_graph_def().node:
  print(f'{node.input} -> {node.name}')

 
[] -> a
['a', 'a'] -> add
['add'] -> Identity

调试

通常,在渴望模式下调试代码比在tf.function内部更容易。在使用tf.function装饰之前,您应确保代码在渴望模式下执行无错误。为了协助调试过程,可以调用tf.config.run_functions_eagerly(True)全局禁用和重新启用tf.function

跟踪仅出现在tf.function ,这里有一些提示:

  • 普通的旧Python print调用仅在跟踪过程中执行,从而帮助您在函数被(重新)跟踪时进行跟踪。
  • tf.print调用将每次执行,并且可以帮助您在执行期间跟踪中间值。
  • tf.debugging.enable_check_numerics是跟踪NaN和Inf创建位置的简便方法。
  • pdb可以帮助您了解跟踪过程中发生的情况。 (注意:PDB将带您进入经过AutoGraph转换的源代码。)

跟踪语义

缓存关键规则

Function确定是否通过计算从输入的指定参数和kwargs一个缓存键重用跟踪具体功能。

  • tf.Tensor参数生成的键是其形状和dtype。
  • 从TensorFlow 2.3开始,为tf.Variable参数生成的键是其id()
  • 为Python原语生成的密钥是其值。为嵌套dictlisttuple s, namedtuple s和attr的键是扁平化的元组。 (由于这种扁平化,以与在跟踪过程中使用的嵌套结构不同的嵌套结构调用具体函数会导致TypeError)。
  • 对于所有其他Python类型,键均基于对象id()以便针对类的每个实例独立地跟踪方法。

控制追溯

追溯有助于确保TensorFlow为每组输入生成正确的图形。但是,跟踪是一项昂贵的操作!如果您的Function在每次调用时都跟踪一个新图形,则会发现您的代码执行速度要比不使用tf.function

若要控制跟踪行为,可以使用以下技术:

 @tf.function(input_signature=(tf.TensorSpec(shape=[None], dtype=tf.int32),))
def next_collatz(x):
  print("Tracing with", x)
  return tf.where(x % 2 == 0, x // 2, 3 * x + 1)

print(next_collatz(tf.constant([1, 2])))
# We specified a 1-D tensor in the input signature, so this should fail.
with assert_raises(ValueError):
  next_collatz(tf.constant([[1, 2], [3, 4]]))

# We specified an int32 dtype in the input signature, so this should fail.
with assert_raises(ValueError):
  next_collatz(tf.constant([1.0, 2.0]))

 
Tracing with Tensor("x:0", shape=(None,), dtype=int32)
tf.Tensor([4 1], shape=(2,), dtype=int32)
Caught expected exception 
  <class 'ValueError'>:
Caught expected exception 
  <class 'ValueError'>:

Traceback (most recent call last):
  File "<ipython-input-3-73d0ca52e838>", line 8, in assert_raises
    yield
  File "<ipython-input-19-20f544b8adbf>", line 9, in <module>
    next_collatz(tf.constant([[1, 2], [3, 4]]))
ValueError: Python inputs incompatible with input_signature:
  inputs: (
    tf.Tensor(
[[1 2]
 [3 4]], shape=(2, 2), dtype=int32))
  input_signature: (
    TensorSpec(shape=(None,), dtype=tf.int32, name=None))
Traceback (most recent call last):
  File "<ipython-input-3-73d0ca52e838>", line 8, in assert_raises
    yield
  File "<ipython-input-19-20f544b8adbf>", line 13, in <module>
    next_collatz(tf.constant([1.0, 2.0]))
ValueError: Python inputs incompatible with input_signature:
  inputs: (
    tf.Tensor([1. 2.], shape=(2,), dtype=float32))
  input_signature: (
    TensorSpec(shape=(None,), dtype=tf.int32, name=None))

  • tf.TensorSpec指定[无]维,以便在跟踪重用中具有灵活性。

    由于TensorFlow根据张量的形状匹配张量,因此将None维度用作通配符将使Function可以将迹线重新用于可变大小的输入。如果您有不同长度的序列或每批图像具有不同大小的图像,则可能会出现大小可变的输入(例如,请参见《 TransformerDeep Dream》教程)。

 @tf.function(input_signature=(tf.TensorSpec(shape=[None], dtype=tf.int32),))
def g(x):
  print('Tracing with', x)
  return x

# No retrace!
print(g(tf.constant([1, 2, 3])))
print(g(tf.constant([1, 2, 3, 4, 5])))

 
Tracing with Tensor("x:0", shape=(None,), dtype=int32)
tf.Tensor([1 2 3], shape=(3,), dtype=int32)
tf.Tensor([1 2 3 4 5], shape=(5,), dtype=int32)

  • 将Python参数转换为张量以减少追踪。

    通常,Python参数用于控制超参数和图形构造-例如, num_layers=10training=Truenonlinearity='relu' 。因此,如果Python参数改变,则必须重新绘制图形。

    但是,可能没有使用Python参数来控制图形的构造。在这些情况下,Python值的更改会触发不必要的跟踪。以这个训练循环为例,AutoGraph将动态展开该训练循环。尽管有多条迹线,但生成的图形实际上是相同的,因此不需要进行迹线跟踪。

 def train_one_step():
  pass

@tf.function
def train(num_steps):
  print("Tracing with num_steps = ", num_steps)
  tf.print("Executing with num_steps = ", num_steps)
  for _ in tf.range(num_steps):
    train_one_step()

print("Retracing occurs for different Python arguments.")
train(num_steps=10)
train(num_steps=20)

print()
print("Traces are reused for Tensor arguments.")
train(num_steps=tf.constant(10))
train(num_steps=tf.constant(20))
 
Retracing occurs for different Python arguments.
Tracing with num_steps =  10
Executing with num_steps =  10
Tracing with num_steps =  20
Executing with num_steps =  20

Traces are reused for Tensor arguments.
Tracing with num_steps =  Tensor("num_steps:0", shape=(), dtype=int32)
Executing with num_steps =  10
Executing with num_steps =  20

如果需要强制追溯,请创建一个新的Function 。保证单独的Function对象不共享跟踪。

 def f():
  print('Tracing!')
  tf.print('Executing')

tf.function(f)()
tf.function(f)()
 
Tracing!
Executing
Tracing!
Executing

Python的副作用

Python的副作用,例如打印,追加到列表和更改全局变量,仅在您第一次使用一组输入调用Function时才会发生。之后,将重新执行跟踪的tf.Graph ,而不执行Python代码。

一般的经验法则是仅使用Python副作用来调试跟踪。否则,诸如tf.Variable.assigntf.printtf.summary类的TensorFlow操作是确保每次调用时TensorFlow运行时都将跟踪和执行代码的最佳方法。

 @tf.function
def f(x):
  print("Traced with", x)
  tf.print("Executed with", x)

f(1)
f(1)
f(2)

 
Traced with 1
Executed with 1
Executed with 1
Traced with 2
Executed with 2

许多Python功能(例如生成器和迭代器)都依赖Python运行时来跟踪状态。通常,尽管这些构造在热切的模式下可以正常工作,但在Function内部可能发生许多意外的事情。

举一个例子,推进迭代器状态是Python的副作用,因此仅在跟踪期间发生。

 external_var = tf.Variable(0)
@tf.function
def buggy_consume_next(iterator):
  external_var.assign_add(next(iterator))
  tf.print("Value of external_var:", external_var)

iterator = iter([0, 1, 2, 3])
buggy_consume_next(iterator)
# This reuses the first value from the iterator, rather than consuming the next value.
buggy_consume_next(iterator)
buggy_consume_next(iterator)

 
Value of external_var: 0
Value of external_var: 0
Value of external_var: 0

AutoGraph支持某些迭代构造。有关概述,请参见“ 自动图转换 ”部分。

如果您想在每次调用Function执行Python代码,则tf.py_function是出口。 tf.py_function的缺点是它不便携或性能不佳,在分布式(多GPU,TPU)设置中也不能很好地工作。同样,由于必须将tf.py_function连接到图中,因此它将所有输入/输出转换为张量。

tf.gathertf.stacktf.TensorArray这样的API可以帮助您在本机TensorFlow中实现常见的循环模式。

 external_list = []

def side_effect(x):
  print('Python side effect')
  external_list.append(x)

@tf.function
def f(x):
  tf.py_function(side_effect, inp=[x], Tout=[])

f(1)
f(1)
f(1)
# The list append happens all three times!
assert len(external_list) == 3
# The list contains tf.constant(1), not 1, because py_function casts everything to tensors.
assert external_list[0].numpy() == 1

 
Python side effect
Python side effect
Python side effect

变数

在函数中创建新的tf.Variable时可能会遇到错误。此错误防止重复调用时的行为差异:在渴望模式下,函数会在每次调用时创建一个新变量,但是在Function ,由于跟踪重用,可能不会创建新变量。

 @tf.function
def f(x):
  v = tf.Variable(1.0)
  v.assign_add(x)
  return v

with assert_raises(ValueError):
  f(1.0)
 
Caught expected exception 
  <class 'ValueError'>:

Traceback (most recent call last):
  File "<ipython-input-3-73d0ca52e838>", line 8, in assert_raises
    yield
  File "<ipython-input-26-73e410646579>", line 8, in <module>
    f(1.0)
ValueError: in user code:

    <ipython-input-26-73e410646579>:3 f  *
        v = tf.Variable(1.0)
    /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/ops/variables.py:262 __call__  **
        return cls._variable_v2_call(*args, **kwargs)
    /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/ops/variables.py:256 _variable_v2_call
        shape=shape)
    /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/ops/variables.py:67 getter
        return captured_getter(captured_previous, **kwargs)
    /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/eager/def_function.py:702 invalid_creator_scope
        "tf.function-decorated function tried to create "

    ValueError: tf.function-decorated function tried to create variables on non-first call.


您可以在Function内部创建变量,只要这些变量仅在函数第一次执行时创建即可。

 class Count(tf.Module):
  def __init__(self):
    self.count = None
  
  @tf.function
  def __call__(self):
    if self.count is None:
      self.count = tf.Variable(0)
    return self.count.assign_add(1)

c = Count()
print(c())
print(c())
 
tf.Tensor(1, shape=(), dtype=int32)
tf.Tensor(2, shape=(), dtype=int32)

您可能遇到的另一个错误是垃圾收集变量。与普通的Python函数不同,具体函数仅保留对它们关闭的变量的WeakRefs ,因此您必须保留对任何变量的引用。

 external_var = tf.Variable(3)
@tf.function
def f(x):
  return x * external_var

traced_f = f.get_concrete_function(4)
print("Calling concrete function...")
print(traced_f(4))

del external_var
print()
print("Calling concrete function after garbage collecting its closed Variable...")
with assert_raises(tf.errors.FailedPreconditionError):
  traced_f(4)
 
Calling concrete function...
tf.Tensor(12, shape=(), dtype=int32)

Calling concrete function after garbage collecting its closed Variable...
Caught expected exception 
  <class 'tensorflow.python.framework.errors_impl.FailedPreconditionError'>:

Traceback (most recent call last):
  File "<ipython-input-3-73d0ca52e838>", line 8, in assert_raises
    yield
  File "<ipython-input-28-304a18524b57>", line 14, in <module>
    traced_f(4)
tensorflow.python.framework.errors_impl.FailedPreconditionError: 2 root error(s) found.
  (0) Failed precondition:  Error while reading resource variable _AnonymousVar4 from Container: localhost. This could mean that the variable was uninitialized. Not found: Resource localhost/_AnonymousVar4/N10tensorflow3VarE does not exist.
     [[node ReadVariableOp (defined at <ipython-input-28-304a18524b57>:4) ]]
     [[ReadVariableOp/_2]]
  (1) Failed precondition:  Error while reading resource variable _AnonymousVar4 from Container: localhost. This could mean that the variable was uninitialized. Not found: Resource localhost/_AnonymousVar4/N10tensorflow3VarE does not exist.
     [[node ReadVariableOp (defined at <ipython-input-28-304a18524b57>:4) ]]
0 successful operations.
0 derived errors ignored. [Op:__inference_f_514]

Function call stack:
f -> f


自动图转换

AutoGraph是默认在tf.functiontf.function的库, tf.function Python渴望代码的子集转换为与图形兼容的TensorFlow ops。这包括控制流程,例如ifforwhile

tf.condtf.while_loop这样的TensorFlow操作继续起作用,但是使用Python编写时,控制流通常更易于编写和理解。

 # Simple loop

@tf.function
def f(x):
  while tf.reduce_sum(x) > 1:
    tf.print(x)
    x = tf.tanh(x)
  return x

f(tf.random.uniform([5]))
 
[0.448926926 0.896036148 0.703306437 0.446930766 0.20440042]
[0.421016544 0.714362323 0.6064623 0.419372857 0.201600626]
[0.397786468 0.613405049 0.541632056 0.396401972 0.198913112]
[0.378053397 0.546519518 0.494222373 0.376866162 0.196330562]
[0.361015767 0.497907132 0.457561225 0.359982818 0.1938463]
[0.346108437 0.460469633 0.428094476 0.3451989 0.191454232]
[0.332919776 0.43046692 0.403727621 0.332110822 0.189148799]
[0.321141869 0.405711472 0.383133948 0.320416152 0.18692489]
[0.310539037 0.384825289 0.365426034 0.309883147 0.184777796]
[0.300927401 0.366890609 0.349984437 0.300330788 0.182703182]
[0.292161077 0.351268977 0.336361736 0.291615278 0.180697069]
[0.284122646 0.337500453 0.324225426 0.283620834 0.178755745]
[0.276716352 0.325244069 0.313322544 0.276252925 0.176875815]
[0.269863278 0.314240903 0.303456694 0.269433528 0.175054088]
[0.263497591 0.304290265 0.294472754 0.263097644 0.17328763]
[0.257564 0.295233846 0.2862463 0.257190555 0.171573699]
[0.25201565 0.286944896 0.278676242 0.25166589 0.169909731]
[0.246812463 0.279320478 0.271679461 0.246483982 0.168293342]
[0.24192 0.272276044 0.265186876 0.241610721 0.166722313]
[0.237308443 0.265741408 0.259140551 0.237016559 0.165194541]
[0.23295185 0.25965777 0.253491491 0.232675791 0.163708091]
[0.228827521 0.253975391 0.248197898 0.228565902 0.162261128]
[0.224915475 0.248651937 0.243223906 0.224667087 0.160851941]
[0.221198082 0.243651047 0.238538548 0.220961839 0.159478888]
[0.217659682 0.238941342 0.23411487 0.217434615 0.158140466]
[0.214286327 0.23449555 0.229929343 0.214071587 0.156835243]
[0.211065561 0.230289876 0.225961298 0.210860386 0.155561864]
[0.207986191 0.226303399 0.222192511 0.207789883 0.154319063]
[0.20503816 0.222517684 0.2186068 0.204850093 0.153105617]

<tf.Tensor: shape=(5,), dtype=float32, numpy=
array([0.20221236, 0.2189164 , 0.21518978, 0.20203198, 0.15192041],
      dtype=float32)>

如果您感到好奇,可以检查签名生成的代码。

 print(tf.autograph.to_code(f.python_function))
 
def tf__f(x):
    with ag__.FunctionScope('f', 'fscope', ag__.ConversionOptions(recursive=True, user_requested=True, optional_features=(), internal_convert_user_code=True)) as fscope:
        do_return = False
        retval_ = ag__.UndefinedReturnValue()

        def get_state():
            return (x,)

        def set_state(vars_):
            nonlocal x
            (x,) = vars_

        def loop_body():
            nonlocal x
            ag__.converted_call(ag__.ld(tf).print, (ag__.ld(x),), None, fscope)
            x = ag__.converted_call(ag__.ld(tf).tanh, (ag__.ld(x),), None, fscope)

        def loop_test():
            return (ag__.converted_call(ag__.ld(tf).reduce_sum, (ag__.ld(x),), None, fscope) > 1)
        ag__.while_stmt(loop_test, loop_body, get_state, set_state, ('x',), {})
        try:
            do_return = True
            retval_ = ag__.ld(x)
        except:
            do_return = False
            raise
        return fscope.ret(retval_, do_return)


有条件的

AutoGraph会将某些if <condition>语句转换为等效的tf.cond调用。如果<condition>是张量,则进行此替换。否则, if语句作为Python条件执行。

Python条件在跟踪过程中执行,因此条件的一个分支将被添加到图中。如果没有依赖于数据的控制流,则没有AutoGraph的情况下,该跟踪图将无法采用替代分支。

tf.cond跟踪条件的两个分支并将其添加到图,在执行时动态选择一个分支。跟踪可能会产生意外的副作用;有关更多信息,请参见AutoGraph跟踪效果

 @tf.function
def fizzbuzz(n):
  for i in tf.range(1, n + 1):
    print('Tracing for loop')
    if i % 15 == 0:
      print('Tracing fizzbuzz branch')
      tf.print('fizzbuzz')
    elif i % 3 == 0:
      print('Tracing fizz branch')
      tf.print('fizz')
    elif i % 5 == 0:
      print('Tracing buzz branch')
      tf.print('buzz')
    else:
      print('Tracing default branch')
      tf.print(i)

fizzbuzz(tf.constant(5))
fizzbuzz(tf.constant(20))
 
Tracing for loop
Tracing fizzbuzz branch
Tracing fizz branch
Tracing buzz branch
Tracing default branch
1
2
fizz
4
buzz
1
2
fizz
4
buzz
fizz
7
8
fizz
buzz
11
fizz
13
14
fizzbuzz
16
17
fizz
19
buzz

有关对自动图形转换的if语句的其他限制,请参见参考文档

循环

AutoGraph会将一些forwhile语句转换为等效的TensorFlow循环操作,例如tf.while_loop 。如果未转换,则forwhile循环将作为Python循环执行。

在以下情况下进行替换:

  • for x in y :如果y是张量,则转换为tf.while_loop 。在ytf.data.Dataset的特殊情况下,会生成tf.data.Dataset ops的组合。
  • while <condition> :如果<condition>是张量,则转换为tf.while_loop

在跟踪期间执行Python循环,为循环的每次迭代添加额外的操作到tf.Graph中。

TensorFlow循环跟踪循环的主体,并动态选择在执行时要运行的迭代次数。循环体在生成的tf.Graph仅出现一次。

有关对自动图形转换的forwhile语句的其他限制,请参见参考文档

遍历Python数据

一个常见的陷阱是在tf.function循环遍历Python / Numpy数据。该循环将在跟踪过程中执行,并为循环的每次迭代将模型的副本添加到tf.Graph中。

如果要将整个训练循环包装在tf.function ,最安全的方法是将数据包装为tf.data.Dataset以便AutoGraph动态展开训练循环。

 def measure_graph_size(f, *args):
  g = f.get_concrete_function(*args).graph
  print("{}({}) contains {} nodes in its graph".format(
      f.__name__, ', '.join(map(str, args)), len(g.as_graph_def().node)))

@tf.function
def train(dataset):
  loss = tf.constant(0)
  for x, y in dataset:
    loss += tf.abs(y - x) # Some dummy computation.
  return loss

small_data = [(1, 1)] * 3
big_data = [(1, 1)] * 10
measure_graph_size(train, small_data)
measure_graph_size(train, big_data)

measure_graph_size(train, tf.data.Dataset.from_generator(
    lambda: small_data, (tf.int32, tf.int32)))
measure_graph_size(train, tf.data.Dataset.from_generator(
    lambda: big_data, (tf.int32, tf.int32)))
 
train([(1, 1), (1, 1), (1, 1)]) contains 11 nodes in its graph
train([(1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1)]) contains 32 nodes in its graph
train(<FlatMapDataset shapes: (<unknown>, <unknown>), types: (tf.int32, tf.int32)>) contains 8 nodes in its graph
train(<FlatMapDataset shapes: (<unknown>, <unknown>), types: (tf.int32, tf.int32)>) contains 8 nodes in its graph

将Python / Numpy数据包装在Dataset中时,请注意tf.data.Dataset.from_generatortf.data.Dataset.from_tensors 。前者会将数据保留在Python中,并通过tf.py_function获取数据,这可能会影响性能,而后者会将数据的副本捆绑为图中的一个大tf.constant()节点,这可能会影响内存。

通过TFRecordDataset / CsvDataset / etc从文件读取数据。是使用数据的最有效方法,因为TensorFlow本身可以管理数据的异步加载和预取,而无需使用Python。要了解更多信息,请参阅tf.data指南

循环累积值

一种常见的模式是从循环中累积中间值。通常,这是通过附加到Python列表或将条目添加到Python字典来完成的。但是,由于这些是Python的副作用,因此它们在动态展开的循环中将无法正常工作。使用tf.TensorArray累积来自动态展开循环的结果。

 batch_size = 2
seq_len = 3
feature_size = 4

def rnn_step(inp, state):
  return inp + state

@tf.function
def dynamic_rnn(rnn_step, input_data, initial_state):
  # [batch, time, features] -> [time, batch, features]
  input_data = tf.transpose(input_data, [1, 0, 2])
  max_seq_len = input_data.shape[0]

  states = tf.TensorArray(tf.float32, size=max_seq_len)
  state = initial_state
  for i in tf.range(max_seq_len):
    state = rnn_step(input_data[i], state)
    states = states.write(i, state)
  return tf.transpose(states.stack(), [1, 0, 2])
  
dynamic_rnn(rnn_step,
            tf.random.uniform([batch_size, seq_len, feature_size]),
            tf.zeros([batch_size, feature_size]))
 
<tf.Tensor: shape=(2, 3, 4), dtype=float32, numpy=
array([[[0.2486304 , 0.0612042 , 0.69624186, 0.28587592],
        [1.2193475 , 0.2389338 , 1.5216837 , 0.38649392],
        [1.7640524 , 1.1970762 , 2.3265643 , 0.81419575]],

       [[0.36599267, 0.41830885, 0.73540664, 0.63987565],
        [0.48354673, 1.1808103 , 1.7210082 , 0.8333106 ],
        [0.7138835 , 1.2030114 , 1.8544207 , 1.1647347 ]]], dtype=float32)>

进一步阅读

要了解有关如何导出和加载Function ,请参见SavedModel指南 。要了解有关跟踪后执行的图形优化的更多信息,请参阅Grappler指南 。要了解如何优化数据管道和分析模型,请参阅Profiler指南