Trang này được dịch bởi Cloud Translation API.
Switch to English

Hiệu suất tốt hơn với tf. Chức năng

Xem trên TensorFlow.org Chạy trong Google Colab Xem nguồn trên GitHub Tải xuống sổ tay

Trong TensorFlow 2, thực thi háo hức được bật theo mặc định. Giao diện người dùng trực quan và linh hoạt (chạy các hoạt động một lần dễ dàng hơn và nhanh hơn nhiều), nhưng điều này có thể phải trả giá bằng hiệu suất và khả năng triển khai.

Bạn có thể sử dụng tf.function để tạo đồ thị từ các chương trình của mình. Nó là một công cụ chuyển đổi tạo đồ thị luồng dữ liệu độc lập với Python từ mã Python của bạn. Điều này sẽ giúp bạn tạo ra các mô hình hiệu quả và di động, và bắt buộc phải sử dụng SavedModel .

Hướng dẫn này sẽ giúp bạn hình dung cách hoạt động của tf.function dưới mui xe để bạn có thể sử dụng nó một cách hiệu quả.

Các điểm rút ra và khuyến nghị chính là:

  • Gỡ lỗi ở chế độ háo hức, sau đó trang trí bằng @tf.function . @tf.function .
  • Đừng dựa vào các tác dụng phụ của Python như đột biến đối tượng hoặc nối danh sách.
  • tf.function hoạt động tốt nhất với hoạt động TensorFlow; Các cuộc gọi NumPy và Python được chuyển đổi thành hằng số.

Thiết lập

import tensorflow as tf

Xác định một hàm trợ giúp để chứng minh các loại lỗi bạn có thể gặp phải:

import traceback
import contextlib

# Some helper code to demonstrate the kinds of errors you might encounter.
@contextlib.contextmanager
def assert_raises(error_class):
  try:
    yield
  except error_class as e:
    print('Caught expected exception \n  {}:'.format(error_class))
    traceback.print_exc(limit=2)
  except Exception as e:
    raise e
  else:
    raise Exception('Expected {} to be raised but no error was raised!'.format(
        error_class))

Khái niệm cơ bản

Sử dụng

Một Function bạn xác định cũng giống như một hoạt động TensorFlow cốt lõi: Bạn có thể thực thi nó một cách hăng hái; bạn có thể tính toán độ dốc; và như thế.

@tf.function
def add(a, b):
  return a + b

add(tf.ones([2, 2]), tf.ones([2, 2]))  #  [[2., 2.], [2., 2.]]
<tf.Tensor: shape=(2, 2), dtype=float32, numpy=
array([[2., 2.],
       [2., 2.]], dtype=float32)>
v = tf.Variable(1.0)
with tf.GradientTape() as tape:
  result = add(v, 1.0)
tape.gradient(result, v)
<tf.Tensor: shape=(), dtype=float32, numpy=1.0>

Bạn có thể sử dụng các Function bên trong các Function khác.

@tf.function
def dense_layer(x, w, b):
  return add(tf.matmul(x, w), b)

dense_layer(tf.ones([3, 2]), tf.ones([2, 2]), tf.ones([2]))
<tf.Tensor: shape=(3, 2), dtype=float32, numpy=
array([[3., 3.],
       [3., 3.],
       [3., 3.]], dtype=float32)>

Function s có thể nhanh hơn mã háo hức, đặc biệt là đối với đồ thị có nhiều hoạt động nhỏ. Nhưng đối với các đồ thị có một vài hoạt động đắt tiền (như chập), bạn có thể không thấy tốc độ tăng nhiều.

import timeit
conv_layer = tf.keras.layers.Conv2D(100, 3)

@tf.function
def conv_fn(image):
  return conv_layer(image)

image = tf.zeros([1, 200, 200, 100])
# warm up
conv_layer(image); conv_fn(image)
print("Eager conv:", timeit.timeit(lambda: conv_layer(image), number=10))
print("Function conv:", timeit.timeit(lambda: conv_fn(image), number=10))
print("Note how there's not much difference in performance for convolutions")

Eager conv: 0.002407395000091128
Function conv: 0.004000883000117028
Note how there's not much difference in performance for convolutions

Truy tìm

Kiểu gõ động của Python có nghĩa là bạn có thể gọi các hàm với nhiều kiểu đối số khác nhau và Python có thể làm điều gì đó khác nhau trong mỗi trường hợp.

Tuy nhiên, để tạo Đồ thị TensorFlow, cần có các dtypes tĩnh và kích thước hình dạng. tf.function thu hẹp khoảng cách này bằng cách gói một hàm Python để tạo một đối tượng Function . Dựa trên các đầu vào đã cho, Function chọn đồ thị thích hợp cho các đầu vào đã cho, thực hiện lại hàm Python nếu cần. Một khi bạn hiểu tại sao và khi nào việc truy tìm xảy ra, việc sử dụng tf.function một cách hiệu quả sẽ dễ dàng hơn nhiều!

