Lưu ngày! Google I / O hoạt động trở lại từ ngày 18 đến 20 tháng 5 Đăng ký ngay
Trang này được dịch bởi Cloud Translation API.
Switch to English

Hiệu suất tốt hơn với tf. Chức năng

Xem trên TensorFlow.org Chạy trong Google Colab Xem nguồn trên GitHub Tải xuống sổ ghi chép

Trong TensorFlow 2, thực thi háo hức được bật theo mặc định. Giao diện người dùng trực quan và linh hoạt (chạy các hoạt động một lần dễ dàng hơn và nhanh hơn nhiều), nhưng điều này có thể phải trả giá bằng hiệu suất và khả năng triển khai.

Bạn có thể sử dụng tf.function để tạo đồ thị từ các chương trình của mình. Nó là một công cụ chuyển đổi tạo ra các biểu đồ luồng dữ liệu độc lập với Python từ mã Python của bạn. Điều này sẽ giúp bạn tạo các mô hình hiệu quả và di động, và bắt buộc phải sử dụng SavedModel .

Hướng dẫn này sẽ giúp bạn khái niệm về cách hoạt động của tf.function dưới mui xe để bạn có thể sử dụng nó một cách hiệu quả.

Các điểm rút ra và khuyến nghị chính là:

  • Gỡ lỗi ở chế độ háo hức, sau đó trang trí với @tf.function . @tf.function .
  • Đừng dựa vào các tác dụng phụ của Python như đột biến đối tượng hoặc nối danh sách.
  • tf.function hoạt động tốt nhất với hoạt động TensorFlow; Các cuộc gọi NumPy và Python được chuyển đổi thành hằng số.

Thiết lập

import tensorflow as tf

Xác định một hàm trợ giúp để chứng minh các loại lỗi bạn có thể gặp phải:

import traceback
import contextlib

# Some helper code to demonstrate the kinds of errors you might encounter.
@contextlib.contextmanager
def assert_raises(error_class):
  try:
    yield
  except error_class as e:
    print('Caught expected exception \n  {}:'.format(error_class))
    traceback.print_exc(limit=2)
  except Exception as e:
    raise e
  else:
    raise Exception('Expected {} to be raised but no error was raised!'.format(
        error_class))

Khái niệm cơ bản

Sử dụng

Một Function bạn xác định (ví dụ bằng cách áp dụng trình trang trí @tf.function ) giống như một hoạt động TensorFlow cốt lõi: Bạn có thể thực thi nó một cách hăng hái; bạn có thể tính toán độ dốc; và như thế.

@tf.function  # The decorator converts `add` into a `Function`.
def add(a, b):
  return a + b

add(tf.ones([2, 2]), tf.ones([2, 2]))  #  [[2., 2.], [2., 2.]]
<tf.Tensor: shape=(2, 2), dtype=float32, numpy=
array([[2., 2.],
       [2., 2.]], dtype=float32)>
v = tf.Variable(1.0)
with tf.GradientTape() as tape:
  result = add(v, 1.0)
tape.gradient(result, v)
<tf.Tensor: shape=(), dtype=float32, numpy=1.0>

Bạn có thể sử dụng các Function bên trong các Function khác.

@tf.function
def dense_layer(x, w, b):
  return add(tf.matmul(x, w), b)

dense_layer(tf.ones([3, 2]), tf.ones([2, 2]), tf.ones([2]))
<tf.Tensor: shape=(3, 2), dtype=float32, numpy=
array([[3., 3.],
       [3., 3.],
       [3., 3.]], dtype=float32)>

Function s có thể nhanh hơn mã háo hức, đặc biệt đối với đồ thị có nhiều hoạt động nhỏ. Nhưng đối với các đồ thị có một vài hoạt động đắt tiền (như chập), bạn có thể không thấy tốc độ tăng nhiều.

import timeit
conv_layer = tf.keras.layers.Conv2D(100, 3)

@tf.function
def conv_fn(image):
  return conv_layer(image)

image = tf.zeros([1, 200, 200, 100])
# warm up
conv_layer(image); conv_fn(image)
print("Eager conv:", timeit.timeit(lambda: conv_layer(image), number=10))
print("Function conv:", timeit.timeit(lambda: conv_fn(image), number=10))
print("Note how there's not much difference in performance for convolutions")
Eager conv: 0.0035502629999655255
Function conv: 0.004116348000025027
Note how there's not much difference in performance for convolutions

Truy tìm

Phần này trình bày cách Function hoạt động ẩn, bao gồm các chi tiết triển khai có thể thay đổi trong tương lai . Tuy nhiên, một khi bạn hiểu tại sao và khi nào việc truy tìm xảy ra, thì việc sử dụng tf.function một cách hiệu quả sẽ dễ dàng hơn nhiều!

"Truy tìm" là gì?

Một Function chạy chương trình của bạn trong một Đồ thị TensorFlow . Tuy nhiên, một tf.Graph không thể đại diện cho tất cả những thứ bạn viết trong một chương trình TensorFlow háo hức. Ví dụ: Python hỗ trợ tính đa hình, nhưng tf.Graph yêu cầu đầu vào của nó phải có kiểu dữ liệu và thứ nguyên được chỉ định. Hoặc bạn có thể thực hiện các tác vụ phụ như đọc các đối số dòng lệnh, nêu lỗi hoặc làm việc với một đối tượng Python phức tạp hơn; không có thứ nào trong số này có thể chạy trong tf.Graph .

Function thu hẹp khoảng cách này bằng cách tách mã của bạn thành hai giai đoạn:

1) Trong giai đoạn đầu tiên, được gọi là " truy tìm ", Function tạo một tf.Graph mới. Mã Python chạy bình thường, nhưng tất cả các hoạt động TensorFlow (như thêm hai Tensor) đều bị hoãn lại : chúng được tf.Graph nắm bắt và không chạy.

2) Trong giai đoạn thứ hai, một tf.Graph chứa mọi thứ đã bị trì hoãn trong giai đoạn đầu tiên được chạy. Giai đoạn này nhanh hơn nhiều so với giai đoạn truy tìm.

Tùy thuộc vào đầu vào của nó, Function sẽ không luôn chạy ở giai đoạn đầu tiên khi nó được gọi. Xem "Quy tắc theo dõi" bên dưới để hiểu rõ hơn về cách nó đưa ra quyết định đó. Bỏ qua giai đoạn đầu tiên và chỉ thực hiện giai đoạn thứ hai là những gì mang lại cho bạn hiệu suất cao của TensorFlow.

