Tham dự Hội nghị chuyên đề Women in ML vào ngày 7 tháng 12 Đăng ký ngay

Hiệu suất tốt hơn với tf. Chức năng

Sử dụng bộ sưu tập để sắp xếp ngăn nắp các trang Lưu và phân loại nội dung dựa trên lựa chọn ưu tiên của bạn.

Xem trên TensorFlow.org Chạy trong Google Colab Xem nguồn trên GitHubTải xuống sổ ghi chép

Trong TensorFlow 2, thực thi háo hức được bật theo mặc định. Giao diện người dùng trực quan và linh hoạt (chạy các hoạt động một lần dễ dàng hơn và nhanh hơn nhiều), nhưng điều này có thể phải trả giá bằng hiệu suất và khả năng triển khai.

Bạn có thể sử dụng tf.function để tạo đồ thị từ các chương trình của mình. Nó là một công cụ chuyển đổi tạo đồ thị luồng dữ liệu độc lập với Python từ mã Python của bạn. Điều này sẽ giúp bạn tạo các mô hình hiệu quả và di động, đồng thời bắt buộc phải sử dụng SavedModel .

Hướng dẫn này sẽ giúp bạn hình dung cách hoạt động của tf.function , vì vậy bạn có thể sử dụng nó một cách hiệu quả.

Các điểm rút ra và khuyến nghị chính là:

  • Gỡ lỗi ở chế độ háo hức, sau đó trang trí với @tf.function . function.
  • Đừng dựa vào các tác dụng phụ của Python như đột biến đối tượng hoặc nối thêm danh sách.
  • tf.function hoạt động tốt nhất với hoạt động TensorFlow; Các cuộc gọi NumPy và Python được chuyển đổi thành hằng số.

Thành lập

import tensorflow as tf

Xác định một chức năng trợ giúp để chứng minh các loại lỗi bạn có thể gặp phải:

import traceback
import contextlib

# Some helper code to demonstrate the kinds of errors you might encounter.
@contextlib.contextmanager
def assert_raises(error_class):
  try:
    yield
  except error_class as e:
    print('Caught expected exception \n  {}:'.format(error_class))
    traceback.print_exc(limit=2)
  except Exception as e:
    raise e
  else:
    raise Exception('Expected {} to be raised but no error was raised!'.format(
        error_class))

Khái niệm cơ bản

Cách sử dụng

Một Function bạn xác định (ví dụ bằng cách áp dụng trình trang trí @tf.function năng) cũng giống như một hoạt động TensorFlow cốt lõi: Bạn có thể thực thi nó một cách hăng hái; bạn có thể tính toán độ dốc; và như thế.

@tf.function  # The decorator converts `add` into a `Function`.
def add(a, b):
  return a + b

add(tf.ones([2, 2]), tf.ones([2, 2]))  #  [[2., 2.], [2., 2.]]
<tf.Tensor: shape=(2, 2), dtype=float32, numpy=
array([[2., 2.],
       [2., 2.]], dtype=float32)>
v = tf.Variable(1.0)
with tf.GradientTape() as tape:
  result = add(v, 1.0)
tape.gradient(result, v)
<tf.Tensor: shape=(), dtype=float32, numpy=1.0>

Bạn có thể sử dụng các Function bên trong các Function khác.

@tf.function
def dense_layer(x, w, b):
  return add(tf.matmul(x, w), b)

dense_layer(tf.ones([3, 2]), tf.ones([2, 2]), tf.ones([2]))
<tf.Tensor: shape=(3, 2), dtype=float32, numpy=
array([[3., 3.],
       [3., 3.],
       [3., 3.]], dtype=float32)>

Function s có thể nhanh hơn mã háo hức, đặc biệt đối với đồ thị có nhiều hoạt động nhỏ. Nhưng đối với các đồ thị có một vài hoạt động đắt tiền (như chập), bạn có thể không thấy tốc độ tăng lên nhiều.

import timeit
conv_layer = tf.keras.layers.Conv2D(100, 3)

@tf.function
def conv_fn(image):
  return conv_layer(image)

image = tf.zeros([1, 200, 200, 100])
# Warm up
conv_layer(image); conv_fn(image)
print("Eager conv:", timeit.timeit(lambda: conv_layer(image), number=10))
print("Function conv:", timeit.timeit(lambda: conv_fn(image), number=10))
print("Note how there's not much difference in performance for convolutions")
Eager conv: 0.006058974999177735
Function conv: 0.005791576000774512
Note how there's not much difference in performance for convolutions

Truy tìm

Phần này trình bày cách Function hoạt động ngầm, bao gồm các chi tiết triển khai có thể thay đổi trong tương lai . Tuy nhiên, một khi bạn hiểu tại sao và khi nào thì việc truy tìm xảy ra, việc sử dụng tf.function một cách hiệu quả sẽ dễ dàng hơn nhiều!

"Truy tìm" là gì?

Một Function chạy chương trình của bạn trong một Đồ thị TensorFlow . Tuy nhiên, một tf.Graph không thể đại diện cho tất cả những thứ bạn viết trong một chương trình TensorFlow háo hức. Ví dụ: Python hỗ trợ tính đa hình, nhưng tf.Graph yêu cầu đầu vào của nó phải có kiểu dữ liệu và thứ nguyên được chỉ định. Hoặc bạn có thể thực hiện các tác vụ phụ như đọc các đối số dòng lệnh, nêu lỗi hoặc làm việc với một đối tượng Python phức tạp hơn; không có thứ nào trong số này có thể chạy trong tf.Graph .

Function thu hẹp khoảng cách này bằng cách tách mã của bạn thành hai giai đoạn:

1) Trong giai đoạn đầu tiên, được gọi là " truy tìm ", Function tạo một tf.Graph mới. Mã Python chạy bình thường, nhưng tất cả các hoạt động TensorFlow (như thêm hai Tensor) đều bị hoãn lại : chúng bị tf.Graph nắm bắt và không chạy.

2) Trong giai đoạn thứ hai, một tf.Graph chứa mọi thứ đã bị trì hoãn trong giai đoạn đầu tiên được chạy. Giai đoạn này nhanh hơn nhiều so với giai đoạn truy tìm.

Tùy thuộc vào các đầu vào của nó, Function sẽ không phải lúc nào cũng chạy ở giai đoạn đầu tiên khi nó được gọi. Xem "Quy tắc theo dõi" bên dưới để hiểu rõ hơn về cách nó đưa ra quyết định đó. Bỏ qua giai đoạn đầu tiên và chỉ thực hiện giai đoạn thứ hai là những gì mang lại cho bạn hiệu suất cao của TensorFlow.

Khi Function quyết định theo dõi, giai đoạn theo dõi ngay lập tức được theo sau bởi giai đoạn thứ hai, vì vậy việc gọi Function vừa tạo và chạy tf.Graph . Sau đó, bạn sẽ thấy cách bạn có thể chạy chỉ giai đoạn theo dõi với get_concrete_function .

Khi bạn chuyển các đối số của các kiểu khác nhau vào một Function , cả hai giai đoạn đều được chạy:

@tf.function
def double(a):
  print("Tracing with", a)
  return a + a

print(double(tf.constant(1)))
print()
print(double(tf.constant(1.1)))
print()
print(double(tf.constant("a")))
print()
Tracing with Tensor("a:0", shape=(), dtype=int32)
tf.Tensor(2, shape=(), dtype=int32)

Tracing with Tensor("a:0", shape=(), dtype=float32)
tf.Tensor(2.2, shape=(), dtype=float32)

Tracing with Tensor("a:0", shape=(), dtype=string)
tf.Tensor(b'aa', shape=(), dtype=string)