Bạn có thể gọi một Function với các đối số thuộc các kiểu khác nhau để xem hành vi đa hình này đang hoạt động.

@tf.function
def double(a):
  print("Tracing with", a)
  return a + a

print(double(tf.constant(1)))
print()
print(double(tf.constant(1.1)))
print()
print(double(tf.constant("a")))
print()

Tracing with Tensor("a:0", shape=(), dtype=int32)
tf.Tensor(2, shape=(), dtype=int32)

Tracing with Tensor("a:0", shape=(), dtype=float32)
tf.Tensor(2.2, shape=(), dtype=float32)

Tracing with Tensor("a:0", shape=(), dtype=string)
tf.Tensor(b'aa', shape=(), dtype=string)


Lưu ý rằng nếu bạn gọi nhiều lần một Function có cùng loại đối số, TensorFlow sẽ sử dụng lại một biểu đồ đã theo dõi trước đó, vì biểu đồ được tạo sẽ giống hệt nhau.

# This doesn't print 'Tracing with ...'
print(double(tf.constant("b")))
tf.Tensor(b'bb', shape=(), dtype=string)

Bạn có thể sử dụng pretty_printed_concrete_signatures() để xem tất cả các dấu vết có sẵn:

print(double.pretty_printed_concrete_signatures())
double(a)
  Args:
    a: int32 Tensor, shape=()
  Returns:
    int32 Tensor, shape=()

double(a)
  Args:
    a: float32 Tensor, shape=()
  Returns:
    float32 Tensor, shape=()

double(a)
  Args:
    a: string Tensor, shape=()
  Returns:
    string Tensor, shape=()

Cho đến nay, bạn đã thấy rằng tf.function tạo ra một lớp điều phối động, được lưu trong bộ nhớ cache trên logic theo dõi đồ thị của TensorFlow. Để cụ thể hơn về thuật ngữ:

  • Một tf.Graph là bản trình bày thô, không thể ngôn ngữ, di động cho tính toán của bạn.
  • ConcreteFunction là một trình bao bọc nhanh chóng thực thi xung quanh một tf.Graph .
  • Function quản lý bộ nhớ cache của các ConcreteFunction và chọn một bộ nhớ đệm phù hợp cho đầu vào của bạn.
  • tf.function kết thúc một hàm Python, trả về một đối tượng Function .

Có được các chức năng cụ thể

Mỗi khi một chức năng được truy tìm, một chức năng cụ thể mới được tạo ra. Bạn có thể lấy trực tiếp một hàm cụ thể bằng cách sử dụng get_concrete_function .

print("Obtaining concrete trace")
double_strings = double.get_concrete_function(tf.constant("a"))
print("Executing traced function")
print(double_strings(tf.constant("a")))
print(double_strings(a=tf.constant("b")))

Obtaining concrete trace
Executing traced function
tf.Tensor(b'aa', shape=(), dtype=string)
tf.Tensor(b'bb', shape=(), dtype=string)

# You can also call get_concrete_function on an InputSpec
double_strings_from_inputspec = double.get_concrete_function(tf.TensorSpec(shape=[], dtype=tf.string))
print(double_strings_from_inputspec(tf.constant("c")))
Tracing with Tensor("a:0", shape=(), dtype=string)
tf.Tensor(b'cc', shape=(), dtype=string)

(Thay đổi sau có trong TensorFlow hàng đêm và sẽ có trong TensorFlow 2.3.)

In một ConcreteFunction hiển thị một bản tóm tắt các đối số đầu vào của nó (với các kiểu) và kiểu đầu ra của nó.

print(double_strings)
ConcreteFunction double(a)
  Args:
    a: string Tensor, shape=()
  Returns:
    string Tensor, shape=()

Bạn cũng có thể truy xuất trực tiếp chữ ký của một hàm cụ thể.

print(double_strings.structured_input_signature)
print(double_strings.structured_outputs)
((TensorSpec(shape=(), dtype=tf.string, name='a'),), {})
Tensor("Identity:0", shape=(), dtype=string)

Sử dụng dấu vết cụ thể với các loại không tương thích sẽ gây ra lỗi

with assert_raises(tf.errors.InvalidArgumentError):
  double_strings(tf.constant(1))
Caught expected exception 
  <class 'tensorflow.python.framework.errors_impl.InvalidArgumentError'>:

Traceback (most recent call last):
  File "<ipython-input-1-73d0ca52e838>", line 8, in assert_raises
    yield
  File "<ipython-input-1-e4e2860a4364>", line 2, in <module>
    double_strings(tf.constant(1))
tensorflow.python.framework.errors_impl.InvalidArgumentError: cannot compute __inference_double_168 as input #0(zero-based) was expected to be a string tensor but is a int32 tensor [Op:__inference_double_168]