Khi Function quyết định theo dõi, giai đoạn theo dõi ngay lập tức được theo sau bởi giai đoạn thứ hai, vì vậy việc gọi Function vừa tạo và chạy tf.Graph . Sau đó, bạn sẽ thấy cách bạn có thể chạy chỉ giai đoạn theo dõi với get_concrete_function .

Khi chúng ta truyền các đối số của các kiểu khác nhau vào một Function , cả hai giai đoạn đều được chạy:

@tf.function
def double(a):
  print("Tracing with", a)
  return a + a

print(double(tf.constant(1)))
print()
print(double(tf.constant(1.1)))
print()
print(double(tf.constant("a")))
print()
Tracing with Tensor("a:0", shape=(), dtype=int32)
tf.Tensor(2, shape=(), dtype=int32)

Tracing with Tensor("a:0", shape=(), dtype=float32)
tf.Tensor(2.2, shape=(), dtype=float32)

Tracing with Tensor("a:0", shape=(), dtype=string)
tf.Tensor(b'aa', shape=(), dtype=string)

Lưu ý rằng nếu bạn gọi nhiều lần một Function có cùng loại đối số, TensorFlow sẽ bỏ qua giai đoạn theo dõi và sử dụng lại một biểu đồ đã theo dõi trước đó, vì biểu đồ được tạo sẽ giống hệt nhau.

# This doesn't print 'Tracing with ...'
print(double(tf.constant("b")))
tf.Tensor(b'bb', shape=(), dtype=string)

Bạn có thể sử dụng pretty_printed_concrete_signatures() để xem tất cả các dấu vết có sẵn:

print(double.pretty_printed_concrete_signatures())
double(a)
  Args:
    a: float32 Tensor, shape=()
  Returns:
    float32 Tensor, shape=()

double(a)
  Args:
    a: string Tensor, shape=()
  Returns:
    string Tensor, shape=()

double(a)
  Args:
    a: int32 Tensor, shape=()
  Returns:
    int32 Tensor, shape=()

Cho đến nay, bạn đã thấy rằng tf.function tạo ra một lớp điều phối động, được lưu trong bộ nhớ cache trên logic theo dõi đồ thị của TensorFlow. Để cụ thể hơn về thuật ngữ:

  • Một tf.Graph là bản trình bày thô, không thể ngôn ngữ, di động của một phép tính TensorFlow.
  • Một ConcreteFunction kết thúc một tf.Graph .
  • Function quản lý bộ nhớ cache của các ConcreteFunction và chọn một bộ nhớ đệm phù hợp cho đầu vào của bạn.
  • tf.function kết thúc một hàm Python, trả về một đối tượng Function .
  • Tracing tạo ra một tf.Graph và bao bọc nó trong một ConcreteFunction , còn được gọi là một dấu vết.

Quy tắc truy tìm

Một Function xác định xem có nên sử dụng lại ConcreteFunction theo dõi hay không bằng cách tính toán khóa bộ nhớ cache từ các args và kwargs của đầu vào. Khóa bộ nhớ cache là khóa xác định ConcreteFunction dựa trên các args và kwargs đầu vào của lệnh gọi Function , theo các quy tắc sau (có thể thay đổi):

  • Chìa khóa được tạo ra cho tf.Tensor là hình dạng và kiểu của nó.
  • Khóa được tạo cho tf.Variable là một id biến duy nhất.
  • Khóa được tạo cho một nguyên thủy Python (như int , float , str ) là giá trị của nó.
  • Khóa được tạo ra cho các dict , list s, tuple s, namedtupleattr được lồng vào nhau là bộ khóa lá được làm phẳng (xemnest.flatten ). (Do kết quả của việc làm phẳng này, việc gọi một hàm cụ thể có cấu trúc lồng khác với cấu trúc được sử dụng trong quá trình truy tìm sẽ dẫn đến Lỗi loại).
  • Đối với tất cả các kiểu Python khác, khóa là duy nhất cho đối tượng. Bằng cách này, một hàm hoặc phương thức được truy tìm độc lập cho mỗi trường hợp mà nó được gọi.

Kiểm soát việc kiểm tra lại

Retracing, đó là khi Function của bạn tạo nhiều hơn một dấu vết, giúp đảm bảo rằng TensorFlow tạo ra các đồ thị chính xác cho từng nhóm đầu vào. Tuy nhiên, truy tìm là một hoạt động tốn kém! Nếu Function của bạn truy xuất lại một đồ thị mới cho mỗi lần gọi, bạn sẽ thấy rằng mã của bạn thực thi chậm hơn so với khi bạn không sử dụng tf.function . tf.function .

Để kiểm soát hành vi theo dõi, bạn có thể sử dụng các kỹ thuật sau:

  • Chỉ định input_signature trong tf.function để hạn chế việc theo dõi.