Lưu ý rằng nếu bạn gọi nhiều lần một Function có cùng loại đối số, TensorFlow sẽ bỏ qua giai đoạn theo dõi và sử dụng lại một biểu đồ đã theo dõi trước đó, vì biểu đồ được tạo sẽ giống hệt nhau.

# This doesn't print 'Tracing with ...'
print(double(tf.constant("b")))
tf.Tensor(b'bb', shape=(), dtype=string)

Bạn có thể sử dụng pretty_printed_concrete_signatures() để xem tất cả các dấu vết có sẵn:

print(double.pretty_printed_concrete_signatures())
double(a)
  Args:
    a: int32 Tensor, shape=()
  Returns:
    int32 Tensor, shape=()

double(a)
  Args:
    a: float32 Tensor, shape=()
  Returns:
    float32 Tensor, shape=()

double(a)
  Args:
    a: string Tensor, shape=()
  Returns:
    string Tensor, shape=()

Cho đến nay, bạn đã thấy rằng tf.function tạo ra một lớp điều phối động, được lưu trong bộ nhớ cache trên logic theo dõi đồ thị của TensorFlow. Để cụ thể hơn về thuật ngữ:

  • Một tf.Graph là bản đại diện thô, không có ngôn ngữ, di động của một phép tính TensorFlow.
  • Một ConcreteFunction kết thúc một tf.Graph .
  • Function quản lý một bộ nhớ cache của các ConcreteFunction và chọn một bộ nhớ đệm phù hợp cho đầu vào của bạn.
  • tf.function kết thúc một hàm Python, trả về một đối tượng Function .
  • Tracing tạo ra một tf.Graph và bao bọc nó trong một ConcreteFunction , còn được gọi là một dấu vết.

Quy tắc truy tìm

Một Function xác định xem có sử dụng lại ConcreteFunction được theo dõi hay không bằng cách tính toán khóa bộ nhớ cache từ các args và kwargs của đầu vào. Khóa bộ nhớ cache là khóa xác định ConcreteFunction dựa trên các args và kwargs đầu vào của lệnh gọi Function , theo các quy tắc sau (có thể thay đổi):

  • Chìa khóa được tạo ra cho tf.Tensor là hình dạng và kiểu của nó.
  • Khóa được tạo cho tf.Variable là một id biến duy nhất.
  • Khóa được tạo cho một nguyên thủy Python (như int , float , str ) là giá trị của nó.
  • Khóa được tạo cho các dict lồng nhau, list s, tuple s, namedtuple s và attr s là bộ khóa lá được làm phẳng (xem nest.flatten ). (Kết quả của việc làm phẳng này, việc gọi một hàm bê tông có cấu trúc lồng khác với cấu trúc được sử dụng trong quá trình truy tìm sẽ dẫn đến Lỗi loại).
  • Đối với tất cả các kiểu Python khác, khóa là duy nhất cho đối tượng. Bằng cách này, một hàm hoặc phương thức được truy tìm độc lập cho mỗi trường hợp mà nó được gọi.

Kiểm soát việc kiểm tra lại

Retracing, đó là khi Function của bạn tạo nhiều hơn một dấu vết, giúp đảm bảo rằng TensorFlow tạo ra các đồ thị chính xác cho từng nhóm đầu vào. Tuy nhiên, truy tìm là một hoạt động tốn kém! Nếu Function của bạn truy xuất lại một biểu đồ mới cho mỗi lần gọi, bạn sẽ thấy rằng mã của bạn thực thi chậm hơn so với khi bạn không sử dụng tf.function . function.

Để kiểm soát hành vi theo dõi, bạn có thể sử dụng các kỹ thuật sau:

  • Chỉ định input_signature trong tf.function để hạn chế việc theo dõi.