Bạn có thể nhận thấy rằng các đối số Python được xử lý đặc biệt trong chữ ký đầu vào của một hàm cụ thể. Trước TensorFlow 2.3, các đối số trong Python chỉ đơn giản là bị xóa khỏi chữ ký của hàm cụ thể. Bắt đầu với TensorFlow 2.3, các đối số Python vẫn còn trong chữ ký, nhưng bị hạn chế lấy giá trị được đặt trong quá trình truy tìm.

@tf.function
def pow(a, b):
  return a ** b

square = pow.get_concrete_function(a=tf.TensorSpec(None, tf.float32), b=2)
print(square)
ConcreteFunction pow(a, b=2)
  Args:
    a: float32 Tensor, shape=<unknown>
  Returns:
    float32 Tensor, shape=<unknown>

assert square(tf.constant(10.0)) == 100

with assert_raises(TypeError):
  square(tf.constant(10.0), b=3)
Caught expected exception 
  <class 'TypeError'>:

Traceback (most recent call last):
  File "/tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/eager/function.py", line 1669, in _call_impl
    cancellation_manager)
  File "/tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/eager/function.py", line 1714, in _call_with_flat_signature
    self._flat_signature_summary(), ", ".join(sorted(kwargs))))
TypeError: pow(a) got unexpected keyword arguments: b.

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "<ipython-input-1-73d0ca52e838>", line 8, in assert_raises
    yield
  File "<ipython-input-1-d163f3d206cb>", line 4, in <module>
    square(tf.constant(10.0), b=3)
TypeError: ConcreteFunction pow(a, b) was constructed with int value 2 in b, but was called with int value 3

Lấy đồ thị

Mỗi hàm cụ thể là một trình bao bọc có thể gọi xung quanh một tf.Graph . Mặc dù truy xuất đối tượng tf.Graph thực tế không phải là điều bạn thường cần làm, nhưng bạn có thể lấy nó dễ dàng từ bất kỳ hàm cụ thể nào.

graph = double_strings.graph
for node in graph.as_graph_def().node:
  print(f'{node.input} -> {node.name}')

[] -> a
['a', 'a'] -> add
['add'] -> Identity

Gỡ lỗi

Nói chung, mã gỡ lỗi trong chế độ háo hức dễ dàng hơn bên trong tf.function . tf.function . Bạn nên đảm bảo rằng mã của bạn thực thi không có lỗi ở chế độ háo hức trước khi trang trí bằng tf.function . tf.function . Để hỗ trợ quá trình gỡ lỗi, bạn có thể gọi tf.config.run_functions_eagerly(True) để tắt và kích hoạt lại tf.function . tf.function trên toàn cầu.

Khi theo dõi các vấn đề chỉ xuất hiện trong tf.function , đây là một số mẹo:

  • Các print gọi print Python cũ đơn thuần chỉ thực thi trong quá trình theo dõi, giúp bạn theo dõi khi hàm của bạn được truy tìm (lại).
  • tf.print gọi tf.print sẽ thực hiện mọi lúc và có thể giúp bạn theo dõi các giá trị trung gian trong quá trình thực thi.
  • tf.debugging.enable_check_numerics là một cách dễ dàng để theo dõi nơi tạo NaN và Inf.
  • pdb có thể giúp bạn hiểu những gì đang xảy ra trong quá trình theo dõi. (Lưu ý: PDB sẽ đưa bạn vào mã nguồn được chuyển đổi AutoGraph.)

Truy tìm ngữ nghĩa

Quy tắc khóa bộ nhớ cache

Một Function xác định xem có sử dụng lại một hàm cụ thể theo dõi hay không bằng cách tính toán khóa bộ nhớ cache từ các args và kwargs của đầu vào.

  • Khóa được tạo cho một đối số tf.Tensor là hình dạng và kiểu của nó.
  • Bắt đầu từ TensorFlow 2.3, khóa được tạo cho một đối số tf.Variableid() của nó.
  • Khóa được tạo cho một nguyên thủy Python là giá trị của nó. Khóa được tạo cho các dict lồng nhau, list s, tuple s, namedtuple s và attr s là bộ tuple được làm phẳng. (Kết quả của việc làm phẳng này, việc gọi một hàm bê tông có cấu trúc lồng khác với cấu trúc được sử dụng trong quá trình truy tìm sẽ dẫn đến Lỗi loại).
  • Đối với tất cả các kiểu Python khác, các khóa dựa trên đối tượng id() để các phương thức được truy tìm độc lập cho mỗi phiên bản của một lớp.

Kiểm soát việc kiểm tra lại

Retracing giúp đảm bảo rằng TensorFlow tạo ra các đồ thị chính xác cho từng tập hợp đầu vào. Tuy nhiên, truy tìm là một hoạt động tốn kém! Nếu Function của bạn truy xuất lại một biểu đồ mới cho mỗi lần gọi, bạn sẽ thấy rằng mã của bạn thực thi chậm hơn so với khi bạn không sử dụng tf.function . tf.function .