@tf.function(input_signature=(tf.TensorSpec(shape=[None], dtype=tf.int32),))
def next_collatz(x):
  print("Tracing with", x)
  return tf.where(x % 2 == 0, x // 2, 3 * x + 1)

print(next_collatz(tf.constant([1, 2])))
# We specified a 1-D tensor in the input signature, so this should fail.
with assert_raises(ValueError):
  next_collatz(tf.constant([[1, 2], [3, 4]]))

# We specified an int32 dtype in the input signature, so this should fail.
with assert_raises(ValueError):
  next_collatz(tf.constant([1.0, 2.0]))
Tracing with Tensor("x:0", shape=(None,), dtype=int32)
tf.Tensor([4 1], shape=(2,), dtype=int32)
Caught expected exception 
  <class 'ValueError'>:
Caught expected exception 
  <class 'ValueError'>:
Traceback (most recent call last):
  File "<ipython-input-1-73d0ca52e838>", line 8, in assert_raises
    yield
  File "<ipython-input-1-20f544b8adbf>", line 9, in <module>
    next_collatz(tf.constant([[1, 2], [3, 4]]))
ValueError: Python inputs incompatible with input_signature:
  inputs: (
    tf.Tensor(
[[1 2]
 [3 4]], shape=(2, 2), dtype=int32))
  input_signature: (
    TensorSpec(shape=(None,), dtype=tf.int32, name=None))
Traceback (most recent call last):
  File "<ipython-input-1-73d0ca52e838>", line 8, in assert_raises
    yield
  File "<ipython-input-1-20f544b8adbf>", line 13, in <module>
    next_collatz(tf.constant([1.0, 2.0]))
ValueError: Python inputs incompatible with input_signature:
  inputs: (
    tf.Tensor([1. 2.], shape=(2,), dtype=float32))
  input_signature: (
    TensorSpec(shape=(None,), dtype=tf.int32, name=None))
  • Chỉ định thứ nguyên [Không có] trong tf.TensorSpec để cho phép sử dụng lại dấu vết một cách linh hoạt.

    Vì TensorFlow đối sánh các tensor dựa trên hình dạng của chúng, việc sử dụng thứ nguyên None làm ký tự đại diện sẽ cho phép Function s sử dụng lại dấu vết cho đầu vào có kích thước thay đổi. Đầu vào có kích thước khác nhau có thể xảy ra nếu bạn có các chuỗi có độ dài khác nhau hoặc hình ảnh có kích thước khác nhau cho mỗi lô (Ví dụ: Xem hướng dẫn về TransformerDeep Dream ).

@tf.function(input_signature=(tf.TensorSpec(shape=[None], dtype=tf.int32),))
def g(x):
  print('Tracing with', x)
  return x

# No retrace!
print(g(tf.constant([1, 2, 3])))
print(g(tf.constant([1, 2, 3, 4, 5])))
Tracing with Tensor("x:0", shape=(None,), dtype=int32)
tf.Tensor([1 2 3], shape=(3,), dtype=int32)
tf.Tensor([1 2 3 4 5], shape=(5,), dtype=int32)
  • Truyền các đối số Python tới Tensors để giảm việc rút lại.

    Thông thường, các đối số Python được sử dụng để kiểm soát các siêu tham số và cấu trúc đồ thị - ví dụ: num_layers=10 hoặc training=True hoặc nonlinearity='relu' . Vì vậy, nếu đối số Python thay đổi, có nghĩa là bạn phải truy xuất lại biểu đồ.

    Tuy nhiên, có thể một đối số Python không được sử dụng để kiểm soát việc xây dựng đồ thị. Trong những trường hợp này, một sự thay đổi trong giá trị Python có thể kích hoạt việc kiểm tra lại không cần thiết. Lấy ví dụ, vòng lặp đào tạo này, AutoGraph sẽ tự động hủy cuộn. Mặc dù có nhiều dấu vết, nhưng biểu đồ được tạo thực sự giống hệt nhau, vì vậy việc kiểm tra lại là không cần thiết.

def train_one_step():
  pass

@tf.function
def train(num_steps):
  print("Tracing with num_steps = ", num_steps)
  tf.print("Executing with num_steps = ", num_steps)
  for _ in tf.range(num_steps):
    train_one_step()

print("Retracing occurs for different Python arguments.")
train(num_steps=10)
train(num_steps=20)

print()
print("Traces are reused for Tensor arguments.")
train(num_steps=tf.constant(10))
train(num_steps=tf.constant(20))
Retracing occurs for different Python arguments.
Tracing with num_steps =  10
Executing with num_steps =  10
Tracing with num_steps =  20
Executing with num_steps =  20

Traces are reused for Tensor arguments.
Tracing with num_steps =  Tensor("num_steps:0", shape=(), dtype=int32)
Executing with num_steps =  10
Executing with num_steps =  20

Nếu bạn cần buộc phải rút lại, hãy tạo một Function mới. Các đối tượng Function riêng biệt được đảm bảo không chia sẻ dấu vết.

def f():
  print('Tracing!')
  tf.print('Executing')

tf.function(f)()
tf.function(f)()
Tracing!
Executing
Tracing!
Executing

Đạt được các chức năng cụ thể

Mỗi khi một chức năng được truy tìm, một chức năng cụ thể mới được tạo ra. Bạn có thể lấy trực tiếp một hàm cụ thể bằng cách sử dụng get_concrete_function .

print("Obtaining concrete trace")
double_strings = double.get_concrete_function(tf.constant("a"))
print("Executing traced function")
print(double_strings(tf.constant("a")))
print(double_strings(a=tf.constant("b")))
Obtaining concrete trace
Executing traced function
tf.Tensor(b'aa', shape=(), dtype=string)
tf.Tensor(b'bb', shape=(), dtype=string)
# You can also call get_concrete_function on an InputSpec
double_strings_from_inputspec = double.get_concrete_function(tf.TensorSpec(shape=[], dtype=tf.string))
print(double_strings_from_inputspec(tf.constant("c")))
Tracing with Tensor("a:0", shape=(), dtype=string)
tf.Tensor(b'cc', shape=(), dtype=string)

Việc in ConcreteFunction hiển thị một bản tóm tắt các đối số đầu vào của nó (với các kiểu) và kiểu đầu ra của nó.

print(double_strings)
ConcreteFunction double(a)
  Args:
    a: string Tensor, shape=()
  Returns:
    string Tensor, shape=()

Bạn cũng có thể lấy trực tiếp chữ ký của một hàm cụ thể.

print(double_strings.structured_input_signature)
print(double_strings.structured_outputs)
((TensorSpec(shape=(), dtype=tf.string, name='a'),), {})
Tensor("Identity:0", shape=(), dtype=string)

Sử dụng dấu vết cụ thể với các loại không tương thích sẽ gây ra lỗi

with assert_raises(tf.errors.InvalidArgumentError):
  double_strings(tf.constant(1))
Caught expected exception 
  <class 'tensorflow.python.framework.errors_impl.InvalidArgumentError'>:
Traceback (most recent call last):
  File "<ipython-input-1-73d0ca52e838>", line 8, in assert_raises
    yield
  File "<ipython-input-1-e4e2860a4364>", line 2, in <module>
    double_strings(tf.constant(1))
tensorflow.python.framework.errors_impl.InvalidArgumentError: cannot compute __inference_double_162 as input #0(zero-based) was expected to be a string tensor but is a int32 tensor [Op:__inference_double_162]

Bạn có thể nhận thấy rằng các đối số Python được xử lý đặc biệt trong chữ ký đầu vào của một hàm cụ thể. Trước TensorFlow 2.3, các đối số trong Python chỉ đơn giản là bị xóa khỏi chữ ký của hàm cụ thể. Bắt đầu với TensorFlow 2.3, các đối số Python vẫn còn trong chữ ký, nhưng bị hạn chế lấy giá trị được đặt trong quá trình theo dõi.

@tf.function
def pow(a, b):
  return a ** b

square = pow.get_concrete_function(a=tf.TensorSpec(None, tf.float32), b=2)
print(square)
ConcreteFunction pow(a, b=2)
  Args:
    a: float32 Tensor, shape=<unknown>
  Returns:
    float32 Tensor, shape=<unknown>
assert square(tf.constant(10.0)) == 100

with assert_raises(TypeError):
  square(tf.constant(10.0), b=3)
Caught expected exception 
  <class 'TypeError'>:
Traceback (most recent call last):
  File "/tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/eager/function.py", line 1683, in _call_impl
    cancellation_manager)
  File "/tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/eager/function.py", line 1728, in _call_with_flat_signature
    self._flat_signature_summary(), ", ".join(sorted(kwargs))))