@tf.function(input_signature=(tf.TensorSpec(shape=[None], dtype=tf.int32),))
def next_collatz(x):
  print("Tracing with", x)
  return tf.where(x % 2 == 0, x // 2, 3 * x + 1)

print(next_collatz(tf.constant([1, 2])))
# You specified a 1-D tensor in the input signature, so this should fail.
with assert_raises(ValueError):
  next_collatz(tf.constant([[1, 2], [3, 4]]))

# You specified an int32 dtype in the input signature, so this should fail.
with assert_raises(ValueError):
  next_collatz(tf.constant([1.0, 2.0]))
Tracing with Tensor("x:0", shape=(None,), dtype=int32)
tf.Tensor([4 1], shape=(2,), dtype=int32)
Caught expected exception 
  <class 'ValueError'>:
Caught expected exception 
  <class 'ValueError'>:
Traceback (most recent call last):
  File "/tmp/ipykernel_26244/3551158538.py", line 8, in assert_raises
    yield
  File "/tmp/ipykernel_26244/1851403433.py", line 9, in <module>
    next_collatz(tf.constant([[1, 2], [3, 4]]))
ValueError: Python inputs incompatible with input_signature:
  inputs: (
    tf.Tensor(
[[1 2]
 [3 4]], shape=(2, 2), dtype=int32))
  input_signature: (
    TensorSpec(shape=(None,), dtype=tf.int32, name=None)).
Traceback (most recent call last):
  File "/tmp/ipykernel_26244/3551158538.py", line 8, in assert_raises
    yield
  File "/tmp/ipykernel_26244/1851403433.py", line 13, in <module>
    next_collatz(tf.constant([1.0, 2.0]))
ValueError: Python inputs incompatible with input_signature:
  inputs: (
    tf.Tensor([1. 2.], shape=(2,), dtype=float32))
  input_signature: (
    TensorSpec(shape=(None,), dtype=tf.int32, name=None)).
  • Chỉ định thứ nguyên [Không có] trong tf.TensorSpec để cho phép sử dụng lại dấu vết một cách linh hoạt.

    Vì TensorFlow so khớp các tensor dựa trên hình dạng của chúng, việc sử dụng thứ nguyên None làm ký tự đại diện sẽ cho phép các Function sử dụng lại dấu vết cho đầu vào có kích thước thay đổi. Đầu vào có kích thước khác nhau có thể xảy ra nếu bạn có các chuỗi có độ dài khác nhau hoặc hình ảnh có kích thước khác nhau cho mỗi lô (Ví dụ: Xem hướng dẫn về TransformerDeep Dream ).

@tf.function(input_signature=(tf.TensorSpec(shape=[None], dtype=tf.int32),))
def g(x):
  print('Tracing with', x)
  return x

# No retrace!
print(g(tf.constant([1, 2, 3])))
print(g(tf.constant([1, 2, 3, 4, 5])))
Tracing with Tensor("x:0", shape=(None,), dtype=int32)
tf.Tensor([1 2 3], shape=(3,), dtype=int32)
tf.Tensor([1 2 3 4 5], shape=(5,), dtype=int32)
  • Truyền các đối số Python đến Tensors để giảm việc rút lại.

    Thông thường, các đối số trong Python được sử dụng để kiểm soát các siêu tham số và cấu trúc đồ thị - ví dụ: num_layers=10 hoặc training=True hoặc nonlinearity='relu' . Vì vậy, nếu đối số Python thay đổi, có nghĩa là bạn phải truy xuất lại biểu đồ.

    Tuy nhiên, có thể một đối số Python không được sử dụng để kiểm soát việc xây dựng đồ thị. Trong những trường hợp này, một sự thay đổi trong giá trị Python có thể kích hoạt việc kiểm tra lại không cần thiết. Lấy ví dụ, vòng lặp đào tạo này, AutoGraph sẽ tự động hủy cuộn. Mặc dù có nhiều dấu vết, nhưng biểu đồ được tạo thực sự giống hệt nhau, vì vậy việc kiểm tra lại là không cần thiết.

def train_one_step():
  pass

@tf.function
def train(num_steps):
  print("Tracing with num_steps = ", num_steps)
  tf.print("Executing with num_steps = ", num_steps)
  for _ in tf.range(num_steps):
    train_one_step()

print("Retracing occurs for different Python arguments.")
train(num_steps=10)
train(num_steps=20)

print()
print("Traces are reused for Tensor arguments.")
train(num_steps=tf.constant(10))
train(num_steps=tf.constant(20))
Retracing occurs for different Python arguments.
Tracing with num_steps =  10
Executing with num_steps =  10
Tracing with num_steps =  20
Executing with num_steps =  20

Traces are reused for Tensor arguments.
Tracing with num_steps =  Tensor("num_steps:0", shape=(), dtype=int32)
Executing with num_steps =  10
Executing with num_steps =  20

Nếu bạn cần buộc phải rút lại, hãy tạo một Function mới. Các đối tượng Function riêng biệt được đảm bảo không chia sẻ dấu vết.

def f():
  print('Tracing!')
  tf.print('Executing')

tf.function(f)()
tf.function(f)()
Tracing!
Executing
Tracing!
Executing

Có được các chức năng cụ thể

Mỗi khi một chức năng được truy tìm, một chức năng cụ thể mới được tạo ra. Bạn có thể lấy trực tiếp một hàm cụ thể bằng cách sử dụng get_concrete_function .

print("Obtaining concrete trace")
double_strings = double.get_concrete_function(tf.constant("a"))
print("Executing traced function")
print(double_strings(tf.constant("a")))
print(double_strings(a=tf.constant("b")))
Obtaining concrete trace
Executing traced function
tf.Tensor(b'aa', shape=(), dtype=string)
tf.Tensor(b'bb', shape=(), dtype=string)
# You can also call get_concrete_function on an InputSpec
double_strings_from_inputspec = double.get_concrete_function(tf.TensorSpec(shape=[], dtype=tf.string))
print(double_strings_from_inputspec(tf.constant("c")))
tf.Tensor(b'cc', shape=(), dtype=string)

Việc in một ConcreteFunction hiển thị một bản tóm tắt các đối số đầu vào của nó (với các kiểu) và kiểu đầu ra của nó.

print(double_strings)
ConcreteFunction double(a)
  Args:
    a: string Tensor, shape=()
  Returns:
    string Tensor, shape=()

Bạn cũng có thể truy xuất trực tiếp chữ ký của một hàm cụ thể.

print(double_strings.structured_input_signature)
print(double_strings.structured_outputs)
((TensorSpec(shape=(), dtype=tf.string, name='a'),), {})
Tensor("Identity:0", shape=(), dtype=string)

Sử dụng dấu vết cụ thể với các loại không tương thích sẽ gây ra lỗi

with assert_raises(tf.errors.InvalidArgumentError):
  double_strings(tf.constant(1))
Caught expected exception 
  <class 'tensorflow.python.framework.errors_impl.InvalidArgumentError'>:
Traceback (most recent call last):
  File "/tmp/ipykernel_26244/3551158538.py", line 8, in assert_raises
    yield
  File "/tmp/ipykernel_26244/3196284684.py", line 2, in <module>
    double_strings(tf.constant(1))
tensorflow.python.framework.errors_impl.InvalidArgumentError: cannot compute __inference_double_162 as input #0(zero-based) was expected to be a string tensor but is a int32 tensor [Op:__inference_double_162]

Bạn có thể nhận thấy rằng các đối số Python được xử lý đặc biệt trong chữ ký đầu vào của một hàm cụ thể. Trước TensorFlow 2.3, các đối số trong Python chỉ đơn giản là bị xóa khỏi chữ ký của hàm cụ thể. Bắt đầu với TensorFlow 2.3, các đối số Python vẫn còn trong chữ ký, nhưng bị hạn chế lấy giá trị được đặt trong quá trình truy tìm.

@tf.function
def pow(a, b):
  return a ** b

square = pow.get_concrete_function(a=tf.TensorSpec(None, tf.float32), b=2)
print(square)
ConcreteFunction pow(a, b=2)
  Args:
    a: float32 Tensor, shape=<unknown>
  Returns:
    float32 Tensor, shape=<unknown>
assert square(tf.constant(10.0)) == 100

with assert_raises(TypeError):
  square(tf.constant(10.0), b=3)
Caught expected exception 
  <class 'TypeError'>:
Traceback (most recent call last):
  File "/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/eager/function.py", line 1721, in _call_impl
    cancellation_manager)
  File "/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/eager/function.py", line 1765, in _call_with_flat_signature
    raise TypeError(f"{self._flat_signature_summary()} got unexpected "
TypeError: pow(a) got unexpected keyword arguments: b.

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/tmp/ipykernel_26244/3551158538.py", line 8, in assert_raises
    yield
  File "/tmp/ipykernel_26244/2310937119.py", line 4, in <module>
    square(tf.constant(10.0), b=3)
TypeError: ConcreteFunction pow(a, b) was constructed with int value 2 in b, but was called with int value 3.

Lấy đồ thị

Mỗi hàm cụ thể là một trình bao bọc có thể gọi xung quanh một tf.Graph . Mặc dù truy xuất đối tượng tf.Graph thực tế không phải là điều bạn thường cần làm, nhưng bạn có thể lấy nó dễ dàng từ bất kỳ hàm cụ thể nào.

graph = double_strings.graph
for node in graph.as_graph_def().node:
  print(f'{node.input} -> {node.name}')
[] -> a
['a', 'a'] -> add
['add'] -> Identity

Gỡ lỗi

Nói chung, mã gỡ lỗi trong chế độ háo hức dễ dàng hơn bên trong tf.function . Chức năng. Bạn nên đảm bảo rằng mã của bạn thực thi không có lỗi ở chế độ háo hức trước khi trang trí bằng tf.function . function. Để hỗ trợ quá trình gỡ lỗi, bạn có thể gọi tf.config.run_functions_eagerly(True) để tắt và kích tf.function . Chức năng trên toàn cầu.

Khi theo dõi các vấn đề chỉ xuất hiện trong tf.function . function, đây là một số mẹo:

  • Các lệnh gọi print Python cũ đơn thuần chỉ thực thi trong quá trình truy tìm, giúp bạn theo dõi khi hàm của bạn được truy tìm (lại).
  • Các lệnh gọi tf.print sẽ thực hiện mọi lúc và có thể giúp bạn theo dõi các giá trị trung gian trong quá trình thực thi.
  • tf.debugging.enable_check_numerics là một cách dễ dàng để theo dõi nơi tạo NaN và Inf.
  • pdb ( trình gỡ lỗi Python ) có thể giúp bạn hiểu những gì đang xảy ra trong quá trình truy tìm. (Lưu ý: pdb sẽ đưa bạn vào mã nguồn được chuyển đổi AutoGraph.)

Các phép biến đổi AutoGraph

AutoGraph là một thư viện được bật theo mặc định trong tf.function . function và chuyển đổi một tập hợp con của mã háo hức Python thành các hoạt động TensorFlow tương thích với đồ thị. Điều này bao gồm luồng điều khiển như if , for , while .

Các hoạt động TensorFlow như tf.condtf.while_loop tiếp tục hoạt động, nhưng luồng điều khiển thường dễ viết và dễ hiểu hơn khi được viết bằng Python.

# A simple loop

@tf.function
def f(x):
  while tf.reduce_sum(x) > 1:
    tf.print(x)
    x = tf.tanh(x)
  return x

f(tf.random.uniform([5]))
[0.666458249 0.713946581 0.723879576 0.330758929 0.184087753]
[0.582645297 0.613145649 0.619306684 0.319202513 0.182036072]
[0.524585426 0.546337605 0.550645113 0.308785647 0.18005164]
[0.481231302 0.497770309 0.501003504 0.299331933 0.178130865]
[0.447229207 0.460361809 0.462906033 0.290701121 0.176270396]
[0.419618756 0.430379033 0.432449728 0.282779962 0.174467146]
[0.396609187 0.405638 0.407366514 0.275476 0.172718227]
[0.377043903 0.384762734 0.386234313 0.268712848 0.17102097]
[0.360137492 0.366836458 0.368109286 0.262426734 0.169372901]
[0.345335096 0.351221472 0.352336824 0.256563932 0.167771652]
[0.332231969 0.337458342 0.338446289 0.251078814 0.166215062]
[0.320524871 0.325206399 0.326089561 0.24593246 0.164701089]
[0.309981436 0.314206958 0.31500268 0.241091311 0.163227797]
[0.300420195 0.304259449 0.304981351 0.236526251 0.161793426]
[0.291697085 0.295205742 0.295864582 0.232211992 0.160396278]
[0.283696055 0.286919087 0.287523568 0.228126258 0.159034774]
[0.276322395 0.279296666 0.27985391 0.224249557 0.157707423]
[0.269497961 0.272254 0.272769839 0.220564634 0.15641281]
[0.263157606 0.265720904 0.266200244 0.21705614 0.155149609]
[0.257246554 0.259638608 0.260085613 0.213710397 0.153916568]
[0.251718313 0.25395745 0.254375577 0.210515186 0.152712509]
[0.246533215 0.248635098 0.249027327 0.207459539 0.151536316]
[0.241657034 0.243635193 0.244004101 0.204533577 0.15038693]
[0.237060249 0.238926381 0.239274174 0.201728329 0.149263337]
[0.232717097 0.234481394 0.234810054 0.199035719 0.148164615]
[0.228605017 0.230276451 0.230587661 0.196448416 0.147089839]
[0.224704206 0.226290658 0.22658591 0.193959698 0.14603813]
[0.220997125 0.222505584 0.222786173 0.191563457 0.145008713]
<tf.Tensor: shape=(5,), dtype=float32, numpy=
array([0.21746822, 0.21890487, 0.21917202, 0.18925412, 0.14400077],
      dtype=float32)>

Nếu bạn tò mò, bạn có thể kiểm tra mã tạo ra chữ ký.

print(tf.autograph.to_code(f.python_function))
def tf__f(x):
    with ag__.FunctionScope('f', 'fscope', ag__.ConversionOptions(recursive=True, user_requested=True, optional_features=(), internal_convert_user_code=True)) as fscope:
        do_return = False
        retval_ = ag__.UndefinedReturnValue()

        def get_state():
            return (x,)

        def set_state(vars_):
            nonlocal x
            (x,) = vars_

        def loop_body():
            nonlocal x
            ag__.converted_call(ag__.ld(tf).print, (ag__.ld(x),), None, fscope)
            x = ag__.converted_call(ag__.ld(tf).tanh, (ag__.ld(x),), None, fscope)

        def loop_test():
            return (ag__.converted_call(ag__.ld(tf).reduce_sum, (ag__.ld(x),), None, fscope) > 1)
        ag__.while_stmt(loop_test, loop_body, get_state, set_state, ('x',), {})
        try:
            do_return = True
            retval_ = ag__.ld(x)
        except:
            do_return = False
            raise
        return fscope.ret(retval_, do_return)

Điều kiện

AutoGraph sẽ chuyển đổi một số câu lệnh if <condition> thành các lệnh gọi tf.cond tương đương. Sự thay thế này được thực hiện nếu <condition> là một Tensor. Ngược lại, câu lệnh if được thực thi dưới dạng điều kiện Python.

Một điều kiện Python thực thi trong quá trình theo dõi, vì vậy chính xác một nhánh của điều kiện sẽ được thêm vào biểu đồ. Nếu không có AutoGraph, biểu đồ theo dõi này sẽ không thể lấy nhánh thay thế nếu có luồng điều khiển phụ thuộc vào dữ liệu.

tf.cond theo dõi và thêm cả hai nhánh của điều kiện vào biểu đồ, chọn động một nhánh tại thời điểm thực thi. Truy tìm có thể có tác dụng phụ ngoài ý muốn; kiểm tra các hiệu ứng theo dõi AutoGraph để biết thêm thông tin.

@tf.function
def fizzbuzz(n):
  for i in tf.range(1, n + 1):
    print('Tracing for loop')
    if i % 15 == 0:
      print('Tracing fizzbuzz branch')
      tf.print('fizzbuzz')
    elif i % 3 == 0:
      print('Tracing fizz branch')
      tf.print('fizz')
    elif i % 5 == 0:
      print('Tracing buzz branch')
      tf.print('buzz')
    else:
      print('Tracing default branch')
      tf.print(i)

fizzbuzz(tf.constant(5))
fizzbuzz(tf.constant(20))
Tracing for loop
Tracing fizzbuzz branch
Tracing fizz branch
Tracing buzz branch
Tracing default branch
1
2
fizz
4
buzz
1
2
fizz
4
buzz
fizz
7
8
fizz
buzz
11
fizz
13
14
fizzbuzz
16
17
fizz
19
buzz

Xem tài liệu tham khảo để biết các hạn chế bổ sung đối với các câu lệnh if được chuyển đổi tự động.

Vòng lặp

AutoGraph sẽ chuyển đổi một số câu lệnh forwhile thành các hoạt động lặp TensorFlow tương đương, như tf. tf.while_loop . Nếu không được chuyển đổi, vòng lặp for hoặc while được thực thi như một vòng lặp Python.

Sự thay thế này được thực hiện trong các trường hợp sau:

  • for x in y : nếu y là Tensor, hãy chuyển đổi thành tf. tf.while_loop . Trong trường hợp đặc biệt khi ytf.data.Dataset , kết hợp các hoạt tf.data.Dataset được tạo ra.
  • while <condition> : nếu <condition> là Tensor, hãy chuyển đổi thành tf. tf.while_loop .

Một vòng lặp Python thực thi trong quá trình truy tìm, thêm các hoạt động bổ sung vào tf.Graph cho mỗi lần lặp lại của vòng lặp.

Một vòng lặp TensorFlow theo dõi phần nội dung của vòng lặp và tự động chọn số lần lặp để chạy tại thời điểm thực thi. Phần thân của vòng lặp chỉ xuất hiện một lần trong tf.Graph tạo.

Xem tài liệu tham khảo để biết thêm các hạn chế đối với các câu lệnh forwhile được chuyển đổi tự động.

Vòng qua dữ liệu Python

Một lỗi phổ biến là lặp qua dữ liệu Python / NumPy trong một tf.function . Vòng lặp này sẽ thực thi trong quá trình truy tìm, thêm một bản sao mô hình của bạn vào tf.Graph cho mỗi lần lặp lại của vòng lặp.

Nếu bạn muốn gói toàn bộ vòng huấn luyện trong tf.function . function, cách an toàn nhất để thực hiện việc này là bọc dữ liệu của bạn dưới dạng tf.data.Dataset để AutoGraph sẽ tự động bỏ cuộn vòng huấn luyện.

def measure_graph_size(f, *args):
  g = f.get_concrete_function(*args).graph
  print("{}({}) contains {} nodes in its graph".format(
      f.__name__, ', '.join(map(str, args)), len(g.as_graph_def().node)))

@tf.function
def train(dataset):
  loss = tf.constant(0)
  for x, y in dataset:
    loss += tf.abs(y - x) # Some dummy computation.
  return loss

small_data = [(1, 1)] * 3
big_data = [(1, 1)] * 10
measure_graph_size(train, small_data)
measure_graph_size(train, big_data)

measure_graph_size(train, tf.data.Dataset.from_generator(
    lambda: small_data, (tf.int32, tf.int32)))
measure_graph_size(train, tf.data.Dataset.from_generator(
    lambda: big_data, (tf.int32, tf.int32)))
train([(1, 1), (1, 1), (1, 1)]) contains 11 nodes in its graph
train([(1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1)]) contains 32 nodes in its graph
train(<FlatMapDataset shapes: (<unknown>, <unknown>), types: (tf.int32, tf.int32)>) contains 6 nodes in its graph
train(<FlatMapDataset shapes: (<unknown>, <unknown>), types: (tf.int32, tf.int32)>) contains 6 nodes in its graph

Khi gói dữ liệu Python / NumPy trong Dataset, hãy lưu ý đến tf.data.Dataset.from_generator so với tf.data.Dataset.from_tensors . Cái trước sẽ giữ dữ liệu bằng Python và tìm nạp nó qua tf.py_function năng có thể có ý nghĩa về hiệu suất, trong khi cái sau sẽ gói một bản sao dữ liệu dưới dạng một nút tf.constant() lớn trong biểu đồ, có thể có ý nghĩa về bộ nhớ.

Đọc dữ liệu từ các tệp thông qua TFRecordDataset , CsvDataset , v.v. là cách hiệu quả nhất để tiêu thụ dữ liệu, vì bản thân TensorFlow có thể quản lý tải không đồng bộ và tìm nạp trước dữ liệu mà không cần phải liên quan đến Python. Để tìm hiểu thêm, hãy xem tf.data : Hướng dẫn xây dựng đường ống đầu vào TensorFlow .

Tích lũy các giá trị trong một vòng lặp

Một mô hình phổ biến là tích lũy các giá trị trung gian từ một vòng lặp. Thông thường, điều này được thực hiện bằng cách thêm vào danh sách Python hoặc thêm các mục nhập vào từ điển Python. Tuy nhiên, vì đây là các tác dụng phụ của Python, chúng sẽ không hoạt động như mong đợi trong một vòng lặp được mở động. Sử dụng tf.TensorArray để tích lũy kết quả từ một vòng lặp không được cuộn động.

batch_size = 2
seq_len = 3
feature_size = 4

def rnn_step(inp, state):
  return inp + state

@tf.function
def dynamic_rnn(rnn_step, input_data, initial_state):
  # [batch, time, features] -> [time, batch, features]
  input_data = tf.transpose(input_data, [1, 0, 2])
  max_seq_len = input_data.shape[0]

  states = tf.TensorArray(tf.float32, size=max_seq_len)
  state = initial_state
  for i in tf.range(max_seq_len):
    state = rnn_step(input_data[i], state)
    states = states.write(i, state)
  return tf.transpose(states.stack(), [1, 0, 2])

dynamic_rnn(rnn_step,
            tf.random.uniform([batch_size, seq_len, feature_size]),
            tf.zeros([batch_size, feature_size]))
<tf.Tensor: shape=(2, 3, 4), dtype=float32, numpy=
array([[[0.06309307, 0.9938811 , 0.90789986, 0.42136216],
        [0.44997275, 1.9107027 , 1.0716251 , 0.717237  ],
        [0.6026064 , 2.1622117 , 1.4164022 , 1.4153863 ]],

       [[0.04946005, 0.69127274, 0.56848884, 0.22406638],
        [0.8148316 , 1.0278493 , 0.6207781 , 1.1935129 ],
        [0.9178308 , 1.320889  , 0.989761  , 2.0120025 ]]], dtype=float32)>

Hạn chế

Function TensorFlow có một vài hạn chế theo thiết kế mà bạn nên biết khi chuyển đổi một hàm Python thành một Function .

Thực thi các tác dụng phụ của Python

Các tác dụng phụ, như in, thêm vào danh sách và thay đổi hình cầu, có thể hoạt động không mong muốn bên trong một Function , đôi khi thực thi hai lần hoặc không phải tất cả. Chúng chỉ xảy ra lần đầu tiên bạn gọi một Function với một tập hợp các đầu vào. Sau đó, tf.Graph được truy tìm được thực thi lại mà không thực thi mã Python.

Nguyên tắc chung là tránh dựa vào các tác dụng phụ của Python trong logic của bạn và chỉ sử dụng chúng để gỡ lỗi các dấu vết của bạn. Mặt khác, các API TensorFlow như tf.data , tf.print , tf.summary , tf.Variable.assigntf.TensorArray là cách tốt nhất để đảm bảo mã của bạn sẽ được thực thi bởi thời gian chạy TensorFlow với mỗi lần gọi.

@tf.function
def f(x):
  print("Traced with", x)
  tf.print("Executed with", x)

f(1)
f(1)
f(2)
Traced with 1
Executed with 1
Executed with 1
Traced with 2
Executed with 2

Nếu bạn muốn thực thi mã Python trong mỗi lần gọi một Function , tf.py_function Chức năng là một lối thoát. Hạn chế của tf.py_function là nó không di động hoặc đặc biệt hiệu quả, không thể lưu bằng SavedModel và không hoạt động tốt trong các thiết lập phân tán (đa GPU, TPU). Ngoài ra, vì tf.py_function phải được kết nối với biểu đồ, nó chuyển tất cả các đầu vào / đầu ra thành tensor.

Thay đổi các biến toàn cầu và miễn phí của Python

Việc thay đổi các biến toàn cầu và miễn phí của Python được coi là một tác dụng phụ của Python, vì vậy nó chỉ xảy ra trong quá trình truy tìm.

external_list = []

@tf.function
def side_effect(x):
  print('Python side effect')
  external_list.append(x)

side_effect(1)
side_effect(1)
side_effect(1)
# The list append only happened once!
assert len(external_list) == 1
Python side effect

Đôi khi những hành vi bất ngờ rất khó nhận thấy. Trong ví dụ dưới đây, bộ counter được dùng để bảo vệ số gia của một biến. Tuy nhiên, vì nó là một số nguyên python và không phải là một đối tượng TensorFlow, nên giá trị của nó được ghi lại trong lần theo dõi đầu tiên. Khi hàm tf.function được sử dụng, assign_add sẽ được ghi lại vô điều kiện trong biểu đồ bên dưới. Do đó v sẽ tăng 1, mỗi khi hàm tf.function được gọi. Sự cố này thường xảy ra ở những người dùng cố gắng di chuyển mã Tensorflow ở chế độ Grpah của họ sang Tensorflow 2 bằng cách sử dụng trình trang trí tf.function , khi các tác dụng phụ của python (bộ counter trong ví dụ) được sử dụng để xác định các hoạt động sẽ chạy (trong ví dụ là assign_add ). Thông thường, người dùng chỉ nhận ra điều này sau khi nhìn thấy kết quả số đáng ngờ hoặc hiệu suất thấp hơn đáng kể so với mong đợi (ví dụ: nếu hoạt động được bảo vệ rất tốn kém).

class Model(tf.Module):
  def __init__(self):
    self.v = tf.Variable(0)
    self.counter = 0

  @tf.function
  def __call__(self):
    if self.counter == 0:
      # A python side-effect
      self.counter += 1
      self.v.assign_add(1)

    return self.v

m = Model()
for n in range(3):
  print(m().numpy()) # prints 1, 2, 3
1
2
3

Một giải pháp để đạt được hành vi mong đợi là sử dụng tf.init_scope để nâng các hoạt động bên ngoài đồ thị hàm số. Điều này đảm bảo rằng việc tăng biến chỉ được thực hiện một lần trong thời gian truy tìm. Cần lưu ý rằng init_scope có các tác dụng phụ khác bao gồm luồng điều khiển bị xóa và băng gradient. Đôi khi việc sử dụng init_scope có thể trở nên quá phức tạp để quản lý một cách thực tế.

class Model(tf.Module):
  def __init__(self):
    self.v = tf.Variable(0)
    self.counter = 0

  @tf.function
  def __call__(self):
    if self.counter == 0:
      # Lifts ops out of function-building graphs
      with tf.init_scope():
        self.counter += 1
        self.v.assign_add(1)

    return self.v

m = Model()
for n in range(3):
  print(m().numpy()) # prints 1, 1, 1
1
1
1

Tóm lại, theo nguyên tắc chung, bạn nên tránh làm thay đổi các đối tượng python như số nguyên hoặc vùng chứa như danh sách nằm bên ngoài Function . Thay vào đó, hãy sử dụng các đối số và đối tượng TF. Ví dụ: phần "Tích lũy các giá trị trong một vòng lặp" có một ví dụ về cách các hoạt động giống danh sách có thể được triển khai.

Trong một số trường hợp, bạn có thể nắm bắt và thao tác trạng thái nếu nó là tf.Variable . Đây là cách các trọng số của các mô hình Keras được cập nhật với các lệnh gọi lặp lại đến cùng một ConcreteFunction .

Sử dụng trình tạo và trình tạo Python

Nhiều tính năng của Python, chẳng hạn như trình tạo và trình vòng lặp, dựa vào thời gian chạy Python để theo dõi trạng thái. Nói chung, trong khi các cấu trúc này hoạt động như mong đợi ở chế độ háo hức, chúng là ví dụ về các tác dụng phụ của Python và do đó chỉ xảy ra trong quá trình truy tìm.

@tf.function
def buggy_consume_next(iterator):
  tf.print("Value:", next(iterator))

iterator = iter([1, 2, 3])
buggy_consume_next(iterator)
# This reuses the first value from the iterator, rather than consuming the next value.
buggy_consume_next(iterator)
buggy_consume_next(iterator)
Value: 1
Value: 1
Value: 1

Cũng giống như cách TensorFlow có tf.TensorArray chuyên biệt cho các cấu trúc danh sách, nó có tf.data.Iterator chuyên biệt cho các cấu trúc lặp. Xem phần về các phép biến đổi AutoGraph để biết tổng quan. Ngoài ra, API tf.data có thể giúp triển khai các mẫu trình tạo:

@tf.function
def good_consume_next(iterator):
  # This is ok, iterator is a tf.data.Iterator
  tf.print("Value:", next(iterator))

ds = tf.data.Dataset.from_tensor_slices([1, 2, 3])
iterator = iter(ds)
good_consume_next(iterator)
good_consume_next(iterator)
good_consume_next(iterator)
Value: 1
Value: 2
Value: 3

Tất cả các đầu ra của hàm tf. phải là giá trị trả về

Ngoại trừ tf.Variable s, một hàm tf. phải trả về tất cả các đầu ra của nó. Việc cố gắng truy cập trực tiếp vào bất kỳ tensor nào từ một hàm mà không đi qua các giá trị trả về gây ra "rò rỉ".

Ví dụ: hàm bên dưới "rò rỉ" tensor a qua Python global x :

x = None

@tf.function
def leaky_function(a):
  global x
  x = a + 1  # Bad - leaks local tensor
  return a + 2

correct_a = leaky_function(tf.constant(1))

print(correct_a.numpy())  # Good - value obtained from function's returns
try:
  x.numpy()  # Bad - tensor leaked from inside the function, cannot be used here
except AttributeError as expected:
  print(expected)
3
'Tensor' object has no attribute 'numpy'

Điều này đúng ngay cả khi giá trị bị rò rỉ cũng được trả về:

@tf.function
def leaky_function(a):
  global x
  x = a + 1  # Bad - leaks local tensor
  return x  # Good - uses local tensor

correct_a = leaky_function(tf.constant(1))

print(correct_a.numpy())  # Good - value obtained from function's returns
try:
  x.numpy()  # Bad - tensor leaked from inside the function, cannot be used here
except AttributeError as expected:
  print(expected)

@tf.function
def captures_leaked_tensor(b):
  b += x  # Bad - `x` is leaked from `leaky_function`
  return b

with assert_raises(TypeError):
  captures_leaked_tensor(tf.constant(2))
2
'Tensor' object has no attribute 'numpy'
Caught expected exception 
  <class 'TypeError'>:
Traceback (most recent call last):
  File "/tmp/ipykernel_26244/3551158538.py", line 8, in assert_raises
    yield
  File "/tmp/ipykernel_26244/566849597.py", line 21, in <module>
    captures_leaked_tensor(tf.constant(2))
TypeError: Originated from a graph execution error.

The graph execution error is detected at a node built at (most recent call last):
>>>  File /usr/lib/python3.7/runpy.py, line 193, in _run_module_as_main
>>>  File /usr/lib/python3.7/runpy.py, line 85, in _run_code
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/ipykernel_launcher.py, line 16, in <module>
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/traitlets/config/application.py, line 846, in launch_instance
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/ipykernel/kernelapp.py, line 677, in start
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tornado/platform/asyncio.py, line 199, in start
>>>  File /usr/lib/python3.7/asyncio/base_events.py, line 534, in run_forever
>>>  File /usr/lib/python3.7/asyncio/base_events.py, line 1771, in _run_once
>>>  File /usr/lib/python3.7/asyncio/events.py, line 88, in _run
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/ipykernel/kernelbase.py, line 457, in dispatch_queue
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/ipykernel/kernelbase.py, line 446, in process_one
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/ipykernel/kernelbase.py, line 353, in dispatch_shell
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/ipykernel/kernelbase.py, line 648, in execute_request
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/ipykernel/ipkernel.py, line 353, in do_execute
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/ipykernel/zmqshell.py, line 533, in run_cell
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/IPython/core/interactiveshell.py, line 2902, in run_cell
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/IPython/core/interactiveshell.py, line 2947, in _run_cell
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/IPython/core/async_helpers.py, line 68, in _pseudo_sync_runner
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/IPython/core/interactiveshell.py, line 3173, in run_cell_async
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/IPython/core/interactiveshell.py, line 3364, in run_ast_nodes
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/IPython/core/interactiveshell.py, line 3444, in run_code
>>>  File /tmp/ipykernel_26244/566849597.py, line 7, in <module>
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/util/traceback_utils.py, line 150, in error_handler
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py, line 910, in __call__
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py, line 958, in _call
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py, line 781, in _initialize
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/eager/function.py, line 3157, in _get_concrete_function_internal_garbage_collected
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/eager/function.py, line 3557, in _maybe_define_function
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/eager/function.py, line 3402, in _create_graph_function
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/framework/func_graph.py, line 1143, in func_graph_from_py_func
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py, line 672, in wrapped_fn
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/framework/func_graph.py, line 1125, in autograph_handler
>>>  File /tmp/ipykernel_26244/566849597.py, line 4, in leaky_function
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/util/traceback_utils.py, line 150, in error_handler
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/ops/math_ops.py, line 1383, in binary_op_wrapper
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/util/traceback_utils.py, line 150, in error_handler
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/util/dispatch.py, line 1096, in op_dispatch_handler
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/ops/math_ops.py, line 1737, in _add_dispatch
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/ops/gen_math_ops.py, line 476, in add_v2
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/framework/op_def_library.py, line 746, in _apply_op_helper
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/framework/func_graph.py, line 691, in _create_op_internal
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/framework/ops.py, line 3705, in _create_op_internal
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/framework/ops.py, line 2101, in __init__

Error detected in node 'add' defined at: File "/tmp/ipykernel_26244/566849597.py", line 4, in leaky_function

TypeError: tf.Graph captured an external symbolic tensor. The symbolic tensor 'add:0' created by node 'add' is captured by the tf.Graph being executed as an input. But a tf.Graph is not allowed to take symbolic tensors from another graph as its inputs. Make sure all captured inputs of the executing tf.Graph are not symbolic tensors. Use return values, explicit Python locals or TensorFlow collections to access it. Please see https://www.tensorflow.org/guide/function#all_outputs_of_a_tffunction_must_be_return_values for more information.

Thông thường, những rò rỉ như thế này xảy ra khi bạn sử dụng các câu lệnh hoặc cấu trúc dữ liệu Python. Ngoài việc rò rỉ các tensor không truy cập được, các câu lệnh như vậy cũng có khả năng sai vì chúng được tính là tác dụng phụ của Python và không được đảm bảo thực thi ở mọi lệnh gọi hàm.

Các cách phổ biến để làm rò rỉ các tensors cục bộ cũng bao gồm việc thay đổi một bộ sưu tập Python bên ngoài hoặc một đối tượng:

class MyClass:

  def __init__(self):
    self.field = None

external_list = []
external_object = MyClass()

def leaky_function():
  a = tf.constant(1)
  external_list.append(a)  # Bad - leaks tensor
  external_object.field = a  # Bad - leaks tensor

Các hàm tf.funcive đệ quy không được hỗ trợ

Function đệ quy không được hỗ trợ và có thể gây ra vòng lặp vô hạn. Ví dụ,

@tf.function
def recursive_fn(n):
  if n > 0:
    return recursive_fn(n - 1)
  else:
    return 1

with assert_raises(Exception):
  recursive_fn(tf.constant(5))  # Bad - maximum recursion error.
Caught expected exception 
  <class 'Exception'>:
Traceback (most recent call last):
  File "/tmp/ipykernel_26244/3551158538.py", line 8, in assert_raises
    yield
  File "/tmp/ipykernel_26244/2233998312.py", line 9, in <module>
    recursive_fn(tf.constant(5))  # Bad - maximum recursion error.
tensorflow.python.autograph.impl.api.StagingError: in user code:

    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 3, in recursive_fn  *
        if n > 0:
    File "/usr/lib/python3.7/abc.py", line 139, in __instancecheck__
        return _abc_instancecheck(cls, instance)

    RecursionError: maximum recursion depth exceeded while calling a Python object

Ngay cả khi một Function đệ quy có vẻ hoạt động, hàm python sẽ được theo dõi nhiều lần và có thể có hàm ý về hiệu suất. Ví dụ,

@tf.function
def recursive_fn(n):
  if n > 0:
    print('tracing')
    return recursive_fn(n - 1)
  else:
    return 1

recursive_fn(5)  # Warning - multiple tracings
tracing
tracing
tracing
tracing
tracing
<tf.Tensor: shape=(), dtype=int32, numpy=1>

Các vấn đề đã biết

Nếu Function của bạn đánh giá không chính xác, lỗi có thể được giải thích bởi các vấn đề đã biết này được lên kế hoạch khắc phục trong tương lai.

Tùy thuộc vào các biến toàn cầu và miễn phí của Python

Function tạo một ConcreteFunction mới khi được gọi với giá trị mới của một đối số Python. Tuy nhiên, nó không làm điều đó đối với bao đóng Python, toàn cầu hoặc không định vị của Function đó. Nếu giá trị của chúng thay đổi giữa các lần gọi đến Function , thì Function sẽ vẫn sử dụng các giá trị mà chúng đã có khi nó được truy tìm. Điều này khác với cách các hàm Python thông thường hoạt động.

Vì lý do đó, bạn nên làm theo phong cách lập trình hàm sử dụng các đối số thay vì đóng trên các tên bên ngoài.

@tf.function
def buggy_add():
  return 1 + foo

@tf.function
def recommended_add(foo):
  return 1 + foo

foo = 1
print("Buggy:", buggy_add())
print("Correct:", recommended_add(foo))
Buggy: tf.Tensor(2, shape=(), dtype=int32)
Correct: tf.Tensor(2, shape=(), dtype=int32)
print("Updating the value of `foo` to 100!")
foo = 100
print("Buggy:", buggy_add())  # Did not change!
print("Correct:", recommended_add(foo))
Updating the value of `foo` to 100!
Buggy: tf.Tensor(2, shape=(), dtype=int32)
Correct: tf.Tensor(101, shape=(), dtype=int32)

Một cách khác để cập nhật giá trị toàn cục là đặt nó thành tf.Variable và thay vào đó sử dụng phương thức Variable.assign .

@tf.function
def variable_add():
  return 1 + foo

foo = tf.Variable(1)
print("Variable:", variable_add())
Variable: tf.Tensor(2, shape=(), dtype=int32)
print("Updating the value of `foo` to 100!")
foo.assign(100)
print("Variable:", variable_add())
Updating the value of `foo` to 100!
Variable: tf.Tensor(101, shape=(), dtype=int32)

Tùy thuộc vào các đối tượng Python

Khuyến nghị chuyển các đối tượng Python dưới dạng đối số vào tf.function có một số vấn đề đã biết, dự kiến ​​sẽ được khắc phục trong tương lai. Nói chung, bạn có thể dựa vào khả năng theo dõi nhất quán nếu bạn sử dụng cấu trúc nguyên thủy Python hoặc cấu trúc tương thích tf.nest làm đối số hoặc chuyển trong một phiên bản khác của đối tượng vào một Function . Tuy nhiên, Function sẽ không tạo một dấu vết mới khi bạn truyền cùng một đối tượng và chỉ thay đổi các thuộc tính của nó .

class SimpleModel(tf.Module):
  def __init__(self):
    # These values are *not* tf.Variables.
    self.bias = 0.
    self.weight = 2.

@tf.function
def evaluate(model, x):
  return model.weight * x + model.bias

simple_model = SimpleModel()
x = tf.constant(10.)
print(evaluate(simple_model, x))
tf.Tensor(20.0, shape=(), dtype=float32)
print("Adding bias!")
simple_model.bias += 5.0
print(evaluate(simple_model, x))  # Didn't change :(
Adding bias!
tf.Tensor(20.0, shape=(), dtype=float32)

Việc sử dụng cùng một Function để đánh giá phiên bản cập nhật của mô hình sẽ có lỗi vì mô hình được cập nhật có khóa bộ nhớ cache giống với mô hình ban đầu.

Vì lý do đó, bạn nên viết Function của mình để tránh phụ thuộc vào thuộc tính đối tượng có thể thay đổi hoặc tạo đối tượng mới.

Nếu không thể, một cách giải quyết là tạo các Function mới mỗi khi bạn sửa đổi đối tượng của mình để buộc thực hiện lại:

def evaluate(model, x):
  return model.weight * x + model.bias

new_model = SimpleModel()
evaluate_no_bias = tf.function(evaluate).get_concrete_function(new_model, x)
# Don't pass in `new_model`, `Function` already captured its state during tracing.
print(evaluate_no_bias(x))
tf.Tensor(20.0, shape=(), dtype=float32)
print("Adding bias!")
new_model.bias += 5.0
# Create new Function and ConcreteFunction since you modified new_model.
evaluate_with_bias = tf.function(evaluate).get_concrete_function(new_model, x)
print(evaluate_with_bias(x)) # Don't pass in `new_model`.
Adding bias!
tf.Tensor(25.0, shape=(), dtype=float32)

Vì việc thử lại có thể tốn kém , bạn có thể sử dụng tf.Variable s làm thuộc tính đối tượng, có thể bị thay đổi (nhưng không được thay đổi, cẩn thận!) Để có hiệu ứng tương tự mà không cần truy xuất lại.

class BetterModel:

  def __init__(self):
    self.bias = tf.Variable(0.)
    self.weight = tf.Variable(2.)

@tf.function
def evaluate(model, x):
  return model.weight * x + model.bias

better_model = BetterModel()
print(evaluate(better_model, x))
tf.Tensor(20.0, shape=(), dtype=float32)
print("Adding bias!")
better_model.bias.assign_add(5.0)  # Note: instead of better_model.bias += 5
print(evaluate(better_model, x))  # This works!
Adding bias!
tf.Tensor(25.0, shape=(), dtype=float32)

Tạo tf.Variables

Function chỉ hỗ trợ singleton tf.Variable được tạo một lần trong lần gọi đầu tiên và được sử dụng lại trong các lần gọi hàm tiếp theo. Đoạn mã bên dưới sẽ tạo một tf.Variable mới trong mọi lệnh gọi hàm, dẫn đến ngoại lệ ValueError .

Thí dụ:

@tf.function
def f(x):
  v = tf.Variable(1.0)
  return v

with assert_raises(ValueError):
  f(1.0)
Caught expected exception 
  <class 'ValueError'>:
Traceback (most recent call last):
  File "/tmp/ipykernel_26244/3551158538.py", line 8, in assert_raises
    yield
  File "/tmp/ipykernel_26244/3018268426.py", line 7, in <module>
    f(1.0)
ValueError: in user code:

    File "/tmp/ipykernel_26244/3018268426.py", line 3, in f  *
        v = tf.Variable(1.0)

    ValueError: tf.function only supports singleton tf.Variables created on the first call. Make sure the tf.Variable is only created once or created outside tf.function. See https://www.tensorflow.org/guide/function#creating_tfvariables for more information.

Một mẫu phổ biến được sử dụng để khắc phục hạn chế này là bắt đầu bằng giá trị Không có trong Python, sau đó tạo tf.Variable có điều kiện nếu giá trị là Không:

class Count(tf.Module):
  def __init__(self):
    self.count = None

  @tf.function
  def __call__(self):
    if self.count is None:
      self.count = tf.Variable(0)
    return self.count.assign_add(1)

c = Count()
print(c())
print(c())
tf.Tensor(1, shape=(), dtype=int32)
tf.Tensor(2, shape=(), dtype=int32)

Sử dụng với nhiều trình tối ưu hóa Keras

Bạn có thể gặp phải ValueError: tf.function only supports singleton tf.Variables created on the first call. khi sử dụng nhiều hơn một trình tối ưu hóa Keras có tf.function . Lỗi này xảy ra do trình tối ưu hóa tạo tf.Variables bên trong khi chúng áp dụng gradient lần đầu tiên.

opt1 = tf.keras.optimizers.Adam(learning_rate = 1e-2)
opt2 = tf.keras.optimizers.Adam(learning_rate = 1e-3)

@tf.function
def train_step(w, x, y, optimizer):
   with tf.GradientTape() as tape:
       L = tf.reduce_sum(tf.square(w*x - y))
   gradients = tape.gradient(L, [w])
   optimizer.apply_gradients(zip(gradients, [w]))

w = tf.Variable(2.)
x = tf.constant([-1.])
y = tf.constant([2.])

train_step(w, x, y, opt1)
print("Calling `train_step` with different optimizer...")
with assert_raises(ValueError):
  train_step(w, x, y, opt2)
Calling `train_step` with different optimizer...
Caught expected exception 
  <class 'ValueError'>:
Traceback (most recent call last):
  File "/tmp/ipykernel_26244/3551158538.py", line 8, in assert_raises
    yield
  File "/tmp/ipykernel_26244/3167358578.py", line 18, in <module>
    train_step(w, x, y, opt2)
ValueError: in user code:

    File "/tmp/ipykernel_26244/3167358578.py", line 9, in train_step  *
        optimizer.apply_gradients(zip(gradients, [w]))
    File "/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/keras/optimizer_v2/optimizer_v2.py", line 639, in apply_gradients  **
        self._create_all_weights(var_list)
    File "/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/keras/optimizer_v2/optimizer_v2.py", line 828, in _create_all_weights
        _ = self.iterations
    File "/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/keras/optimizer_v2/optimizer_v2.py", line 835, in __getattribute__
        return super(OptimizerV2, self).__getattribute__(name)
    File "/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/keras/optimizer_v2/optimizer_v2.py", line 995, in iterations
        aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA)
    File "/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/keras/optimizer_v2/optimizer_v2.py", line 1202, in add_weight
        aggregation=aggregation)
    File "/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/keras/engine/base_layer_utils.py", line 129, in make_variable
        shape=variable_shape if variable_shape else None)

    ValueError: tf.function only supports singleton tf.Variables created on the first call. Make sure the tf.Variable is only created once or created outside tf.function. See https://www.tensorflow.org/guide/function#creating_tfvariables for more information.