Để kiểm soát hành vi theo dõi, bạn có thể sử dụng các kỹ thuật sau:

  • Chỉ định input_signature trong tf.function để hạn chế việc theo dõi.
@tf.function(input_signature=(tf.TensorSpec(shape=[None], dtype=tf.int32),))
def next_collatz(x):
  print("Tracing with", x)
  return tf.where(x % 2 == 0, x // 2, 3 * x + 1)

print(next_collatz(tf.constant([1, 2])))
# We specified a 1-D tensor in the input signature, so this should fail.
with assert_raises(ValueError):
  next_collatz(tf.constant([[1, 2], [3, 4]]))

# We specified an int32 dtype in the input signature, so this should fail.
with assert_raises(ValueError):
  next_collatz(tf.constant([1.0, 2.0]))

Tracing with Tensor("x:0", shape=(None,), dtype=int32)
tf.Tensor([4 1], shape=(2,), dtype=int32)
Caught expected exception 
  <class 'ValueError'>:
Caught expected exception 
  <class 'ValueError'>:

Traceback (most recent call last):
  File "<ipython-input-1-73d0ca52e838>", line 8, in assert_raises
    yield
  File "<ipython-input-1-20f544b8adbf>", line 9, in <module>
    next_collatz(tf.constant([[1, 2], [3, 4]]))
ValueError: Python inputs incompatible with input_signature:
  inputs: (
    tf.Tensor(
[[1 2]
 [3 4]], shape=(2, 2), dtype=int32))
  input_signature: (
    TensorSpec(shape=(None,), dtype=tf.int32, name=None))
Traceback (most recent call last):
  File "<ipython-input-1-73d0ca52e838>", line 8, in assert_raises
    yield
  File "<ipython-input-1-20f544b8adbf>", line 13, in <module>
    next_collatz(tf.constant([1.0, 2.0]))
ValueError: Python inputs incompatible with input_signature:
  inputs: (
    tf.Tensor([1. 2.], shape=(2,), dtype=float32))
  input_signature: (
    TensorSpec(shape=(None,), dtype=tf.int32, name=None))

  • Chỉ định thứ nguyên [Không có] trong tf.TensorSpec để cho phép sử dụng lại dấu vết một cách linh hoạt.

    Vì TensorFlow so khớp các tensor dựa trên hình dạng của chúng, việc sử dụng thứ nguyên None làm ký tự đại diện sẽ cho phép các Function sử dụng lại dấu vết cho đầu vào có kích thước thay đổi. Đầu vào có kích thước khác nhau có thể xảy ra nếu bạn có các chuỗi có độ dài khác nhau hoặc hình ảnh có kích thước khác nhau cho mỗi lô (Ví dụ: Xem hướng dẫn về TransformerDeep Dream ).

@tf.function(input_signature=(tf.TensorSpec(shape=[None], dtype=tf.int32),))
def g(x):
  print('Tracing with', x)
  return x

# No retrace!
print(g(tf.constant([1, 2, 3])))
print(g(tf.constant([1, 2, 3, 4, 5])))

Tracing with Tensor("x:0", shape=(None,), dtype=int32)
tf.Tensor([1 2 3], shape=(3,), dtype=int32)
tf.Tensor([1 2 3 4 5], shape=(5,), dtype=int32)

  • Truyền các đối số Python đến Tensors để giảm việc rút lại.

    Thông thường, các đối số Python được sử dụng để kiểm soát các siêu tham số và cấu trúc đồ thị - ví dụ: num_layers=10 hoặc training=True hoặc nonlinearity='relu' . Vì vậy, nếu đối số Python thay đổi, có nghĩa là bạn phải truy xuất lại biểu đồ.

    Tuy nhiên, có thể một đối số Python không được sử dụng để kiểm soát việc xây dựng đồ thị. Trong những trường hợp này, một sự thay đổi trong giá trị Python có thể kích hoạt việc kiểm tra lại không cần thiết. Lấy ví dụ, vòng lặp đào tạo này, AutoGraph sẽ tự động hủy cuộn. Mặc dù có nhiều dấu vết, nhưng biểu đồ được tạo ra thực sự giống hệt nhau, vì vậy việc kiểm tra lại là không cần thiết.

def train_one_step():
  pass

@tf.function
def train(num_steps):
  print("Tracing with num_steps = ", num_steps)
  tf.print("Executing with num_steps = ", num_steps)
  for _ in tf.range(num_steps):
    train_one_step()

print("Retracing occurs for different Python arguments.")
train(num_steps=10)
train(num_steps=20)

print()
print("Traces are reused for Tensor arguments.")
train(num_steps=tf.constant(10))
train(num_steps=tf.constant(20))
Retracing occurs for different Python arguments.
Tracing with num_steps =  10
Executing with num_steps =  10
Tracing with num_steps =  20
Executing with num_steps =  20

Traces are reused for Tensor arguments.
Tracing with num_steps =  Tensor("num_steps:0", shape=(), dtype=int32)
Executing with num_steps =  10
Executing with num_steps =  20

Nếu bạn cần buộc phải rút lại, hãy tạo một Function mới. Các đối tượng Function riêng biệt được đảm bảo không chia sẻ dấu vết.

def f():
  print('Tracing!')
  tf.print('Executing')

tf.function(f)()
tf.function(f)()
Tracing!
Executing
Tracing!
Executing

Tác dụng phụ của Python

Các tác dụng phụ trong Python như in, thêm vào danh sách và thay đổi hình cầu chỉ xảy ra lần đầu tiên bạn gọi một Function với một tập hợp các đầu vào. Sau đó, tf.Graph dõi được thực thi lại mà không thực thi mã Python.

Nguyên tắc chung là chỉ sử dụng các tác dụng phụ của Python để gỡ lỗi dấu vết của bạn. Nếu không, các hoạt động TensorFlow như tf.Variable.assign , tf.printtf.summary là cách tốt nhất để đảm bảo mã của bạn sẽ được truy tìm và thực thi bởi thời gian chạy TensorFlow với mỗi lần gọi.

@tf.function
def f(x):
  print("Traced with", x)
  tf.print("Executed with", x)

f(1)
f(1)
f(2)

Traced with 1
Executed with 1
Executed with 1
Traced with 2
Executed with 2

Nhiều tính năng của Python, chẳng hạn như trình tạo và trình vòng lặp, dựa vào thời gian chạy Python để theo dõi trạng thái. Nói chung, trong khi các cấu trúc này hoạt động như mong đợi ở chế độ háo hức, nhiều điều không mong muốn có thể xảy ra bên trong một Function .

Để đưa ra một ví dụ, việc tăng trạng thái trình lặp là một tác dụng phụ của Python và do đó chỉ xảy ra trong quá trình truy tìm.

external_var = tf.Variable(0)
@tf.function
def buggy_consume_next(iterator):
  external_var.assign_add(next(iterator))
  tf.print("Value of external_var:", external_var)

iterator = iter([0, 1, 2, 3])
buggy_consume_next(iterator)
# This reuses the first value from the iterator, rather than consuming the next value.
buggy_consume_next(iterator)
buggy_consume_next(iterator)

Value of external_var: 0
Value of external_var: 0
Value of external_var: 0

Một số cấu trúc lặp được hỗ trợ thông qua AutoGraph. Xem phần về Biến đổi tự động để biết tổng quan.

Nếu bạn muốn thực thi mã Python trong mỗi lần gọi một Function , tf.py_function Function là một lối thoát. Hạn chế của tf.py_function là nó không di động hoặc đặc biệt hiệu quả, cũng như không hoạt động tốt trong các thiết lập phân tán (đa GPU, TPU). Ngoài ra, vì tf.py_function phải được kết nối với biểu đồ, nó chuyển tất cả các đầu vào / đầu ra thành tensor.

Các API như tf.gather , tf.stacktf.TensorArray có thể giúp bạn triển khai các mẫu lặp phổ biến trong TensorFlow gốc.

external_list = []

def side_effect(x):
  print('Python side effect')
  external_list.append(x)

@tf.function
def f(x):
  tf.py_function(side_effect, inp=[x], Tout=[])

f(1)
f(1)
f(1)
# The list append happens all three times!
assert len(external_list) == 3
# The list contains tf.constant(1), not 1, because py_function casts everything to tensors.
assert external_list[0].numpy() == 1

Python side effect
Python side effect
Python side effect

Biến

Bạn có thể gặp lỗi khi tạo tf.Variable mới trong một hàm. Lỗi này bảo vệ chống lại sự phân kỳ hành vi đối với các cuộc gọi lặp lại: Trong chế độ háo hức, một hàm tạo một biến mới với mỗi lệnh gọi, nhưng trong một Function , một biến mới có thể không được tạo do sử dụng lại dấu vết.

@tf.function
def f(x):
  v = tf.Variable(1.0)
  v.assign_add(x)
  return v

with assert_raises(ValueError):
  f(1.0)
Caught expected exception 
  <class 'ValueError'>:

Traceback (most recent call last):
  File "<ipython-input-1-73d0ca52e838>", line 8, in assert_raises
    yield
  File "<ipython-input-1-73e410646579>", line 8, in <module>
    f(1.0)
ValueError: in user code:

    <ipython-input-1-73e410646579>:3 f  *
        v = tf.Variable(1.0)
    /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/ops/variables.py:262 __call__  **
        return cls._variable_v2_call(*args, **kwargs)
    /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/ops/variables.py:256 _variable_v2_call
        shape=shape)
    /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/ops/variables.py:67 getter
        return captured_getter(captured_previous, **kwargs)
    /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/eager/def_function.py:702 invalid_creator_scope
        "tf.function-decorated function tried to create "

    ValueError: tf.function-decorated function tried to create variables on non-first call.