TypeError: pow(a) got unexpected keyword arguments: b.

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "<ipython-input-1-73d0ca52e838>", line 8, in assert_raises
    yield
  File "<ipython-input-1-d163f3d206cb>", line 4, in <module>
    square(tf.constant(10.0), b=3)
TypeError: ConcreteFunction pow(a, b) was constructed with int value 2 in b, but was called with int value 3

Lấy đồ thị

Mỗi hàm cụ thể là một trình bao bọc có thể gọi xung quanh một tf.Graph . Mặc dù việc truy xuất đối tượng tf.Graph thực tế không phải là điều bạn thường cần làm, nhưng bạn có thể lấy nó dễ dàng từ bất kỳ hàm cụ thể nào.

graph = double_strings.graph
for node in graph.as_graph_def().node:
  print(f'{node.input} -> {node.name}')
[] -> a
['a', 'a'] -> add
['add'] -> Identity

Gỡ lỗi

Nói chung, mã gỡ lỗi trong chế độ háo hức dễ dàng hơn so với bên trong tf.function . tf.function . Bạn nên đảm bảo rằng mã của bạn thực thi không có lỗi ở chế độ háo hức trước khi trang trí bằng tf.function . tf.function . Để hỗ trợ quá trình gỡ lỗi, bạn có thể gọi tf.config.run_functions_eagerly(True) để tắt và kích hoạt lại tf.function . tf.function trên toàn cầu.

Khi theo dõi các vấn đề chỉ xuất hiện trong tf.function , đây là một số mẹo:

  • Các print gọi print Python cũ đơn thuần chỉ thực thi trong quá trình truy tìm, giúp bạn theo dõi khi hàm của bạn được truy tìm (lại).
  • tf.print gọi tf.print sẽ thực hiện mọi lúc và có thể giúp bạn theo dõi các giá trị trung gian trong quá trình thực thi.
  • tf.debugging.enable_check_numerics là một cách dễ dàng để theo dõi nơi tạo NaN và Inf.
  • pdb có thể giúp bạn hiểu những gì đang diễn ra trong quá trình theo dõi. (Lưu ý: PDB sẽ đưa bạn vào mã nguồn được chuyển đổi AutoGraph.)

Chuyển đổi AutoGraph

AutoGraph là một thư viện được bật theo mặc định trong tf.function và chuyển đổi một tập hợp con của mã háo hức Python thành các hoạt động TensorFlow tương thích với đồ thị. Điều này bao gồm luồng điều khiển như if , for , while .

Các hoạt động TensorFlow như tf.condtf.while_loop tiếp tục hoạt động, nhưng luồng điều khiển thường dễ viết và dễ hiểu hơn khi được viết bằng Python.

# Simple loop

@tf.function
def f(x):
  while tf.reduce_sum(x) > 1:
    tf.print(x)
    x = tf.tanh(x)
  return x

f(tf.random.uniform([5]))
[0.928048491 0.537333608 0.319427252 0.414729953 0.138620138]
[0.729682684 0.490966946 0.308988899 0.392481416 0.137739]
[0.62287122 0.454983532 0.299516946 0.373497456 0.136874482]
[0.553123951 0.425986826 0.290870458 0.357047111 0.13602607]
[0.502857924 0.401961982 0.282935768 0.342610359 0.135193244]
[0.464361787 0.381626487 0.27562 0.329805791 0.134375557]
[0.433632493 0.364119112 0.268846452 0.318346262 0.133572534]
[0.408352554 0.348837078 0.262551099 0.308010817 0.132783771]
[0.387072921 0.335343778 0.256680071 0.298626363 0.132008836]
[0.368834078 0.32331419 0.251187652 0.290055037 0.131247327]
[0.352971435 0.312500536 0.246034727 0.282185435 0.130498841]
[0.339008093 0.302710205 0.241187632 0.274926543 0.129763052]
[0.326591551 0.293790847 0.236617178 0.26820302 0.129039586]
[0.315454811 0.285620153 0.232297987 0.261951953 0.128328085]
[0.305391371 0.278098613 0.228207797 0.256120354 0.127628237]
[0.296238661 0.27114439 0.224326983 0.250663161 0.126939729]
[0.287866682 0.264689356 0.220638305 0.245541915 0.126262262]
[0.280170113 0.25867638 0.217126325 0.240723446 0.12559554]
[0.273062497 0.253057063 0.213777393 0.236178935 0.124939285]
[0.266472191 0.247790173 0.210579231 0.231883332 0.124293216]
[0.260339141 0.242840245 0.207520843 0.227814704 0.12365707]
[0.254612684 0.238176659 0.204592302 0.223953649 0.123030603]
[0.249249727 0.23377277 0.201784685 0.220283121 0.122413576]
[0.244213238 0.229605287 0.199089885 0.216787875 0.12180575]
<tf.Tensor: shape=(5,), dtype=float32, numpy=
array([0.23947136, 0.22565375, 0.19650048, 0.21345437, 0.12120689],
      dtype=float32)>

Nếu bạn tò mò, bạn có thể kiểm tra mã tạo ra chữ ký.

print(tf.autograph.to_code(f.python_function))
def tf__f(x):
    with ag__.FunctionScope('f', 'fscope', ag__.ConversionOptions(recursive=True, user_requested=True, optional_features=(), internal_convert_user_code=True)) as fscope:
        do_return = False
        retval_ = ag__.UndefinedReturnValue()

        def get_state():
            return (x,)

        def set_state(vars_):
            nonlocal x
            (x,) = vars_

        def loop_body():
            nonlocal x
            ag__.converted_call(ag__.ld(tf).print, (ag__.ld(x),), None, fscope)
            x = ag__.converted_call(ag__.ld(tf).tanh, (ag__.ld(x),), None, fscope)

        def loop_test():
            return (ag__.converted_call(ag__.ld(tf).reduce_sum, (ag__.ld(x),), None, fscope) > 1)
        ag__.while_stmt(loop_test, loop_body, get_state, set_state, ('x',), {})
        try:
            do_return = True
            retval_ = ag__.ld(x)
        except:
            do_return = False
            raise
        return fscope.ret(retval_, do_return)

Điều kiện

AutoGraph sẽ chuyển đổi một số câu lệnh if <condition> thành các tf.cond gọi tf.cond tương đương. Sự thay thế này được thực hiện nếu <condition> là một Tensor. Nếu không, if tuyên bố được thực hiện như một điều kiện Python.