Nếu bạn cần thay đổi trình tối ưu hóa trong quá trình đào tạo, cách giải quyết là tạo một Function mới cho mỗi trình tối ưu hóa, gọi trực tiếp ConcreteFunction .

opt1 = tf.keras.optimizers.Adam(learning_rate = 1e-2)
opt2 = tf.keras.optimizers.Adam(learning_rate = 1e-3)

# Not a tf.function.
def train_step(w, x, y, optimizer):
   with tf.GradientTape() as tape:
       L = tf.reduce_sum(tf.square(w*x - y))
   gradients = tape.gradient(L, [w])
   optimizer.apply_gradients(zip(gradients, [w]))

w = tf.Variable(2.)
x = tf.constant([-1.])
y = tf.constant([2.])

# Make a new Function and ConcreteFunction for each optimizer.
train_step_1 = tf.function(train_step).get_concrete_function(w, x, y, opt1)
train_step_2 = tf.function(train_step).get_concrete_function(w, x, y, opt2)
for i in range(10):
  if i % 2 == 0:
    train_step_1(w, x, y) # `opt1` is not used as a parameter. 
  else:
    train_step_2(w, x, y) # `opt2` is not used as a parameter.

Sử dụng với nhiều mô hình Keras

Bạn cũng có thể gặp phải ValueError: tf.function only supports singleton tf.Variables created on the first call. khi chuyển các thể hiện mô hình khác nhau cho cùng một Function .

Lỗi này xảy ra do các mô hình Keras ( không có hình dạng đầu vào được xác định ) và các lớp Keras tạo tf.Variables khi chúng được gọi lần đầu tiên. Bạn có thể đang cố gắng khởi tạo các biến đó bên trong một Function , đã được gọi. Để tránh lỗi này, hãy thử gọi model.build(input_shape) để khởi tạo tất cả các trọng số trước khi huấn luyện mô hình.

đọc thêm

Để tìm hiểu về cách xuất và tải một Function , hãy xem hướng dẫn SavedModel . Để tìm hiểu thêm về các tối ưu hóa biểu đồ được thực hiện sau khi theo dõi, hãy xem hướng dẫn Grappler . Để tìm hiểu cách tối ưu hóa đường ống dữ liệu và lập hồ sơ cho mô hình của bạn, hãy xem hướng dẫn về Hồ sơ.