Bạn có thể tạo các biến bên trong một Function miễn là các biến đó chỉ được tạo vào lần đầu tiên hàm được thực thi.

class Count(tf.Module):
  def __init__(self):
    self.count = None
  
  @tf.function
  def __call__(self):
    if self.count is None:
      self.count = tf.Variable(0)
    return self.count.assign_add(1)

c = Count()
print(c())
print(c())
tf.Tensor(1, shape=(), dtype=int32)
tf.Tensor(2, shape=(), dtype=int32)

Một lỗi khác mà bạn có thể gặp phải là biến thu thập rác. Không giống như các hàm Python thông thường, các hàm cụ thể chỉ giữ lại WeakRefs cho các biến mà chúng đóng lại, vì vậy bạn phải giữ lại một tham chiếu đến bất kỳ biến nào.

external_var = tf.Variable(3)
@tf.function
def f(x):
  return x * external_var

traced_f = f.get_concrete_function(4)
print("Calling concrete function...")
print(traced_f(4))

del external_var
print()
print("Calling concrete function after garbage collecting its closed Variable...")
with assert_raises(tf.errors.FailedPreconditionError):
  traced_f(4)
Calling concrete function...
tf.Tensor(12, shape=(), dtype=int32)

Calling concrete function after garbage collecting its closed Variable...
Caught expected exception 
  <class 'tensorflow.python.framework.errors_impl.FailedPreconditionError'>:

Traceback (most recent call last):
  File "<ipython-input-1-73d0ca52e838>", line 8, in assert_raises
    yield
  File "<ipython-input-1-304a18524b57>", line 14, in <module>
    traced_f(4)
tensorflow.python.framework.errors_impl.FailedPreconditionError: 2 root error(s) found.
  (0) Failed precondition:  Error while reading resource variable _AnonymousVar4 from Container: localhost. This could mean that the variable was uninitialized. Not found: Resource localhost/_AnonymousVar4/N10tensorflow3VarE does not exist.
     [[node ReadVariableOp (defined at <ipython-input-1-304a18524b57>:4) ]]
  (1) Failed precondition:  Error while reading resource variable _AnonymousVar4 from Container: localhost. This could mean that the variable was uninitialized. Not found: Resource localhost/_AnonymousVar4/N10tensorflow3VarE does not exist.
     [[node ReadVariableOp (defined at <ipython-input-1-304a18524b57>:4) ]]
     [[ReadVariableOp/_2]]
0 successful operations.
0 derived errors ignored. [Op:__inference_f_514]

Function call stack:
f -> f


Chuyển đổi AutoGraph

AutoGraph là một thư viện được bật theo mặc định trong tf.function và chuyển một tập hợp con của mã háo hức Python thành các hoạt động TensorFlow tương thích với đồ thị. Điều này bao gồm luồng điều khiển như if , for , while .

Các hoạt động TensorFlow như tf.condtf.while_loop tiếp tục hoạt động, nhưng luồng điều khiển thường dễ viết và dễ hiểu hơn khi được viết bằng Python.

# Simple loop

@tf.function
def f(x):
  while tf.reduce_sum(x) > 1:
    tf.print(x)
    x = tf.tanh(x)
  return x

f(tf.random.uniform([5]))
[0.224704742 0.895507693 0.0398198366 0.98112452 0.278468847]
[0.220997646 0.71410346 0.0397988036 0.753552318 0.271487355]
[0.217468739 0.61324358 0.0397778042 0.637263417 0.265008271]
[0.214104146 0.546406269 0.0397568382 0.563033342 0.258973926]
[0.210891485 0.497821957 0.0397359058 0.510224521 0.253335565]
[0.207819641 0.460402519 0.0397150069 0.470120102 0.248051569]
[0.204878598 0.430412233 0.0396941416 0.438296348 0.243086234]
[0.202059314 0.405665785 0.039673306 0.412231296 0.2384087]
[0.199353606 0.384786367 0.039652504 0.39036563 0.23399213]
[0.196754038 0.366856933 0.0396317355 0.371675402 0.229813099]
[0.194253832 0.351239443 0.039611 0.355456293 0.225851]
[0.191846803 0.337474287 0.0395902954 0.341205537 0.222087651]
[0.189527303 0.325220674 0.0395696238 0.3285532 0.218506947]
[0.187290132 0.314219803 0.0395489857 0.317220151 0.215094551]
[0.185130537 0.304271102 0.0395283774 0.30699119 0.211837649]
[0.183044136 0.295216352 0.0395078026 0.297697395 0.208724767]
[0.181026861 0.286928833 0.0394872613 0.289204 0.205745578]