Một điều kiện Python thực thi trong quá trình theo dõi, vì vậy chính xác một nhánh của điều kiện sẽ được thêm vào biểu đồ. Nếu không có AutoGraph, biểu đồ theo dõi này sẽ không thể lấy nhánh thay thế nếu có luồng điều khiển phụ thuộc vào dữ liệu.

tf.cond theo dõi và thêm cả hai nhánh của điều kiện vào biểu đồ, chọn động một nhánh tại thời điểm thực thi. Truy tìm có thể có tác dụng phụ ngoài ý muốn; xem các hiệu ứng theo dõi AutoGraph để biết thêm.

@tf.function
def fizzbuzz(n):
  for i in tf.range(1, n + 1):
    print('Tracing for loop')
    if i % 15 == 0:
      print('Tracing fizzbuzz branch')
      tf.print('fizzbuzz')
    elif i % 3 == 0:
      print('Tracing fizz branch')
      tf.print('fizz')
    elif i % 5 == 0:
      print('Tracing buzz branch')
      tf.print('buzz')
    else:
      print('Tracing default branch')
      tf.print(i)

fizzbuzz(tf.constant(5))
fizzbuzz(tf.constant(20))
Tracing for loop
Tracing fizzbuzz branch
Tracing fizz branch
Tracing buzz branch
Tracing default branch
1
2
fizz
4
buzz
1
2
fizz
4
buzz
fizz
7
8
fizz
buzz
11
fizz
13
14
fizzbuzz
16
17
fizz
19
buzz

Xem tài liệu tham khảo để biết các hạn chế bổ sung đối với các câu lệnh if được chuyển đổi tự động.

Vòng lặp

AutoGraph sẽ chuyển đổi một số câu lệnh forwhile thành các hoạt động lặp TensorFlow tương đương, như tf.while_loop . tf.while_loop . Nếu không được chuyển đổi, for hay while vòng lặp được thực hiện như một vòng lặp Python.

Sự thay thế này được thực hiện trong các trường hợp sau:

  • for x in y : nếu y là Tensor, hãy chuyển đổi thành tf.while_loop . tf.while_loop . Trong trường hợp đặc biệt khi ytf.data.Dataset , sự kết hợp của các hoạt độngtf.data.Dataset được tạo ra.
  • while <condition> : nếu <condition> là Tensor, chuyển đổi thành tf.while_loop . tf.while_loop .

Một vòng lặp Python thực thi trong quá trình truy tìm, thêm các hoạt động bổ sung vào tf.Graph cho mỗi lần lặp lại của vòng lặp.

Một vòng lặp TensorFlow theo dõi phần thân của vòng lặp và tự động chọn số lần lặp để chạy tại thời điểm thực thi. Phần thân của vòng lặp chỉ xuất hiện một lần trong tf.Graph được tạo.

Xem tài liệu tham khảo để biết các hạn chế bổ sung đối với các câu lệnh forwhile chuyển đổi tự động.

Lặp qua dữ liệu Python

Một cạm bẫy phổ biến là lặp lại dữ liệu Python / Numpy trong một tf.function . Vòng lặp này sẽ thực thi trong quá trình truy tìm, thêm một bản sao mô hình của bạn vào tf.Graph cho mỗi lần lặp lại của vòng lặp.

Nếu bạn muốn gói toàn bộ vòng huấn luyện trong tf.function , cách an toàn nhất để làm điều này là bọc dữ liệu của bạn dưới dạngtf.data.Dataset để AutoGraph sẽ tự động bỏ cuộn vòng huấn luyện.

def measure_graph_size(f, *args):
  g = f.get_concrete_function(*args).graph
  print("{}({}) contains {} nodes in its graph".format(
      f.__name__, ', '.join(map(str, args)), len(g.as_graph_def().node)))

@tf.function
def train(dataset):
  loss = tf.constant(0)
  for x, y in dataset:
    loss += tf.abs(y - x) # Some dummy computation.
  return loss

small_data = [(1, 1)] * 3
big_data = [(1, 1)] * 10
measure_graph_size(train, small_data)
measure_graph_size(train, big_data)

measure_graph_size(train, tf.data.Dataset.from_generator(
    lambda: small_data, (tf.int32, tf.int32)))
measure_graph_size(train, tf.data.Dataset.from_generator(
    lambda: big_data, (tf.int32, tf.int32)))
train([(1, 1), (1, 1), (1, 1)]) contains 11 nodes in its graph
train([(1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1)]) contains 32 nodes in its graph
train(<FlatMapDataset shapes: (<unknown>, <unknown>), types: (tf.int32, tf.int32)>) contains 10 nodes in its graph
train(<FlatMapDataset shapes: (<unknown>, <unknown>), types: (tf.int32, tf.int32)>) contains 10 nodes in its graph

Khi gói dữ liệu Python / Numpy trong Dataset, hãy lưu ý đến tf.data.Dataset.from_generator so với tf.data.Dataset.from_tensors . Cái trước sẽ giữ dữ liệu bằng Python và tìm nạp nó qua tf.py_function năng có thể có ý nghĩa về hiệu suất, trong khi cái sau sẽ gói một bản sao của dữ liệu dưới dạng một nút tf.constant() trong biểu đồ, có thể có ý nghĩa về bộ nhớ.

Đọc dữ liệu từ các tệp qua TFRecordDataset / CsvDataset / etc. là cách hiệu quả nhất để tiêu thụ dữ liệu, vì bản thân TensorFlow có thể quản lý việc tải và tìm nạp trước dữ liệu không đồng bộ mà không cần phải liên quan đến Python. Để tìm hiểu thêm, hãy xem hướng dẫn tf.data .

Tích lũy các giá trị trong một vòng lặp

Một mô hình phổ biến là tích lũy các giá trị trung gian từ một vòng lặp. Thông thường, điều này được thực hiện bằng cách thêm vào danh sách Python hoặc thêm các mục nhập vào từ điển Python. Tuy nhiên, vì đây là những tác dụng phụ của Python nên chúng sẽ không hoạt động như mong đợi trong một vòng lặp được mở động. Sử dụng tf.TensorArray để tích lũy kết quả từ một vòng lặp không được cuộn động.

batch_size = 2
seq_len = 3
feature_size = 4

def rnn_step(inp, state):
  return inp + state

@tf.function
def dynamic_rnn(rnn_step, input_data, initial_state):
  # [batch, time, features] -> [time, batch, features]
  input_data = tf.transpose(input_data, [1, 0, 2])
  max_seq_len = input_data.shape[0]

  states = tf.TensorArray(tf.float32, size=max_seq_len)
  state = initial_state
  for i in tf.range(max_seq_len):
    state = rnn_step(input_data[i], state)
    states = states.write(i, state)
  return tf.transpose(states.stack(), [1, 0, 2])

dynamic_rnn(rnn_step,
            tf.random.uniform([batch_size, seq_len, feature_size]),
            tf.zeros([batch_size, feature_size]))
<tf.Tensor: shape=(2, 3, 4), dtype=float32, numpy=
array([[[0.8216245 , 0.29562855, 0.379112  , 0.49940717],
        [1.6473945 , 1.039927  , 1.3268942 , 0.5298227 ],
        [2.4393063 , 1.1283967 , 2.087479  , 1.2748951 ]],

       [[0.08016336, 0.73864746, 0.33738315, 0.4542967 ],
        [0.7459605 , 1.307698  , 1.1588445 , 0.9293362 ],
        [1.3752056 , 1.6133544 , 1.8199729 , 1.7356051 ]]], dtype=float32)>

Hạn chế

Function TensorFlow có một vài hạn chế theo thiết kế mà bạn nên biết khi chuyển đổi một hàm Python thành một Function .

Thực thi các tác dụng phụ của Python

Các hiệu ứng phụ, như in, thêm vào danh sách và thay đổi hình cầu, có thể hoạt động không mong muốn bên trong một Function , đôi khi thực thi hai lần hoặc không phải tất cả. Chúng chỉ xảy ra lần đầu tiên bạn gọi một Function với một tập hợp các đầu vào. Sau đó, tf.Graph được truy tìm được thực thi lại mà không thực thi mã Python.

Nguyên tắc chung là tránh dựa vào các tác dụng phụ của Python trong logic của bạn và chỉ sử dụng chúng để gỡ lỗi các dấu vết của bạn. tf.data khác, các API TensorFlow như tf.data , tf.print , tf.summary , tf.Variable.assigntf.TensorArray là cách tốt nhất để đảm bảo mã của bạn sẽ được thực thi bởi thời gian chạy TensorFlow với mỗi lần gọi.

@tf.function
def f(x):
  print("Traced with", x)
  tf.print("Executed with", x)

f(1)
f(1)
f(2)
Traced with 1
Executed with 1
Executed with 1
Traced with 2
Executed with 2

Nếu bạn muốn thực thi mã Python trong mỗi lần gọi một Function , tf.py_function Function là một lối thoát. Hạn chế của tf.py_function là nó không di động hoặc đặc biệt hiệu quả, không thể lưu bằng SavedModel và không hoạt động tốt trong các thiết lập phân tán (đa GPU, TPU). Ngoài ra, vì tf.py_function phải được kết nối với biểu đồ, nó chuyển tất cả các đầu vào / đầu ra thành tensor.

Thay đổi các biến miễn phí và toàn cầu của Python

Việc thay đổi các biến toàn cầu và miễn phí của Python được tính là một tác dụng phụ của Python, vì vậy nó chỉ xảy ra trong quá trình truy tìm.

external_list = []

@tf.function
def side_effect(x):
  print('Python side effect')
  external_list.append(x)

side_effect(1)
side_effect(1)
side_effect(1)
# The list append only happened once!
assert len(external_list) == 1
Python side effect

Bạn nên tránh các vùng chứa thay đổi như danh sách, dicts, các đối tượng khác nằm ngoài Function . Thay vào đó, hãy sử dụng các đối số và đối tượng TF. Ví dụ: phần "Tích lũy các giá trị trong một vòng lặp" có một ví dụ về cách các hoạt động giống danh sách có thể được triển khai.

Trong một số trường hợp, bạn có thể nắm bắt và thao tác trạng thái nếu nó là tf.Variable . Đây là cách các trọng số của các mô hình Keras được cập nhật với các lệnh gọi lặp lại đến cùng một ConcreteFunction .

Sử dụng trình tạo và trình tạo Python

Nhiều tính năng của Python, chẳng hạn như trình tạo và trình vòng lặp, dựa vào thời gian chạy Python để theo dõi trạng thái. Nói chung, trong khi các cấu trúc này hoạt động như mong đợi ở chế độ háo hức, chúng là ví dụ về các tác dụng phụ của Python và do đó chỉ xảy ra trong quá trình truy tìm.

@tf.function
def buggy_consume_next(iterator):
  tf.print("Value:", next(iterator))

iterator = iter([1, 2, 3])
buggy_consume_next(iterator)
# This reuses the first value from the iterator, rather than consuming the next value.
buggy_consume_next(iterator)
buggy_consume_next(iterator)
Value: 1
Value: 1
Value: 1

Cũng giống như cách TensorFlow có tf.TensorArray chuyên biệt cho các cấu trúc danh sách, nó có tf.data.Iterator chuyên biệt cho các cấu trúc lặp. Xem phần về Biến đổi Tự động Đồ thị để biết tổng quan. Ngoài ra, API tf.data có thể giúp triển khai các mẫu trình tạo:

@tf.function
def good_consume_next(iterator):
  # This is ok, iterator is a tf.data.Iterator
  tf.print("Value:", next(iterator))

ds = tf.data.Dataset.from_tensor_slices([1, 2, 3])
iterator = iter(ds)
good_consume_next(iterator)
good_consume_next(iterator)
good_consume_next(iterator)
Value: 1
Value: 2
Value: 3

Xóa tf.Variables giữa các lệnh gọi Function

Một lỗi khác mà bạn có thể gặp phải là biến thu thập rác. ConcreteFunction s chỉ giữ lại WeakRefs cho các biến mà chúng đóng lại, vì vậy bạn phải giữ lại một tham chiếu đến bất kỳ biến nào.

external_var = tf.Variable(3)
@tf.function
def f(x):
  return x * external_var

traced_f = f.get_concrete_function(4)
print("Calling concrete function...")
print(traced_f(4))

# The original variable object gets garbage collected, since there are no more
# references to it.
external_var = tf.Variable(4)
print()
print("Calling concrete function after garbage collecting its closed Variable...")
with assert_raises(tf.errors.FailedPreconditionError):
  traced_f(4)
Calling concrete function...
tf.Tensor(12, shape=(), dtype=int32)

Calling concrete function after garbage collecting its closed Variable...
Caught expected exception 
  <class 'tensorflow.python.framework.errors_impl.FailedPreconditionError'>:
Traceback (most recent call last):
  File "<ipython-input-1-73d0ca52e838>", line 8, in assert_raises
    yield
  File "<ipython-input-1-9a93d2e07632>", line 16, in <module>
    traced_f(4)
tensorflow.python.framework.errors_impl.FailedPreconditionError: 2 root error(s) found.
  (0) Failed precondition:  Error while reading resource variable _AnonymousVar3 from Container: localhost. This could mean that the variable was uninitialized. Not found: Resource localhost/_AnonymousVar3/N10tensorflow3VarE does not exist.
     [[node ReadVariableOp (defined at <ipython-input-1-9a93d2e07632>:4) ]]
  (1) Failed precondition:  Error while reading resource variable _AnonymousVar3 from Container: localhost. This could mean that the variable was uninitialized. Not found: Resource localhost/_AnonymousVar3/N10tensorflow3VarE does not exist.
     [[node ReadVariableOp (defined at <ipython-input-1-9a93d2e07632>:4) ]]
     [[ReadVariableOp/_2]]
0 successful operations.
0 derived errors ignored. [Op:__inference_f_782]

Function call stack:
f -> f

Các vấn đề đã biết

Nếu Function của bạn đánh giá không chính xác, lỗi có thể được giải thích bởi những vấn đề đã biết này được lên kế hoạch khắc phục trong tương lai.

Tùy thuộc vào các biến toàn cầu và miễn phí của Python

Function tạo một ConcreteFunction mới khi được gọi với giá trị mới của một đối số Python. Tuy nhiên, nó không làm điều đó đối với bao đóng Python, toàn cầu hoặc không định vị của Function đó. Nếu giá trị của chúng thay đổi giữa các lần gọi đến Function , thì Function sẽ vẫn sử dụng các giá trị mà chúng đã có khi nó được truy tìm. Điều này khác với cách hoạt động của các hàm Python thông thường.

Vì lý do đó, chúng tôi đề xuất một kiểu lập trình hàm sử dụng các đối số thay vì đóng trên các tên bên ngoài.

@tf.function
def buggy_add():
  return 1 + foo

@tf.function
def recommended_add(foo):
  return 1 + foo

foo = 1
print("Buggy:", buggy_add())
print("Correct:", recommended_add(foo))
Buggy: tf.Tensor(2, shape=(), dtype=int32)
Correct: tf.Tensor(2, shape=(), dtype=int32)
print("Updating the value of `foo` to 100!")
foo = 100
print("Buggy:", buggy_add())  # Did not change!
print("Correct:", recommended_add(foo))
Updating the value of `foo` to 100!
Buggy: tf.Tensor(2, shape=(), dtype=int32)
Correct: tf.Tensor(101, shape=(), dtype=int32)

Bạn có thể đóng các tên bên ngoài, miễn là bạn không cập nhật giá trị của chúng.

Tùy thuộc vào các đối tượng Python

Đề xuất chuyển các đối tượng Python dưới dạng đối số vào tf.function có một số vấn đề đã biết, dự kiến ​​sẽ được khắc phục trong tương lai. Nói chung, bạn có thể dựa vào khả năng theo dõi nhất quán nếu bạn sử dụng cấu trúc nguyên thủy Python hoặc cấu trúc tương thích tf.nest làm đối số hoặc chuyển trong một phiên bản khác của một đối tượng vào một Function . Tuy nhiên, Function sẽ không tạo dấu vết mới khi bạn truyền cùng một đối tượng và chỉ thay đổi các thuộc tính của nó .

class SimpleModel(tf.Module):
  def __init__(self):
    # These values are *not* tf.Variables.
    self.bias = 0.
    self.weight = 2.

@tf.function
def evaluate(model, x):
  return model.weight * x + model.bias

simple_model = SimpleModel()
x = tf.constant(10.)
print(evaluate(simple_model, x))
tf.Tensor(20.0, shape=(), dtype=float32)
print("Adding bias!")
simple_model.bias += 5.0
print(evaluate(simple_model, x))  # Didn't change :(
Adding bias!
tf.Tensor(20.0, shape=(), dtype=float32)

Việc sử dụng cùng một Function để đánh giá phiên bản cập nhật của mô hình sẽ có lỗi vì mô hình được cập nhật có khóa bộ nhớ cache giống với mô hình gốc.

Vì lý do đó, chúng tôi khuyên bạn nên viết Function của mình để tránh phụ thuộc vào thuộc tính đối tượng có thể thay đổi hoặc tạo đối tượng mới.

Nếu điều đó không thể thực hiện được, một giải pháp là tạo các Function mới mỗi khi bạn sửa đổi đối tượng của mình để buộc thực hiện lại:

def evaluate(model, x):
  return model.weight * x + model.bias

new_model = SimpleModel()
evaluate_no_bias = tf.function(evaluate).get_concrete_function(new_model, x)
# Don't pass in `new_model`, `Function` already captured its state during tracing.
print(evaluate_no_bias(x))
tf.Tensor(20.0, shape=(), dtype=float32)
print("Adding bias!")
new_model.bias += 5.0
# Create new Function and ConcreteFunction since you modified new_model.
evaluate_with_bias = tf.function(evaluate).get_concrete_function(new_model, x)
print(evaluate_with_bias(x)) # Don't pass in `new_model`.
Adding bias!
tf.Tensor(25.0, shape=(), dtype=float32)

Vì việc thử lại có thể tốn kém , bạn có thể sử dụng tf.Variable s làm thuộc tính đối tượng, có thể bị thay đổi (nhưng không được thay đổi, cẩn thận!) Để có hiệu ứng tương tự mà không cần truy xuất lại.

class BetterModel:

  def __init__(self):
    self.bias = tf.Variable(0.)
    self.weight = tf.Variable(2.)

@tf.function
def evaluate(model, x):
  return model.weight * x + model.bias

better_model = BetterModel()
print(evaluate(better_model, x))
tf.Tensor(20.0, shape=(), dtype=float32)
print("Adding bias!")
better_model.bias.assign_add(5.0)  # Note: instead of better_model.bias += 5
print(evaluate(better_model, x))  # This works!
Adding bias!
tf.Tensor(25.0, shape=(), dtype=float32)

Tạo tf.Variables

Function chỉ hỗ trợ tạo biến một lần, khi được gọi lần đầu và sau đó sử dụng lại chúng. Bạn không thể tạo tf.Variables trong các dấu vết mới. Việc tạo các biến mới trong các cuộc gọi tiếp theo hiện không được phép, nhưng sẽ có trong tương lai.

Thí dụ:

@tf.function
def f(x):
  v = tf.Variable(1.0)
  return v

with assert_raises(ValueError):
  f(1.0)