<tf.Tensor: shape=(5,), dtype=float32, numpy=
array([0.17907499, 0.27930567, 0.03946675, 0.281402  , 0.20289075],
      dtype=float32)>

Nếu bạn tò mò, bạn có thể kiểm tra mã tạo ra chữ ký.

print(tf.autograph.to_code(f.python_function))
def tf__f(x):
    with ag__.FunctionScope('f', 'fscope', ag__.ConversionOptions(recursive=True, user_requested=True, optional_features=(), internal_convert_user_code=True)) as fscope:
        do_return = False
        retval_ = ag__.UndefinedReturnValue()

        def get_state():
            return (x,)

        def set_state(vars_):
            nonlocal x
            (x,) = vars_

        def loop_body():
            nonlocal x
            ag__.converted_call(ag__.ld(tf).print, (ag__.ld(x),), None, fscope)
            x = ag__.converted_call(ag__.ld(tf).tanh, (ag__.ld(x),), None, fscope)

        def loop_test():
            return (ag__.converted_call(ag__.ld(tf).reduce_sum, (ag__.ld(x),), None, fscope) > 1)
        ag__.while_stmt(loop_test, loop_body, get_state, set_state, ('x',), {})
        try:
            do_return = True
            retval_ = ag__.ld(x)
        except:
            do_return = False
            raise
        return fscope.ret(retval_, do_return)


Điều kiện

AutoGraph sẽ chuyển đổi một số câu lệnh if <condition> thành các tf.cond gọi tf.cond tương đương. Sự thay thế này được thực hiện nếu <condition> là một Tensor. Nếu không, if tuyên bố được thực hiện như một điều kiện Python.

Một điều kiện Python thực thi trong quá trình theo dõi, vì vậy chính xác một nhánh của điều kiện sẽ được thêm vào biểu đồ. Nếu không có AutoGraph, biểu đồ theo dõi này sẽ không thể sử dụng nhánh thay thế nếu có luồng điều khiển phụ thuộc vào dữ liệu.

tf.cond theo dõi và thêm cả hai nhánh của điều kiện vào biểu đồ, chọn động một nhánh tại thời điểm thực thi. Truy tìm có thể có tác dụng phụ ngoài ý muốn; xem các hiệu ứng theo dõi AutoGraph để biết thêm.

@tf.function
def fizzbuzz(n):
  for i in tf.range(1, n + 1):
    print('Tracing for loop')
    if i % 15 == 0:
      print('Tracing fizzbuzz branch')
      tf.print('fizzbuzz')
    elif i % 3 == 0:
      print('Tracing fizz branch')
      tf.print('fizz')
    elif i % 5 == 0:
      print('Tracing buzz branch')
      tf.print('buzz')
    else:
      print('Tracing default branch')
      tf.print(i)

fizzbuzz(tf.constant(5))
fizzbuzz(tf.constant(20))
Tracing for loop
Tracing fizzbuzz branch
Tracing fizz branch
Tracing buzz branch
Tracing default branch
1
2
fizz
4
buzz
1
2
fizz
4
buzz
fizz
7
8
fizz
buzz
11
fizz
13
14
fizzbuzz
16
17
fizz
19
buzz

Xem tài liệu tham khảo để biết các hạn chế bổ sung đối với các câu lệnh if được chuyển đổi tự động.

Vòng lặp

AutoGraph sẽ chuyển đổi một số câu lệnh forwhile thành các hoạt động lặp TensorFlow tương đương, như tf.while_loop . tf.while_loop . Nếu không được chuyển đổi, for hay while vòng lặp được thực hiện như một vòng lặp Python.

Sự thay thế này được thực hiện trong các trường hợp sau:

  • for x in y : nếu y là Tensor, hãy chuyển đổi thành tf.while_loop . tf.while_loop . Trong trường hợp đặc biệt khi ytf.data.Dataset , kết hợp các hoạt động tf.data.Dataset được tạo ra.
  • while <condition> : nếu <condition> là Tensor, hãy chuyển đổi thành tf.while_loop . tf.while_loop .

Một vòng lặp Python thực thi trong quá trình truy tìm, thêm các hoạt động bổ sung vào tf.Graph cho mỗi lần lặp lại của vòng lặp.

Một vòng lặp TensorFlow theo dõi phần thân của vòng lặp và tự động chọn số lần lặp để chạy tại thời điểm thực thi. Phần thân của vòng lặp chỉ xuất hiện một lần trong tf.Graph được tạo.

Xem tài liệu tham khảo để biết thêm các hạn chế for câu lệnh forwhile chuyển đổi tự động.

Vòng qua dữ liệu Python

Một lỗi phổ biến là lặp qua dữ liệu Python / Numpy trong một tf.function . Vòng lặp này sẽ thực thi trong quá trình truy tìm, thêm một bản sao mô hình của bạn vào tf.Graph cho mỗi lần lặp lại của vòng lặp.

Nếu bạn muốn gói toàn bộ vòng huấn luyện trong tf.function , cách an toàn nhất để làm điều này là bọc dữ liệu của bạn dưới dạng tf.data.Dataset để AutoGraph sẽ tự động bỏ cuộn vòng huấn luyện.

def measure_graph_size(f, *args):
  g = f.get_concrete_function(*args).graph
  print("{}({}) contains {} nodes in its graph".format(
      f.__name__, ', '.join(map(str, args)), len(g.as_graph_def().node)))

@tf.function
def train(dataset):
  loss = tf.constant(0)
  for x, y in dataset:
    loss += tf.abs(y - x) # Some dummy computation.
  return loss

small_data = [(1, 1)] * 3
big_data = [(1, 1)] * 10
measure_graph_size(train, small_data)
measure_graph_size(train, big_data)

measure_graph_size(train, tf.data.Dataset.from_generator(
    lambda: small_data, (tf.int32, tf.int32)))
measure_graph_size(train, tf.data.Dataset.from_generator(
    lambda: big_data, (tf.int32, tf.int32)))
train([(1, 1), (1, 1), (1, 1)]) contains 11 nodes in its graph
train([(1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1)]) contains 32 nodes in its graph
train(<FlatMapDataset shapes: (<unknown>, <unknown>), types: (tf.int32, tf.int32)>) contains 8 nodes in its graph
train(<FlatMapDataset shapes: (<unknown>, <unknown>), types: (tf.int32, tf.int32)>) contains 8 nodes in its graph

Khi gói dữ liệu Python / Numpy trong Dataset, hãy lưu ý đến tf.data.Dataset.from_generator so với tf.data.Dataset.from_tensors . Cái trước sẽ giữ dữ liệu bằng Python và tìm nạp nó qua tf.py_function năng có thể có ý nghĩa về hiệu suất, trong khi cái sau sẽ gói một bản sao dữ liệu dưới dạng một nút tf.constant() trong biểu đồ, có thể có ý nghĩa về bộ nhớ.

Đọc dữ liệu từ các tập tin qua TFRecordDataset / CsvDataset / etc. là cách hiệu quả nhất để sử dụng dữ liệu, vì bản thân TensorFlow có thể quản lý việc tải và tìm nạp trước dữ liệu không đồng bộ mà không cần phải liên quan đến Python. Để tìm hiểu thêm, hãy xem hướng dẫn tf.data .

Tích lũy các giá trị trong một vòng lặp

Một mô hình phổ biến là tích lũy các giá trị trung gian từ một vòng lặp. Thông thường, điều này được thực hiện bằng cách thêm vào danh sách Python hoặc thêm các mục nhập vào từ điển Python. Tuy nhiên, vì đây là các tác dụng phụ của Python, chúng sẽ không hoạt động như mong đợi trong một vòng lặp được mở động. Sử dụng tf.TensorArray để tích lũy kết quả từ một vòng lặp không được cuộn động.

batch_size = 2
seq_len = 3
feature_size = 4

def rnn_step(inp, state):
  return inp + state

@tf.function
def dynamic_rnn(rnn_step, input_data, initial_state):
  # [batch, time, features] -> [time, batch, features]
  input_data = tf.transpose(input_data, [1, 0, 2])
  max_seq_len = input_data.shape[0]

  states = tf.TensorArray(tf.float32, size=max_seq_len)
  state = initial_state
  for i in tf.range(max_seq_len):
    state = rnn_step(input_data[i], state)
    states = states.write(i, state)
  return tf.transpose(states.stack(), [1, 0, 2])
  
dynamic_rnn(rnn_step,
            tf.random.uniform([batch_size, seq_len, feature_size]),
            tf.zeros([batch_size, feature_size]))
<tf.Tensor: shape=(2, 3, 4), dtype=float32, numpy=
array([[[0.9854791 , 0.5162524 , 0.14062047, 0.04950547],
        [1.8820469 , 0.67421603, 0.40786874, 0.7679055 ],
        [2.8815444 , 1.1567757 , 1.0627073 , 0.8880433 ]],

       [[0.94119024, 0.19776726, 0.24890792, 0.4663092 ],
        [1.4591933 , 1.123581  , 0.35438073, 1.4392309 ],
        [2.0026946 , 1.9165647 , 0.37988353, 1.8128917 ]]], dtype=float32)>

đọc thêm

Để tìm hiểu về cách xuất và tải một Function , hãy xem hướng dẫn SavedModel . Để tìm hiểu thêm về tối ưu hóa biểu đồ được thực hiện sau khi theo dõi, hãy xem hướng dẫn Grappler . Để tìm hiểu cách tối ưu hóa đường ống dữ liệu và lập hồ sơ cho mô hình của bạn, hãy xem hướng dẫn về Hồ sơ .