Caught expected exception 
  <class 'ValueError'>:
Traceback (most recent call last):
  File "<ipython-input-1-73d0ca52e838>", line 8, in assert_raises
    yield
  File "<ipython-input-1-8a0913e250e0>", line 7, in <module>
    f(1.0)
ValueError: in user code:

    <ipython-input-1-8a0913e250e0>:3 f  *
        v = tf.Variable(1.0)
    /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/ops/variables.py:262 __call__  **
        return cls._variable_v2_call(*args, **kwargs)
    /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/ops/variables.py:256 _variable_v2_call
        shape=shape)
    /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/ops/variables.py:67 getter
        return captured_getter(captured_previous, **kwargs)
    /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/eager/def_function.py:731 invalid_creator_scope
        "tf.function-decorated function tried to create "

    ValueError: tf.function-decorated function tried to create variables on non-first call.

Bạn có thể tạo các biến bên trong một Function miễn là các biến đó chỉ được tạo trong lần đầu tiên hàm được thực thi.

class Count(tf.Module):
  def __init__(self):
    self.count = None

  @tf.function
  def __call__(self):
    if self.count is None:
      self.count = tf.Variable(0)
    return self.count.assign_add(1)

c = Count()
print(c())
print(c())
tf.Tensor(1, shape=(), dtype=int32)
tf.Tensor(2, shape=(), dtype=int32)

Sử dụng với nhiều trình tối ưu hóa Keras

Bạn có thể gặp phải ValueError: tf.function-decorated function tried to create variables on non-first call. khi sử dụng nhiều hơn một trình tối ưu hóa Keras có tf.function . Lỗi này xảy ra do trình tối ưu hóa tạo tf.Variables bên trong khi chúng áp dụng gradient lần đầu tiên.

opt1 = tf.keras.optimizers.Adam(learning_rate = 1e-2)
opt2 = tf.keras.optimizers.Adam(learning_rate = 1e-3)

@tf.function
def train_step(w, x, y, optimizer):
   with tf.GradientTape() as tape:
       L = tf.reduce_sum(tf.square(w*x - y))
   gradients = tape.gradient(L, [w])
   optimizer.apply_gradients(zip(gradients, [w]))

w = tf.Variable(2.)
x = tf.constant([-1.])
y = tf.constant([2.])

train_step(w, x, y, opt1)
print("Calling `train_step` with different optimizer...")
with assert_raises(ValueError):
  train_step(w, x, y, opt2)
Calling `train_step` with different optimizer...
Caught expected exception 
  <class 'ValueError'>:
Traceback (most recent call last):
  File "<ipython-input-1-73d0ca52e838>", line 8, in assert_raises
    yield
  File "<ipython-input-1-d3d3937dbf1a>", line 18, in <module>
    train_step(w, x, y, opt2)
ValueError: in user code:

    <ipython-input-1-d3d3937dbf1a>:9 train_step  *
        optimizer.apply_gradients(zip(gradients, [w]))
    /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/keras/optimizer_v2/optimizer_v2.py:604 apply_gradients  **
        self._create_all_weights(var_list)
    /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/keras/optimizer_v2/optimizer_v2.py:781 _create_all_weights
        _ = self.iterations
    /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/keras/optimizer_v2/optimizer_v2.py:788 __getattribute__
        return super(OptimizerV2, self).__getattribute__(name)
    /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/keras/optimizer_v2/optimizer_v2.py:926 iterations
        aggregation=tf_variables.VariableAggregation.ONLY_FIRST_REPLICA)
    /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/keras/optimizer_v2/optimizer_v2.py:1132 add_weight
        aggregation=aggregation)
    /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/training/tracking/base.py:810 _add_variable_with_custom_getter
        **kwargs_for_getter)
    /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/keras/engine/base_layer_utils.py:142 make_variable
        shape=variable_shape if variable_shape else None)
    /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/ops/variables.py:260 __call__
        return cls._variable_v1_call(*args, **kwargs)
    /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/ops/variables.py:221 _variable_v1_call
        shape=shape)
    /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/ops/variables.py:67 getter
        return captured_getter(captured_previous, **kwargs)
    /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/eager/def_function.py:731 invalid_creator_scope
        "tf.function-decorated function tried to create "

    ValueError: tf.function-decorated function tried to create variables on non-first call.

Nếu bạn cần thay đổi trình tối ưu hóa trong quá trình đào tạo, cách giải quyết là tạo một Function mới cho từng trình tối ưu hóa, gọi trực tiếp ConcreteFunction .

opt1 = tf.keras.optimizers.Adam(learning_rate = 1e-2)
opt2 = tf.keras.optimizers.Adam(learning_rate = 1e-3)

# Not a tf.function.
def train_step(w, x, y, optimizer):
   with tf.GradientTape() as tape:
       L = tf.reduce_sum(tf.square(w*x - y))
   gradients = tape.gradient(L, [w])
   optimizer.apply_gradients(zip(gradients, [w]))

w = tf.Variable(2.)
x = tf.constant([-1.])
y = tf.constant([2.])

# Make a new Function and ConcreteFunction for each optimizer.
train_step_1 = tf.function(train_step).get_concrete_function(w, x, y, opt1)
train_step_2 = tf.function(train_step).get_concrete_function(w, x, y, opt2)
for i in range(10):
  if i % 2 == 0:
    train_step_1(w, x, y) # `opt1` is not used as a parameter. 
  else:
    train_step_2(w, x, y) # `opt2` is not used as a parameter.

Sử dụng với nhiều mô hình Keras

Bạn cũng có thể gặp phải ValueError: tf.function-decorated function tried to create variables on non-first call. khi chuyển các thể hiện mô hình khác nhau cho cùng một Function .

Lỗi này xảy ra do các mô hình Keras ( không có hình dạng đầu vào được xác định ) và các lớp tf.Variables tạo các tf.Variables khi chúng được gọi lần đầu tiên. Bạn có thể đang cố gắng khởi tạo các biến đó bên trong một Function , đã được gọi. Để tránh lỗi này, hãy thử gọi model.build(input_shape) để khởi tạo tất cả các trọng số trước khi huấn luyện mô hình.

đọc thêm

Để tìm hiểu về cách xuất và tải một Function , hãy xem hướng dẫn SavedModel . Để tìm hiểu thêm về tối ưu hóa đồ thị được thực hiện sau khi theo dõi, hãy xem hướng dẫn Grappler . Để tìm hiểu cách tối ưu hóa đường ống dữ liệu và lập hồ sơ cho mô hình của bạn, hãy xem hướng dẫn về Hồ sơ .