ช่วยปกป้อง Great Barrier Reef กับ TensorFlow บน Kaggle เข้าร่วมท้าทาย

ประสิทธิภาพที่ดีขึ้นด้วย tf.function

ดูบน TensorFlow.org ทำงานใน Google Colab ดูแหล่งที่มาบน GitHubดาวน์โหลดโน๊ตบุ๊ค

ใน TensorFlow 2 การดำเนินการกระตือรือร้นที่ เปิดอยู่โดยค่าเริ่มต้น อินเทอร์เฟซผู้ใช้ใช้งานง่ายและยืดหยุ่น (การดำเนินการครั้งเดียวทำได้ง่ายกว่าและเร็วกว่ามาก) แต่สิ่งนี้อาจทำให้เสียประสิทธิภาพและความสามารถในการปรับใช้

คุณสามารถใช้ tf.function กราฟให้ออกจากโปรแกรมของคุณ เป็นเครื่องมือการแปลงที่สร้างกราฟการไหลของข้อมูลที่ไม่ขึ้นกับ Python จากโค้ด Python ของคุณ นี้จะช่วยให้คุณสร้างแบบจำลอง performant และแบบพกพาและมันจะต้องใช้ SavedModel

คู่มือนี้จะช่วยให้คุณสร้างกรอบความคิดวิธี tf.function ทำงานภายใต้ฝากระโปรงเพื่อให้คุณสามารถใช้งานได้อย่างมีประสิทธิภาพ

ประเด็นหลักและคำแนะนำคือ:

  • ตรวจแก้จุดบกพร่องในโหมดความกระตือรือร้นที่แล้วตกแต่งด้วย @tf.function
  • อย่าพึ่งพาผลข้างเคียงของ Python เช่น การกลายพันธุ์ของวัตถุหรือการผนวกรายการ
  • tf.function ทำงานได้ดีที่สุดกับ Ops TensorFlow; การเรียก NumPy และ Python จะถูกแปลงเป็นค่าคงที่

ติดตั้ง

import tensorflow as tf

กำหนดฟังก์ชันตัวช่วยเพื่อสาธิตประเภทข้อผิดพลาดที่คุณอาจพบ:

import traceback
import contextlib

# Some helper code to demonstrate the kinds of errors you might encounter.
@contextlib.contextmanager
def assert_raises(error_class):
  try:
    yield
  except error_class as e:
    print('Caught expected exception \n  {}:'.format(error_class))
    traceback.print_exc(limit=2)
  except Exception as e:
    raise e
  else:
    raise Exception('Expected {} to be raised but no error was raised!'.format(
        error_class))

พื้นฐาน

การใช้งาน

Function ที่คุณกำหนด (เช่นโดยการใช้ @tf.function มัณฑนากร) เป็นเช่นเดียวกับการดำเนินการหลัก TensorFlow: คุณสามารถรันมันกระหาย; คุณสามารถคำนวณการไล่ระดับสี และอื่นๆ

@tf.function  # The decorator converts `add` into a `Function`.
def add(a, b):
  return a + b

add(tf.ones([2, 2]), tf.ones([2, 2]))  #  [[2., 2.], [2., 2.]]
<tf.Tensor: shape=(2, 2), dtype=float32, numpy=
array([[2., 2.],
       [2., 2.]], dtype=float32)>
v = tf.Variable(1.0)
with tf.GradientTape() as tape:
  result = add(v, 1.0)
tape.gradient(result, v)
<tf.Tensor: shape=(), dtype=float32, numpy=1.0>

คุณสามารถใช้ Function s ภายในอื่น ๆ Function s

@tf.function
def dense_layer(x, w, b):
  return add(tf.matmul(x, w), b)

dense_layer(tf.ones([3, 2]), tf.ones([2, 2]), tf.ones([2]))
<tf.Tensor: shape=(3, 2), dtype=float32, numpy=
array([[3., 3.],
       [3., 3.],
       [3., 3.]], dtype=float32)>

Function ที่สามารถจะเร็วกว่ารหัสกระตือรือร้นโดยเฉพาะอย่างยิ่งสำหรับกราฟกับปฏิบัติการขนาดเล็กจำนวนมาก แต่สำหรับกราฟที่มี ops ราคาแพง (เช่น การบิด) คุณอาจไม่เห็นการเร่งความเร็วมากนัก

import timeit
conv_layer = tf.keras.layers.Conv2D(100, 3)

@tf.function
def conv_fn(image):
  return conv_layer(image)

image = tf.zeros([1, 200, 200, 100])
# Warm up
conv_layer(image); conv_fn(image)
print("Eager conv:", timeit.timeit(lambda: conv_layer(image), number=10))
print("Function conv:", timeit.timeit(lambda: conv_fn(image), number=10))
print("Note how there's not much difference in performance for convolutions")
Eager conv: 0.006058974999177735
Function conv: 0.005791576000774512
Note how there's not much difference in performance for convolutions

ติดตาม

ในส่วนนี้จะหมายความว่า Function การทำงานภายใต้ประทุนรวมทั้งรายละเอียดการดำเนินงานซึ่งอาจมีการเปลี่ยนแปลงในอนาคต แต่เมื่อคุณเข้าใจว่าทำไมและเมื่อติดตามที่เกิดขึ้นมันเป็นเรื่องง่ายมากที่จะใช้ tf.function ได้อย่างมีประสิทธิภาพ!

"การติดตาม" คืออะไร?

Function รันโปรแกรมของคุณใน TensorFlow กราฟ อย่างไรก็ตาม tf.Graph ไม่สามารถเป็นตัวแทนทุกสิ่งที่คุณต้องการเขียนในโปรแกรม TensorFlow กระตือรือร้น ยกตัวอย่างเช่นงูหลามสนับสนุน polymorphism แต่ tf.Graph ต้องใช้ปัจจัยการผลิตที่จะมีประเภทของข้อมูลที่ระบุและมิติ หรือคุณอาจทำงานข้างเคียง เช่น การอ่านอาร์กิวเมนต์บรรทัดคำสั่ง การทำให้เกิดข้อผิดพลาด หรือการทำงานกับอ็อบเจ็กต์ Python ที่ซับซ้อนยิ่งขึ้น ไม่มีสิ่งเหล่านี้สามารถทำงานใน tf.Graph

Function เชื่อมโยงช่องว่างนี้โดยการแยกรหัสของคุณในสองขั้นตอน:

1) ในขั้นแรกจะเรียกว่า "ติดตาม" Function สร้างใหม่ tf.Graph รหัสงูใหญ่วิ่งได้ตามปกติ แต่การดำเนินงานทั้งหมด TensorFlow (เช่นการเพิ่มสองเทนเซอร์) จะรอการตัดบัญชี: พวกเขาจะถูกจับโดย tf.Graph และไม่ทำงาน

2) ในขั้นตอนที่สอง tf.Graph ซึ่งมีทุกอย่างที่ได้รับการรอการตัดบัญชีในขั้นตอนแรกคือการทำงาน ขั้นตอนนี้เร็วกว่าระยะการติดตามมาก

ทั้งนี้ขึ้นอยู่กับปัจจัยการผลิตของ Function จะไม่ทำงานในขั้นตอนแรกเสมอเมื่อมันถูกเรียกว่า ดู "กฎการติดตาม" ด้านล่างเพื่อรับความรู้สึกที่ดีขึ้นของวิธีการที่จะทำให้ความมุ่งมั่นว่า การข้ามขั้นตอนแรกและดำเนินการเฉพาะขั้นตอนที่สองคือสิ่งที่ให้ประสิทธิภาพสูงของ TensorFlow

เมื่อ Function ไม่ตัดสินใจที่จะติดตามขั้นตอนการติดตามจะทันทีตามขั้นตอนที่สองเพื่อการเรียก Function ทั้งสร้างและทำงาน tf.Graph หลังจากนั้นคุณจะเห็นว่าคุณสามารถทำงานได้เพียงขั้นตอนการติดตามกับ get_concrete_function

เมื่อคุณผ่านการขัดแย้งของประเภทที่แตกต่างกันเป็น Function ทั้งสองขั้นตอนจะดำเนินการ:

@tf.function
def double(a):
  print("Tracing with", a)
  return a + a

print(double(tf.constant(1)))
print()
print(double(tf.constant(1.1)))
print()
print(double(tf.constant("a")))
print()
Tracing with Tensor("a:0", shape=(), dtype=int32)
tf.Tensor(2, shape=(), dtype=int32)

Tracing with Tensor("a:0", shape=(), dtype=float32)
tf.Tensor(2.2, shape=(), dtype=float32)

Tracing with Tensor("a:0", shape=(), dtype=string)
tf.Tensor(b'aa', shape=(), dtype=string)

โปรดทราบว่าถ้าคุณซ้ำ ๆ เรียก Function ที่มีประเภทอาร์กิวเมนต์เดียวกัน TensorFlow จะข้ามขั้นตอนการติดตามและนำมาใช้กราฟตรวจสอบก่อนหน้านี้เป็นรูปแบบของกราฟที่สร้างขึ้นจะเป็นเหมือนกัน

# This doesn't print 'Tracing with ...'
print(double(tf.constant("b")))
tf.Tensor(b'bb', shape=(), dtype=string)

คุณสามารถใช้ pretty_printed_concrete_signatures() เพื่อดูร่องรอยที่มีอยู่:

print(double.pretty_printed_concrete_signatures())
double(a)
  Args:
    a: int32 Tensor, shape=()
  Returns:
    int32 Tensor, shape=()

double(a)
  Args:
    a: float32 Tensor, shape=()
  Returns:
    float32 Tensor, shape=()

double(a)
  Args:
    a: string Tensor, shape=()
  Returns:
    string Tensor, shape=()

เพื่อให้ห่างไกลที่คุณเห็นว่า tf.function สร้างแคชชั้นจัดส่งแบบไดนามิกมากกว่าตรรกะกราฟติดตาม TensorFlow ของ เพื่อให้เฉพาะเจาะจงมากขึ้นเกี่ยวกับคำศัพท์:

  • tf.Graph เป็นดิบภาษาไม่เชื่อเรื่องพระเจ้าเป็นตัวแทนแบบพกพาของการคำนวณ TensorFlow
  • ConcreteFunction wraps tf.Graph
  • Function การบริหารจัดการแคชของ ConcreteFunction และหยิบหนึ่งที่เหมาะสมสำหรับปัจจัยการผลิตของคุณ
  • tf.function ตัดฟังก์ชั่นหลามกลับ Function วัตถุ
  • ติดตามจะสร้าง tf.Graph และตัดไว้ใน ConcreteFunction ยังเป็นที่รู้จักร่องรอย

กฎการติดตาม

Function กำหนดว่าใช้ซ้ำสืบ ConcreteFunction โดยการคำนวณคีย์แคชจากการป้อนข้อมูลของ args และ kwargs คีย์แคชเป็นกุญแจสำคัญที่ระบุ ConcreteFunction บนพื้นฐานของการป้อนข้อมูลและ args kwargs ของ Function การโทรตามกฎต่อไปนี้ (ซึ่งอาจมีการเปลี่ยนแปลง):

  • กุญแจสำคัญในการสร้างสำหรับ tf.Tensor เป็นรูปร่างและ dtype ของมัน
  • กุญแจสำคัญในการสร้างสำหรับ tf.Variable เป็นรหัสตัวแปรที่ไม่ซ้ำกัน
  • กุญแจสำคัญในการสร้างขึ้นสำหรับหลามดั้งเดิม (เช่น int , float , str ) เป็นความคุ้มค่า
  • กุญแจสำคัญในการสร้างขึ้นสำหรับซ้อนกัน dict s, list s, tuple s, namedtuple s และ attr s เป็น tuple บี้ใบกุญแจ (ดู nest.flatten ) (เนื่องจากการแบนราบนี้ การเรียกฟังก์ชันที่เป็นรูปธรรมด้วยโครงสร้างการซ้อนที่แตกต่างจากที่ใช้ระหว่างการติดตามจะส่งผลให้เกิด TypeError)
  • สำหรับประเภท Python อื่นๆ ทั้งหมด คีย์จะไม่ซ้ำกับอ็อบเจ็กต์ วิธีนี้จะมีการติดตามฟังก์ชันหรือเมธอดอย่างอิสระสำหรับแต่ละอินสแตนซ์ที่เรียกใช้ด้วย

การควบคุมการถอยกลับ

ย้อนกลับซึ่งเมื่อคุณ Function สร้างมากกว่าหนึ่งร่องรอยจะช่วยให้มั่นใจได้ว่า TensorFlow สร้างกราฟที่ถูกต้องสำหรับการตั้งค่าของปัจจัยการผลิตแต่ละ อย่างไรก็ตาม การติดตามเป็นการดำเนินการที่มีราคาแพง! ถ้าคุณ Function retraces กราฟใหม่สำหรับทุกสาย, คุณจะพบว่ารันรหัสของคุณช้ากว่าถ้าคุณไม่ได้ใช้ tf.function

ในการควบคุมลักษณะการติดตาม คุณสามารถใช้เทคนิคต่อไปนี้:

  • ระบุ input_signature ใน tf.function ถึงขีด จำกัด การติดตาม
@tf.function(input_signature=(tf.TensorSpec(shape=[None], dtype=tf.int32),))
def next_collatz(x):
  print("Tracing with", x)
  return tf.where(x % 2 == 0, x // 2, 3 * x + 1)

print(next_collatz(tf.constant([1, 2])))
# You specified a 1-D tensor in the input signature, so this should fail.
with assert_raises(ValueError):
  next_collatz(tf.constant([[1, 2], [3, 4]]))

# You specified an int32 dtype in the input signature, so this should fail.
with assert_raises(ValueError):
  next_collatz(tf.constant([1.0, 2.0]))
Tracing with Tensor("x:0", shape=(None,), dtype=int32)
tf.Tensor([4 1], shape=(2,), dtype=int32)
Caught expected exception 
  <class 'ValueError'>:
Caught expected exception 
  <class 'ValueError'>:
Traceback (most recent call last):
  File "/tmp/ipykernel_26244/3551158538.py", line 8, in assert_raises
    yield
  File "/tmp/ipykernel_26244/1851403433.py", line 9, in <module>
    next_collatz(tf.constant([[1, 2], [3, 4]]))
ValueError: Python inputs incompatible with input_signature:
  inputs: (
    tf.Tensor(
[[1 2]
 [3 4]], shape=(2, 2), dtype=int32))
  input_signature: (
    TensorSpec(shape=(None,), dtype=tf.int32, name=None)).
Traceback (most recent call last):
  File "/tmp/ipykernel_26244/3551158538.py", line 8, in assert_raises
    yield
  File "/tmp/ipykernel_26244/1851403433.py", line 13, in <module>
    next_collatz(tf.constant([1.0, 2.0]))
ValueError: Python inputs incompatible with input_signature:
  inputs: (
    tf.Tensor([1. 2.], shape=(2,), dtype=float32))
  input_signature: (
    TensorSpec(shape=(None,), dtype=tf.int32, name=None)).
  • ระบุ [ไม่มี] มิติใน tf.TensorSpec เพื่อให้มีความยืดหยุ่นในการติดตามนำมาใช้ใหม่

    ตั้งแต่ TensorFlow ตรง tensors ขึ้นอยู่กับรูปร่างของพวกเขาโดยใช้ None มิติเป็นสัญลักษณ์แทนจะช่วยให้ Function ที่จะนำมาใช้ใหม่ร่องรอยสำหรับการป้อนข้อมูลแตกขนาด การป้อนข้อมูลแตกขนาดสามารถเกิดขึ้นได้ถ้าคุณมีลำดับของความยาวที่แตกต่างกันหรือภาพที่มีขนาดแตกต่างกันสำหรับแต่ละชุด (ดู หม้อแปลง และ ลึกฝัน บทเรียนตัวอย่าง)

@tf.function(input_signature=(tf.TensorSpec(shape=[None], dtype=tf.int32),))
def g(x):
  print('Tracing with', x)
  return x

# No retrace!
print(g(tf.constant([1, 2, 3])))
print(g(tf.constant([1, 2, 3, 4, 5])))
Tracing with Tensor("x:0", shape=(None,), dtype=int32)
tf.Tensor([1 2 3], shape=(3,), dtype=int32)
tf.Tensor([1 2 3 4 5], shape=(5,), dtype=int32)
  • ส่งอาร์กิวเมนต์ Python ไปยัง Tensors เพื่อลดการย้อนกลับ

    มักจะมีปากเสียงงูใหญ่จะใช้ในการควบคุมและการ hyperparameters ก่อสร้างกราฟ - ตัวอย่างเช่น num_layers=10 หรือ training=True หรือ nonlinearity='relu' ดังนั้น หากอาร์กิวเมนต์ Python เปลี่ยนไป คุณจะต้องย้อนกราฟ

    อย่างไรก็ตาม อาจไม่ได้ใช้อาร์กิวเมนต์ Python เพื่อควบคุมการสร้างกราฟ ในกรณีเหล่านี้ การเปลี่ยนแปลงในค่า Python สามารถทริกเกอร์การย้อนกลับโดยไม่จำเป็น ตัวอย่างเช่น วนรอบการฝึกนี้ ซึ่ง AutoGraph จะคลี่คลายแบบไดนามิก แม้ว่าจะมีการติดตามหลายรายการ แต่กราฟที่สร้างขึ้นนั้นเหมือนกันทุกประการ ดังนั้นการย้อนรอยจึงไม่จำเป็น

def train_one_step():
  pass

@tf.function
def train(num_steps):
  print("Tracing with num_steps = ", num_steps)
  tf.print("Executing with num_steps = ", num_steps)
  for _ in tf.range(num_steps):
    train_one_step()

print("Retracing occurs for different Python arguments.")
train(num_steps=10)
train(num_steps=20)

print()
print("Traces are reused for Tensor arguments.")
train(num_steps=tf.constant(10))
train(num_steps=tf.constant(20))
Retracing occurs for different Python arguments.
Tracing with num_steps =  10
Executing with num_steps =  10
Tracing with num_steps =  20
Executing with num_steps =  20

Traces are reused for Tensor arguments.
Tracing with num_steps =  Tensor("num_steps:0", shape=(), dtype=int32)
Executing with num_steps =  10
Executing with num_steps =  20

หากคุณจำเป็นต้องมีผลบังคับใช้ย้อนกลับสร้างใหม่ Function เฉพาะกิจ Function วัตถุที่มีการรับประกันไม่ร่องรอยหุ้น

def f():
  print('Tracing!')
  tf.print('Executing')

tf.function(f)()
tf.function(f)()
Tracing!
Executing
Tracing!
Executing

การรับฟังก์ชั่นที่เป็นรูปธรรม

ทุกครั้งที่มีการติดตามฟังก์ชัน จะมีการสร้างฟังก์ชันที่เป็นรูปธรรมขึ้นใหม่ คุณโดยตรงสามารถขอรับฟังก์ชั่นที่เป็นรูปธรรมโดยใช้ get_concrete_function

print("Obtaining concrete trace")
double_strings = double.get_concrete_function(tf.constant("a"))
print("Executing traced function")
print(double_strings(tf.constant("a")))
print(double_strings(a=tf.constant("b")))
Obtaining concrete trace
Executing traced function
tf.Tensor(b'aa', shape=(), dtype=string)
tf.Tensor(b'bb', shape=(), dtype=string)
# You can also call get_concrete_function on an InputSpec
double_strings_from_inputspec = double.get_concrete_function(tf.TensorSpec(shape=[], dtype=tf.string))
print(double_strings_from_inputspec(tf.constant("c")))
tf.Tensor(b'cc', shape=(), dtype=string)

พิมพ์ ConcreteFunction แสดงบทสรุปของข้อโต้แย้งปัจจัยการผลิต (ประเภท) และประเภทเอาท์พุท

print(double_strings)
ConcreteFunction double(a)
  Args:
    a: string Tensor, shape=()
  Returns:
    string Tensor, shape=()

คุณยังสามารถดึงข้อมูลลายเซ็นของฟังก์ชันที่เป็นรูปธรรมได้โดยตรง

print(double_strings.structured_input_signature)
print(double_strings.structured_outputs)
((TensorSpec(shape=(), dtype=tf.string, name='a'),), {})
Tensor("Identity:0", shape=(), dtype=string)

การใช้การติดตามที่เป็นรูปธรรมกับประเภทที่เข้ากันไม่ได้จะทำให้เกิดข้อผิดพลาด

with assert_raises(tf.errors.InvalidArgumentError):
  double_strings(tf.constant(1))
Caught expected exception 
  <class 'tensorflow.python.framework.errors_impl.InvalidArgumentError'>:
Traceback (most recent call last):
  File "/tmp/ipykernel_26244/3551158538.py", line 8, in assert_raises
    yield
  File "/tmp/ipykernel_26244/3196284684.py", line 2, in <module>
    double_strings(tf.constant(1))
tensorflow.python.framework.errors_impl.InvalidArgumentError: cannot compute __inference_double_162 as input #0(zero-based) was expected to be a string tensor but is a int32 tensor [Op:__inference_double_162]

คุณอาจสังเกตเห็นว่าอาร์กิวเมนต์ Python ได้รับการดูแลเป็นพิเศษในลายเซ็นอินพุตของฟังก์ชันที่เป็นรูปธรรม ก่อนหน้า TensorFlow 2.3 อาร์กิวเมนต์ Python จะถูกลบออกจากลายเซ็นของฟังก์ชันที่เป็นรูปธรรม เริ่มต้นด้วย TensorFlow 2.3 อาร์กิวเมนต์ Python จะยังคงอยู่ในลายเซ็น แต่ถูกจำกัดให้รับค่าที่ตั้งไว้ระหว่างการติดตาม

@tf.function
def pow(a, b):
  return a ** b

square = pow.get_concrete_function(a=tf.TensorSpec(None, tf.float32), b=2)
print(square)
ConcreteFunction pow(a, b=2)
  Args:
    a: float32 Tensor, shape=<unknown>
  Returns:
    float32 Tensor, shape=<unknown>
assert square(tf.constant(10.0)) == 100

with assert_raises(TypeError):
  square(tf.constant(10.0), b=3)
Caught expected exception 
  <class 'TypeError'>:
Traceback (most recent call last):
  File "/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/eager/function.py", line 1721, in _call_impl
    cancellation_manager)
  File "/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/eager/function.py", line 1765, in _call_with_flat_signature
    raise TypeError(f"{self._flat_signature_summary()} got unexpected "
TypeError: pow(a) got unexpected keyword arguments: b.

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/tmp/ipykernel_26244/3551158538.py", line 8, in assert_raises
    yield
  File "/tmp/ipykernel_26244/2310937119.py", line 4, in <module>
    square(tf.constant(10.0), b=3)
TypeError: ConcreteFunction pow(a, b) was constructed with int value 2 in b, but was called with int value 3.

รับกราฟ

ฟังก์ชั่นที่เป็นรูปธรรมแต่ละเป็นเสื้อคลุม callable รอบ tf.Graph แม้ว่าการเรียกที่เกิดขึ้นจริง tf.Graph วัตถุไม่ใช่สิ่งที่คุณตามปกติจะต้องทำคุณสามารถได้รับมันได้อย่างง่ายดายจากการทำงานที่เป็นรูปธรรมใด ๆ

graph = double_strings.graph
for node in graph.as_graph_def().node:
  print(f'{node.input} -> {node.name}')
[] -> a
['a', 'a'] -> add
['add'] -> Identity

แก้จุดบกพร่อง

โดยทั่วไปรหัสแก้จุดบกพร่องเป็นเรื่องง่ายในโหมดความกระตือรือร้นกว่าภายใน tf.function คุณควรแน่ใจว่ารหัสของคุณรันปราศจากข้อผิดพลาดในโหมดความกระตือรือร้นที่ก่อนที่จะตกแต่งด้วย tf.function เพื่อช่วยในกระบวนการแก้จุดบกพร่องคุณสามารถเรียก tf.config.run_functions_eagerly(True) ที่จะปิดการใช้งานทั่วโลกและ reenable tf.function

เมื่อติดตามการลงปัญหาที่จะปรากฏใน tf.function นี่คือเคล็ดลับบางอย่าง

  • ธรรมดาเก่าหลาม print โทรเพียงดำเนินการในระหว่างการติดตามช่วยให้คุณติดตามลงเมื่อฟังก์ชั่นของคุณได้รับ (อีกครั้ง) ตรวจสอบ
  • tf.print โทรจะดำเนินการทุกครั้งและสามารถช่วยให้คุณติดตามค่ากลางระหว่างการดำเนินการ
  • tf.debugging.enable_check_numerics เป็นวิธีที่ง่ายในการติดตามที่แก่นแก้วและ Inf จะถูกสร้างขึ้น
  • pdb (คน ดีบักงูใหญ่ ) สามารถช่วยให้คุณเข้าใจสิ่งที่เกิดขึ้นในระหว่างการติดตาม (ข้อแม้: pdb จะวางคุณลงในซอร์สโค้ด Autograph-เปลี่ยน.)

การแปลงกราฟอัตโนมัติ

ลายเซ็นเป็นห้องสมุดที่ตามค่าเริ่มต้นใน tf.function และแปลงย่อยของงูหลามรหัสกระตือรือร้นเข้าไปในกราฟได้ Ops TensorFlow ซึ่งรวมถึงการควบคุมการไหลเช่น if , for , while

Ops TensorFlow เช่น tf.cond และ tf.while_loop ยังคงทำงาน แต่การควบคุมการไหลมักจะง่ายต่อการเขียนและเข้าใจเมื่อเขียนในหลาม

# A simple loop

@tf.function
def f(x):
  while tf.reduce_sum(x) > 1:
    tf.print(x)
    x = tf.tanh(x)
  return x

f(tf.random.uniform([5]))
[0.666458249 0.713946581 0.723879576 0.330758929 0.184087753]
[0.582645297 0.613145649 0.619306684 0.319202513 0.182036072]
[0.524585426 0.546337605 0.550645113 0.308785647 0.18005164]
[0.481231302 0.497770309 0.501003504 0.299331933 0.178130865]
[0.447229207 0.460361809 0.462906033 0.290701121 0.176270396]
[0.419618756 0.430379033 0.432449728 0.282779962 0.174467146]
[0.396609187 0.405638 0.407366514 0.275476 0.172718227]
[0.377043903 0.384762734 0.386234313 0.268712848 0.17102097]
[0.360137492 0.366836458 0.368109286 0.262426734 0.169372901]
[0.345335096 0.351221472 0.352336824 0.256563932 0.167771652]
[0.332231969 0.337458342 0.338446289 0.251078814 0.166215062]
[0.320524871 0.325206399 0.326089561 0.24593246 0.164701089]
[0.309981436 0.314206958 0.31500268 0.241091311 0.163227797]
[0.300420195 0.304259449 0.304981351 0.236526251 0.161793426]
[0.291697085 0.295205742 0.295864582 0.232211992 0.160396278]
[0.283696055 0.286919087 0.287523568 0.228126258 0.159034774]
[0.276322395 0.279296666 0.27985391 0.224249557 0.157707423]
[0.269497961 0.272254 0.272769839 0.220564634 0.15641281]
[0.263157606 0.265720904 0.266200244 0.21705614 0.155149609]
[0.257246554 0.259638608 0.260085613 0.213710397 0.153916568]
[0.251718313 0.25395745 0.254375577 0.210515186 0.152712509]
[0.246533215 0.248635098 0.249027327 0.207459539 0.151536316]
[0.241657034 0.243635193 0.244004101 0.204533577 0.15038693]
[0.237060249 0.238926381 0.239274174 0.201728329 0.149263337]
[0.232717097 0.234481394 0.234810054 0.199035719 0.148164615]
[0.228605017 0.230276451 0.230587661 0.196448416 0.147089839]
[0.224704206 0.226290658 0.22658591 0.193959698 0.14603813]
[0.220997125 0.222505584 0.222786173 0.191563457 0.145008713]
<tf.Tensor: shape=(5,), dtype=float32, numpy=
array([0.21746822, 0.21890487, 0.21917202, 0.18925412, 0.14400077],
      dtype=float32)>

หากคุณสงสัย คุณสามารถตรวจสอบรหัสลายเซ็นที่สร้างขึ้นได้

print(tf.autograph.to_code(f.python_function))
def tf__f(x):
    with ag__.FunctionScope('f', 'fscope', ag__.ConversionOptions(recursive=True, user_requested=True, optional_features=(), internal_convert_user_code=True)) as fscope:
        do_return = False
        retval_ = ag__.UndefinedReturnValue()

        def get_state():
            return (x,)

        def set_state(vars_):
            nonlocal x
            (x,) = vars_

        def loop_body():
            nonlocal x
            ag__.converted_call(ag__.ld(tf).print, (ag__.ld(x),), None, fscope)
            x = ag__.converted_call(ag__.ld(tf).tanh, (ag__.ld(x),), None, fscope)

        def loop_test():
            return (ag__.converted_call(ag__.ld(tf).reduce_sum, (ag__.ld(x),), None, fscope) > 1)
        ag__.while_stmt(loop_test, loop_body, get_state, set_state, ('x',), {})
        try:
            do_return = True
            retval_ = ag__.ld(x)
        except:
            do_return = False
            raise
        return fscope.ret(retval_, do_return)

เงื่อนไข

ลายเซ็นจะแปลงบางอย่าง if <condition> งบเข้าไปในเทียบเท่า tf.cond โทร เปลี่ยนตัวนี้จะทำถ้า <condition> เป็น Tensor มิฉะนั้น if คำสั่งจะถูกดำเนินการเป็นเงื่อนไขหลาม

Python Conditional ดำเนินการระหว่างการติดตาม ดังนั้นหนึ่งสาขาของเงื่อนไขจะถูกเพิ่มลงในกราฟ หากไม่มี AutoGraph กราฟที่ติดตามนี้จะไม่สามารถใช้สาขาสำรองได้หากมีโฟลว์การควบคุมที่ขึ้นกับข้อมูล

tf.cond ร่องรอยและเพิ่มทั้งสาขาของเงื่อนไขกราฟแบบไดนามิกการเลือกสาขาที่เวลาดำเนินการ การติดตามอาจมีผลข้างเคียงที่ไม่ได้ตั้งใจ ตรวจสอบ ผลกระทบที่ติดตามลายเซ็น สำหรับข้อมูลเพิ่มเติม

@tf.function
def fizzbuzz(n):
  for i in tf.range(1, n + 1):
    print('Tracing for loop')
    if i % 15 == 0:
      print('Tracing fizzbuzz branch')
      tf.print('fizzbuzz')
    elif i % 3 == 0:
      print('Tracing fizz branch')
      tf.print('fizz')
    elif i % 5 == 0:
      print('Tracing buzz branch')
      tf.print('buzz')
    else:
      print('Tracing default branch')
      tf.print(i)

fizzbuzz(tf.constant(5))
fizzbuzz(tf.constant(20))
Tracing for loop
Tracing fizzbuzz branch
Tracing fizz branch
Tracing buzz branch
Tracing default branch
1
2
fizz
4
buzz
1
2
fizz
4
buzz
fizz
7
8
fizz
buzz
11
fizz
13
14
fizzbuzz
16
17
fizz
19
buzz

ดู เอกสารอ้างอิง สำหรับข้อ จำกัด เพิ่มเติมเกี่ยวกับลายเซ็นที่แปลงถ้างบ

ลูป

ลายเซ็นจะแปลงบางส่วน for และ while งบเข้าไปใน TensorFlow เทียบเท่าบ่วง Ops เช่น tf.while_loop ถ้าไม่ได้แปลงที่ for หรือ while วงจะถูกดำเนินการเป็นห่วงหลาม

การทดแทนนี้ทำขึ้นในสถานการณ์ต่อไปนี้:

  • for x in y : ถ้า y เป็น Tensor, แปลง tf.while_loop ในกรณีพิเศษที่ y เป็น tf.data.Dataset การรวมกันของ tf.data.Dataset Ops จะถูกสร้างขึ้น
  • while <condition> : ถ้า <condition> เป็น Tensor, แปลง tf.while_loop

หลามรันห่วงในระหว่างการติดตามการเพิ่มปฏิบัติการเพิ่มเติมเพื่อ tf.Graph สำหรับทวนของวงทุก

ลูป TensorFlow ติดตามเนื้อหาของลูป และเลือกจำนวนการวนซ้ำแบบไดนามิกที่จะรันในเวลาดำเนินการ ร่างกายห่วงจะปรากฏเฉพาะในครั้งเดียวในการสร้าง tf.Graph

ดู เอกสารอ้างอิง สำหรับข้อ จำกัด เพิ่มเติมเกี่ยวกับลายเซ็นที่แปลง for และ while งบ

วนรอบข้อมูล Python

อันตรายที่พบบ่อยคือห่วงมากกว่าข้อมูลหลาม / NumPy ภายใน tf.function วงนี้จะดำเนินการในระหว่างขั้นตอนการติดตามเพิ่มสำเนาของรูปแบบของคุณไปที่ tf.Graph สำหรับทวนของแต่ละวง

หากคุณต้องการที่จะตัดห่วงการฝึกอบรมทั้งใน tf.function วิธีที่ปลอดภัยที่สุดที่จะทำคือการห่อข้อมูลของคุณเป็น tf.data.Dataset เพื่อให้ลายเซ็นแบบไดนามิกจะเหยียดห่วงการฝึกอบรม

def measure_graph_size(f, *args):
  g = f.get_concrete_function(*args).graph
  print("{}({}) contains {} nodes in its graph".format(
      f.__name__, ', '.join(map(str, args)), len(g.as_graph_def().node)))

@tf.function
def train(dataset):
  loss = tf.constant(0)
  for x, y in dataset:
    loss += tf.abs(y - x) # Some dummy computation.
  return loss

small_data = [(1, 1)] * 3
big_data = [(1, 1)] * 10
measure_graph_size(train, small_data)
measure_graph_size(train, big_data)

measure_graph_size(train, tf.data.Dataset.from_generator(
    lambda: small_data, (tf.int32, tf.int32)))
measure_graph_size(train, tf.data.Dataset.from_generator(
    lambda: big_data, (tf.int32, tf.int32)))
train([(1, 1), (1, 1), (1, 1)]) contains 11 nodes in its graph
train([(1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1)]) contains 32 nodes in its graph
train(<FlatMapDataset shapes: (<unknown>, <unknown>), types: (tf.int32, tf.int32)>) contains 6 nodes in its graph
train(<FlatMapDataset shapes: (<unknown>, <unknown>), types: (tf.int32, tf.int32)>) contains 6 nodes in its graph

เมื่อห่อข้อมูลหลาม / NumPy ในชุดข้อมูลที่มีสติรู้ tf.data.Dataset.from_generator เมื่อเทียบกับ tf.data.Dataset.from_tensors อดีตจะเก็บข้อมูลในหลามและเรียกมันผ่าน tf.py_function ซึ่งจะมีผลกระทบต่อผลการดำเนินงานในขณะที่หลังจะกำสำเนาของข้อมูลที่มีขนาดใหญ่หนึ่ง tf.constant() โหนดในกราฟซึ่งจะมีผลกระทบต่อหน่วยความจำ

อ่านข้อมูลจากไฟล์ผ่าน TFRecordDataset , CsvDataset ฯลฯ เป็นวิธีที่มีประสิทธิภาพมากที่สุดในการบริโภคข้อมูลที่เป็นแล้ว TensorFlow ตัวเองสามารถจัดการโหลดไม่ตรงกันและ prefetching ของข้อมูลได้โดยไม่ต้องเกี่ยวข้องกับงูหลาม ต้องการเรียนรู้เพิ่มเติมโปรดดูที่ tf.data : รูปร่าง TensorFlow ท่อป้อนข้อมูล คู่มือ

สะสมค่าเป็นวง

รูปแบบทั่วไปคือการสะสมค่ากลางจากลูป โดยปกติ ทำได้โดยการผนวกรายการ Python หรือเพิ่มรายการลงในพจนานุกรม Python อย่างไรก็ตาม เนื่องจากสิ่งเหล่านี้เป็นผลข้างเคียงของ Python พวกมันจะไม่ทำงานตามที่คาดไว้ในลูปที่คลี่คลายแบบไดนามิก ใช้ tf.TensorArray กับผลสะสมจากวงคลี่แบบไดนามิก

batch_size = 2
seq_len = 3
feature_size = 4

def rnn_step(inp, state):
  return inp + state

@tf.function
def dynamic_rnn(rnn_step, input_data, initial_state):
  # [batch, time, features] -> [time, batch, features]
  input_data = tf.transpose(input_data, [1, 0, 2])
  max_seq_len = input_data.shape[0]

  states = tf.TensorArray(tf.float32, size=max_seq_len)
  state = initial_state
  for i in tf.range(max_seq_len):
    state = rnn_step(input_data[i], state)
    states = states.write(i, state)
  return tf.transpose(states.stack(), [1, 0, 2])

dynamic_rnn(rnn_step,
            tf.random.uniform([batch_size, seq_len, feature_size]),
            tf.zeros([batch_size, feature_size]))
<tf.Tensor: shape=(2, 3, 4), dtype=float32, numpy=
array([[[0.06309307, 0.9938811 , 0.90789986, 0.42136216],
        [0.44997275, 1.9107027 , 1.0716251 , 0.717237  ],
        [0.6026064 , 2.1622117 , 1.4164022 , 1.4153863 ]],

       [[0.04946005, 0.69127274, 0.56848884, 0.22406638],
        [0.8148316 , 1.0278493 , 0.6207781 , 1.1935129 ],
        [0.9178308 , 1.320889  , 0.989761  , 2.0120025 ]]], dtype=float32)>

ข้อจำกัด

TensorFlow Function มีข้อ จำกัด บางโดยการออกแบบที่คุณควรจะตระหนักถึงเมื่อมีการแปลงฟังก์ชั่นหลามกับ Function

การดำเนินการผลข้างเคียงของ Python

ผลข้างเคียงเช่นพิมพ์ผนวกกับรายการและกรรมวิธี Globals สามารถประพฤติตนไม่คาดคิดภายใน Function บางครั้งการดำเนินการสองครั้งหรือไม่ทั้งหมด พวกเขาเท่านั้นที่เกิดขึ้นครั้งแรกที่คุณเรียก Function กับชุดของปัจจัยการผลิต หลังจากนั้นสืบ tf.Graph เป็น reexecuted โดยไม่ต้องดำเนินการรหัสหลาม

กฎทั่วไปคือหลีกเลี่ยงการพึ่งพา Python ผลข้างเคียงในตรรกะของคุณ และใช้เฉพาะเพื่อดีบักการติดตามของคุณ มิฉะนั้น TensorFlow APIs เช่น tf.data , tf.print , tf.summary , tf.Variable.assign และ tf.TensorArray เป็นวิธีที่ดีที่สุดเพื่อให้รหัสของคุณจะได้รับการดำเนินการโดยรันไทม์ TensorFlow กับการโทรแต่ละครั้ง

@tf.function
def f(x):
  print("Traced with", x)
  tf.print("Executed with", x)

f(1)
f(1)
f(2)
Traced with 1
Executed with 1
Executed with 1
Traced with 2
Executed with 2

หากคุณต้องการที่จะรันโค้ดหลามระหว่างอุทธรณ์ของแต่ละ Function , tf.py_function เป็นฟักออกจาก ข้อเสียเปรียบของ tf.py_function ก็คือว่ามันไม่ได้เป็นแบบพกพาหรือ performant โดยเฉพาะอย่างยิ่งไม่สามารถบันทึกด้วย SavedModel และไม่ได้ทำงานได้ดีในการกระจาย (multi-GPU, TPU) การตั้งค่า นอกจากนี้ตั้งแต่ tf.py_function จะต้องมีสายในกราฟก็ปลดเปลื้องทุก Input / Output เพื่อเทนเซอร์

การเปลี่ยนตัวแปรโกลบอลและตัวแปรฟรีของ Python

เปลี่ยนหลามทั่วโลกและ ตัวแปรอิสระ นับเป็นผลข้างเคียงหลามดังนั้นมันจะเกิดขึ้นในระหว่างการติดตาม

external_list = []

@tf.function
def side_effect(x):
  print('Python side effect')
  external_list.append(x)

side_effect(1)
side_effect(1)
side_effect(1)
# The list append only happened once!
assert len(external_list) == 1
Python side effect

บางครั้งพฤติกรรมที่ไม่คาดคิดก็สังเกตได้ยาก ในตัวอย่างด้านล่างที่ counter มีวัตถุประสงค์เพื่อป้องกันการเพิ่มขึ้นของตัวแปร อย่างไรก็ตาม เนื่องจากเป็นจำนวนเต็มหลามและไม่ใช่อ็อบเจ็กต์ TensorFlow ค่าจึงถูกจับระหว่างการติดตามครั้งแรก เมื่อ tf.function ถูกนำมาใช้ใน assign_add จะถูกบันทึกไว้โดยไม่มีเงื่อนไขในกราฟพื้นฐาน ดังนั้น v จะเพิ่มขึ้นโดยที่ 1 ทุกครั้ง tf.function เรียกว่า ปัญหานี้เป็นเรื่องธรรมดาในหมู่ผู้ใช้ที่มีความพยายามที่จะโยกย้าย Grpah โหมดรหัสของพวกเขาที่จะ Tensorflow Tensorflow 2 ใช้ tf.function ตกแต่งเมื่อหลามผลข้างเคียง (คน counter ในตัวอย่าง) จะใช้ในการตรวจสอบสิ่งที่ Ops วิ่ง ( assign_add ในตัวอย่าง ). โดยปกติ ผู้ใช้จะทราบสิ่งนี้ก็ต่อเมื่อเห็นผลลัพธ์ที่เป็นตัวเลขที่น่าสงสัย หรือประสิทธิภาพต่ำกว่าที่คาดไว้อย่างมาก (เช่น หากการดำเนินการที่มีการป้องกันนั้นมีค่าใช้จ่ายสูง)

class Model(tf.Module):
  def __init__(self):
    self.v = tf.Variable(0)
    self.counter = 0

  @tf.function
  def __call__(self):
    if self.counter == 0:
      # A python side-effect
      self.counter += 1
      self.v.assign_add(1)

    return self.v

m = Model()
for n in range(3):
  print(m().numpy()) # prints 1, 2, 3
1
2
3

การแก้ปัญหาเพื่อให้บรรลุพฤติกรรมที่คาดหวังจะใช้ tf.init_scope จะยกการดำเนินงานที่อยู่ด้านนอกของกราฟฟังก์ชั่น เพื่อให้แน่ใจว่าการเพิ่มตัวแปรทำได้เพียงครั้งเดียวในช่วงเวลาการติดตาม มันควรจะสังเกต init_scope มีผลข้างเคียงอื่น ๆ รวมทั้งการควบคุมการไหลล้างและเทปการไล่ระดับสี บางครั้งการใช้งานของ init_scope สามารถกลายเป็นความซับซ้อนเกินไปที่จะจัดการแนบเนียน

class Model(tf.Module):
  def __init__(self):
    self.v = tf.Variable(0)
    self.counter = 0

  @tf.function
  def __call__(self):
    if self.counter == 0:
      # Lifts ops out of function-building graphs
      with tf.init_scope():
        self.counter += 1
        self.v.assign_add(1)

    return self.v

m = Model()
for n in range(3):
  print(m().numpy()) # prints 1, 1, 1
1
1
1

ในการสรุปเป็นกฎของหัวแม่มือคุณควรหลีกเลี่ยงกรรมวิธีหลามวัตถุเช่นจำนวนเต็มหรือภาชนะเช่นรายการที่นอกสด Function ให้ใช้อาร์กิวเมนต์และวัตถุ TF แทน ยกตัวอย่างเช่นในส่วน "การสะสมค่าในวง" มีตัวอย่างหนึ่งของวิธีรายการเช่นการดำเนินงานสามารถดำเนินการได้

คุณสามารถในบางกรณีการจับภาพและจัดการรัฐถ้ามันเป็น tf.Variable นี่คือวิธีที่น้ำหนักของแบบจำลอง Keras ที่มีการปรับปรุงด้วยซ้ำโทรไปเหมือนกัน ConcreteFunction

การใช้ตัววนซ้ำและตัวสร้าง Python

คุณลักษณะของ Python จำนวนมาก เช่น ตัวสร้างและตัววนซ้ำ อาศัยรันไทม์ของ Python เพื่อติดตามสถานะ โดยทั่วไป แม้ว่าโครงสร้างเหล่านี้จะทำงานตามที่คาดไว้ในโหมดกระตือรือร้น แต่ก็เป็นตัวอย่างของผลข้างเคียงของ Python ดังนั้นจึงเกิดขึ้นระหว่างการติดตามเท่านั้น

@tf.function
def buggy_consume_next(iterator):
  tf.print("Value:", next(iterator))

iterator = iter([1, 2, 3])
buggy_consume_next(iterator)
# This reuses the first value from the iterator, rather than consuming the next value.
buggy_consume_next(iterator)
buggy_consume_next(iterator)
Value: 1
Value: 1
Value: 1

เช่นเดียวกับวิธี TensorFlow มีความเชี่ยวชาญ tf.TensorArray สำหรับสร้างรายการจะได้มีความเชี่ยวชาญ tf.data.Iterator สำหรับย้ำโครงสร้าง ดูส่วนที่เกี่ยวกับ การเปลี่ยนแปลงลายเซ็น สำหรับภาพรวม นอกจากนี้ tf.data API สามารถช่วยดำเนินการรูปแบบการกำเนิด:

@tf.function
def good_consume_next(iterator):
  # This is ok, iterator is a tf.data.Iterator
  tf.print("Value:", next(iterator))

ds = tf.data.Dataset.from_tensor_slices([1, 2, 3])
iterator = iter(ds)
good_consume_next(iterator)
good_consume_next(iterator)
good_consume_next(iterator)
Value: 1
Value: 2
Value: 3

เอาต์พุตทั้งหมดของ tf.function ต้องเป็นค่าส่งคืน

มีข้อยกเว้นของ tf.Variable S, A tf.function ต้องกลับเอาท์พุททั้งหมด การพยายามเข้าถึงเทนเซอร์โดยตรงจากฟังก์ชันโดยไม่ผ่านค่าที่ส่งกลับทำให้เกิด "การรั่วไหล"

ยกตัวอย่างเช่นฟังก์ชั่นด้านล่าง "การรั่วไหลที่" เมตริกซ์ ผ่านโลกหลาม a x :

x = None

@tf.function
def leaky_function(a):
  global x
  x = a + 1  # Bad - leaks local tensor
  return a + 2

correct_a = leaky_function(tf.constant(1))

print(correct_a.numpy())  # Good - value obtained from function's returns
try:
  x.numpy()  # Bad - tensor leaked from inside the function, cannot be used here
except AttributeError as expected:
  print(expected)
3
'Tensor' object has no attribute 'numpy'

สิ่งนี้เป็นจริงแม้ว่าจะส่งคืนค่าที่รั่วไหลออกมาด้วย:

@tf.function
def leaky_function(a):
  global x
  x = a + 1  # Bad - leaks local tensor
  return x  # Good - uses local tensor

correct_a = leaky_function(tf.constant(1))

print(correct_a.numpy())  # Good - value obtained from function's returns
try:
  x.numpy()  # Bad - tensor leaked from inside the function, cannot be used here
except AttributeError as expected:
  print(expected)

@tf.function
def captures_leaked_tensor(b):
  b += x  # Bad - `x` is leaked from `leaky_function`
  return b

with assert_raises(TypeError):
  captures_leaked_tensor(tf.constant(2))
2
'Tensor' object has no attribute 'numpy'
Caught expected exception 
  <class 'TypeError'>:
Traceback (most recent call last):
  File "/tmp/ipykernel_26244/3551158538.py", line 8, in assert_raises
    yield
  File "/tmp/ipykernel_26244/566849597.py", line 21, in <module>
    captures_leaked_tensor(tf.constant(2))
TypeError: Originated from a graph execution error.

The graph execution error is detected at a node built at (most recent call last):
>>>  File /usr/lib/python3.7/runpy.py, line 193, in _run_module_as_main
>>>  File /usr/lib/python3.7/runpy.py, line 85, in _run_code
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/ipykernel_launcher.py, line 16, in <module>
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/traitlets/config/application.py, line 846, in launch_instance
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/ipykernel/kernelapp.py, line 677, in start
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tornado/platform/asyncio.py, line 199, in start
>>>  File /usr/lib/python3.7/asyncio/base_events.py, line 534, in run_forever
>>>  File /usr/lib/python3.7/asyncio/base_events.py, line 1771, in _run_once
>>>  File /usr/lib/python3.7/asyncio/events.py, line 88, in _run
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/ipykernel/kernelbase.py, line 457, in dispatch_queue
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/ipykernel/kernelbase.py, line 446, in process_one
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/ipykernel/kernelbase.py, line 353, in dispatch_shell
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/ipykernel/kernelbase.py, line 648, in execute_request
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/ipykernel/ipkernel.py, line 353, in do_execute
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/ipykernel/zmqshell.py, line 533, in run_cell
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/IPython/core/interactiveshell.py, line 2902, in run_cell
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/IPython/core/interactiveshell.py, line 2947, in _run_cell
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/IPython/core/async_helpers.py, line 68, in _pseudo_sync_runner
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/IPython/core/interactiveshell.py, line 3173, in run_cell_async
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/IPython/core/interactiveshell.py, line 3364, in run_ast_nodes
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/IPython/core/interactiveshell.py, line 3444, in run_code
>>>  File /tmp/ipykernel_26244/566849597.py, line 7, in <module>
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/util/traceback_utils.py, line 150, in error_handler
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py, line 910, in __call__
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py, line 958, in _call
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py, line 781, in _initialize
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/eager/function.py, line 3157, in _get_concrete_function_internal_garbage_collected
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/eager/function.py, line 3557, in _maybe_define_function
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/eager/function.py, line 3402, in _create_graph_function
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/framework/func_graph.py, line 1143, in func_graph_from_py_func
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py, line 672, in wrapped_fn
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/framework/func_graph.py, line 1125, in autograph_handler
>>>  File /tmp/ipykernel_26244/566849597.py, line 4, in leaky_function
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/util/traceback_utils.py, line 150, in error_handler
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/ops/math_ops.py, line 1383, in binary_op_wrapper
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/util/traceback_utils.py, line 150, in error_handler
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/util/dispatch.py, line 1096, in op_dispatch_handler
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/ops/math_ops.py, line 1737, in _add_dispatch
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/ops/gen_math_ops.py, line 476, in add_v2
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/framework/op_def_library.py, line 746, in _apply_op_helper
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/framework/func_graph.py, line 691, in _create_op_internal
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/framework/ops.py, line 3705, in _create_op_internal
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/framework/ops.py, line 2101, in __init__

Error detected in node 'add' defined at: File "/tmp/ipykernel_26244/566849597.py", line 4, in leaky_function

TypeError: tf.Graph captured an external symbolic tensor. The symbolic tensor 'add:0' created by node 'add' is captured by the tf.Graph being executed as an input. But a tf.Graph is not allowed to take symbolic tensors from another graph as its inputs. Make sure all captured inputs of the executing tf.Graph are not symbolic tensors. Use return values, explicit Python locals or TensorFlow collections to access it. Please see https://www.tensorflow.org/guide/function#all_outputs_of_a_tffunction_must_be_return_values for more information.

โดยปกติ การรั่วไหลเช่นนี้จะเกิดขึ้นเมื่อคุณใช้คำสั่ง Python หรือโครงสร้างข้อมูล นอกเหนือจากการรั่วของเทนเซอร์ที่ไม่สามารถเข้าถึงได้ คำสั่งดังกล่าวยังมีแนวโน้มที่จะผิดพลาดเพราะนับเป็นผลข้างเคียงของ Python และไม่รับประกันว่าจะดำเนินการทุกครั้งที่เรียกใช้ฟังก์ชัน

วิธีทั่วไปในการรั่วไหลของเทนเซอร์ภายในยังรวมถึงการกลายพันธุ์คอลเล็กชัน Python ภายนอกหรืออ็อบเจ็กต์:

class MyClass:

  def __init__(self):
    self.field = None

external_list = []
external_object = MyClass()

def leaky_function():
  a = tf.constant(1)
  external_list.append(a)  # Bad - leaks tensor
  external_object.field = a  # Bad - leaks tensor

ไม่รองรับ tf.functions แบบเรียกซ้ำ

ซ้ำ Function ที่ไม่ได้รับการสนับสนุนและอาจก่อให้เกิดลูปไม่มีที่สิ้นสุด ตัวอย่างเช่น,

@tf.function
def recursive_fn(n):
  if n > 0:
    return recursive_fn(n - 1)
  else:
    return 1

with assert_raises(Exception):
  recursive_fn(tf.constant(5))  # Bad - maximum recursion error.
Caught expected exception 
  <class 'Exception'>:
Traceback (most recent call last):
  File "/tmp/ipykernel_26244/3551158538.py", line 8, in assert_raises
    yield
  File "/tmp/ipykernel_26244/2233998312.py", line 9, in <module>
    recursive_fn(tf.constant(5))  # Bad - maximum recursion error.
tensorflow.python.autograph.impl.api.StagingError: in user code:

    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 3, in recursive_fn  *
        if n > 0:
    File "/usr/lib/python3.7/abc.py", line 139, in __instancecheck__
        return _abc_instancecheck(cls, instance)

    RecursionError: maximum recursion depth exceeded while calling a Python object

แม้มีการเวียนเกิด Function ที่ดูเหมือนว่าจะทำงานฟังก์ชั่นหลามจะได้รับการตรวจสอบหลายครั้งและอาจจะมีความหมายผลการดำเนินงาน ตัวอย่างเช่น,

@tf.function
def recursive_fn(n):
  if n > 0:
    print('tracing')
    return recursive_fn(n - 1)
  else:
    return 1

recursive_fn(5)  # Warning - multiple tracings
tracing
tracing
tracing
tracing
tracing
<tf.Tensor: shape=(), dtype=int32, numpy=1>

ปัญหาที่ทราบ

ถ้าคุณ Function ไม่ได้ประเมินผลอย่างถูกต้องข้อผิดพลาดอาจจะอธิบายได้ด้วยปัญหาที่พบเหล่านี้ซึ่งมีการวางแผนที่จะได้รับการแก้ไขในอนาคต

ขึ้นอยู่กับตัวแปรโกลบอลและตัวแปรอิสระของ Python

Function สร้างใหม่ ConcreteFunction เมื่อเรียกว่ามีค่าใหม่การโต้เถียงหลาม แต่ก็ไม่ได้ทำเพื่อการปิดหลาม Globals หรือ nonlocals ที่ Function ถ้าค่าของพวกเขามีการเปลี่ยนแปลงที่อยู่ในระหว่างการเรียกไปยัง Function ที่ Function จะยังคงใช้ค่าที่พวกเขามีเมื่อมันถูกตรวจสอบ ซึ่งแตกต่างจากวิธีการทำงานของ Python ปกติ

ด้วยเหตุผลดังกล่าว คุณควรปฏิบัติตามรูปแบบการเขียนโปรแกรมเชิงฟังก์ชันที่ใช้อาร์กิวเมนต์แทนการปิดทับชื่อภายนอก

@tf.function
def buggy_add():
  return 1 + foo

@tf.function
def recommended_add(foo):
  return 1 + foo

foo = 1
print("Buggy:", buggy_add())
print("Correct:", recommended_add(foo))
Buggy: tf.Tensor(2, shape=(), dtype=int32)
Correct: tf.Tensor(2, shape=(), dtype=int32)
print("Updating the value of `foo` to 100!")
foo = 100
print("Buggy:", buggy_add())  # Did not change!
print("Correct:", recommended_add(foo))
Updating the value of `foo` to 100!
Buggy: tf.Tensor(2, shape=(), dtype=int32)
Correct: tf.Tensor(101, shape=(), dtype=int32)

วิธีการปรับปรุงค่าระดับโลกอีกคือการทำให้มันเป็น tf.Variable และใช้ Variable.assign วิธีการแทน

@tf.function
def variable_add():
  return 1 + foo

foo = tf.Variable(1)
print("Variable:", variable_add())
Variable: tf.Tensor(2, shape=(), dtype=int32)
print("Updating the value of `foo` to 100!")
foo.assign(100)
print("Variable:", variable_add())
Updating the value of `foo` to 100!
Variable: tf.Tensor(101, shape=(), dtype=int32)

ขึ้นอยู่กับวัตถุ Python

ข้อเสนอแนะที่จะผ่านวัตถุหลามเป็นข้อโต้แย้งเข้า tf.function มีจำนวนของปัญหาที่ทราบที่คาดว่าจะได้รับการแก้ไขในอนาคต โดยทั่วไปแล้วคุณสามารถพึ่งพาการติดตามที่สอดคล้องกันถ้าคุณใช้งูหลามดั้งเดิมหรือ tf.nest โครงสร้างใช้ได้กับระบบปฏิบัติการเป็นอาร์กิวเมนต์หรือผ่านในอินสแตนซ์ที่แตกต่างกันของวัตถุเป็น Function แต่ Function จะไม่สร้างร่องรอยใหม่เมื่อคุณผ่านวัตถุเดียวกันและเพียงเปลี่ยนแอตทริบิวต์ของมัน

class SimpleModel(tf.Module):
  def __init__(self):
    # These values are *not* tf.Variables.
    self.bias = 0.
    self.weight = 2.

@tf.function
def evaluate(model, x):
  return model.weight * x + model.bias

simple_model = SimpleModel()
x = tf.constant(10.)
print(evaluate(simple_model, x))
tf.Tensor(20.0, shape=(), dtype=float32)
print("Adding bias!")
simple_model.bias += 5.0
print(evaluate(simple_model, x))  # Didn't change :(
Adding bias!
tf.Tensor(20.0, shape=(), dtype=float32)

ใช้เดียวกัน Function ในการประเมินเช่นการปรับปรุงรูปแบบจะเป็นรถตั้งแต่รุ่นที่ปรับปรุงมี คีย์แคชเดียวกัน เป็นรูปแบบเดิม

สำหรับเหตุผลที่คุณแนะนำในการเขียนของคุณ Function เพื่อหลีกเลี่ยงการขึ้นอยู่กับแอตทริบิวต์วัตถุไม่แน่นอนหรือสร้างวัตถุใหม่

หากที่เป็นไปไม่ได้หนึ่งในการแก้ปัญหาคือการทำให้ใหม่ Function ในแต่ละครั้งที่คุณแก้ไขวัตถุของคุณที่จะมีผลบังคับใช้ย้อนรอย:

def evaluate(model, x):
  return model.weight * x + model.bias

new_model = SimpleModel()
evaluate_no_bias = tf.function(evaluate).get_concrete_function(new_model, x)
# Don't pass in `new_model`, `Function` already captured its state during tracing.
print(evaluate_no_bias(x))
tf.Tensor(20.0, shape=(), dtype=float32)
print("Adding bias!")
new_model.bias += 5.0
# Create new Function and ConcreteFunction since you modified new_model.
evaluate_with_bias = tf.function(evaluate).get_concrete_function(new_model, x)
print(evaluate_with_bias(x)) # Don't pass in `new_model`.
Adding bias!
tf.Tensor(25.0, shape=(), dtype=float32)

ในฐานะที่เป็น ย้อนอาจมีราคาแพง คุณสามารถใช้ tf.Variable ฐานะแอตทริบิวต์วัตถุซึ่งสามารถกลายพันธุ์ ( แต่ไม่ได้เปลี่ยนระวัง!) สำหรับผลที่คล้ายกันโดยไม่จำเป็นต้องหวน

class BetterModel:

  def __init__(self):
    self.bias = tf.Variable(0.)
    self.weight = tf.Variable(2.)

@tf.function
def evaluate(model, x):
  return model.weight * x + model.bias

better_model = BetterModel()
print(evaluate(better_model, x))
tf.Tensor(20.0, shape=(), dtype=float32)
print("Adding bias!")
better_model.bias.assign_add(5.0)  # Note: instead of better_model.bias += 5
print(evaluate(better_model, x))  # This works!
Adding bias!
tf.Tensor(25.0, shape=(), dtype=float32)

การสร้าง tf.Variables

Function สนับสนุนเฉพาะเดี่ยว tf.Variable s สร้างขึ้นครั้งเดียวในสายแรกและนำกลับมาใช้ข้ามสายงานที่ตามมา โค้ดด้านล่างจะสร้างใหม่ tf.Variable ในการเรียกฟังก์ชั่นทุกซึ่งผลลัพธ์ใน ValueError ข้อยกเว้น

ตัวอย่าง:

@tf.function
def f(x):
  v = tf.Variable(1.0)
  return v

with assert_raises(ValueError):
  f(1.0)
Caught expected exception 
  <class 'ValueError'>:
Traceback (most recent call last):
  File "/tmp/ipykernel_26244/3551158538.py", line 8, in assert_raises
    yield
  File "/tmp/ipykernel_26244/3018268426.py", line 7, in <module>
    f(1.0)
ValueError: in user code:

    File "/tmp/ipykernel_26244/3018268426.py", line 3, in f  *
        v = tf.Variable(1.0)

    ValueError: tf.function only supports singleton tf.Variables created on the first call. Make sure the tf.Variable is only created once or created outside tf.function. See https://www.tensorflow.org/guide/function#creating_tfvariables for more information.

รูปแบบทั่วไปที่ใช้เพื่อหลีกเลี่ยงข้อ จำกัด นี้คือการเริ่มต้นด้วยการไม่มีค่าหลามแล้วเงื่อนไขสร้าง tf.Variable ถ้าค่าไม่มี:

class Count(tf.Module):
  def __init__(self):
    self.count = None

  @tf.function
  def __call__(self):
    if self.count is None:
      self.count = tf.Variable(0)
    return self.count.assign_add(1)

c = Count()
print(c())
print(c())
tf.Tensor(1, shape=(), dtype=int32)
tf.Tensor(2, shape=(), dtype=int32)

การใช้งานกับเครื่องมือเพิ่มประสิทธิภาพ Keras หลายตัว

คุณอาจพบ ValueError: tf.function only supports singleton tf.Variables created on the first call. เมื่อมีการใช้เพิ่มประสิทธิภาพ Keras มากกว่าหนึ่งกับ tf.function ข้อผิดพลาดนี้เกิดขึ้นเนื่องจากการเพิ่มประสิทธิภาพภายในสร้าง tf.Variables เมื่อพวกเขาใช้การไล่ระดับสีเป็นครั้งแรก

opt1 = tf.keras.optimizers.Adam(learning_rate = 1e-2)
opt2 = tf.keras.optimizers.Adam(learning_rate = 1e-3)

@tf.function
def train_step(w, x, y, optimizer):
   with tf.GradientTape() as tape:
       L = tf.reduce_sum(tf.square(w*x - y))
   gradients = tape.gradient(L, [w])
   optimizer.apply_gradients(zip(gradients, [w]))

w = tf.Variable(2.)
x = tf.constant([-1.])
y = tf.constant([2.])

train_step(w, x, y, opt1)
print("Calling `train_step` with different optimizer...")
with assert_raises(ValueError):
  train_step(w, x, y, opt2)
Calling `train_step` with different optimizer...
Caught expected exception 
  <class 'ValueError'>:
Traceback (most recent call last):
  File "/tmp/ipykernel_26244/3551158538.py", line 8, in assert_raises
    yield
  File "/tmp/ipykernel_26244/3167358578.py", line 18, in <module>
    train_step(w, x, y, opt2)
ValueError: in user code:

    File "/tmp/ipykernel_26244/3167358578.py", line 9, in train_step  *
        optimizer.apply_gradients(zip(gradients, [w]))
    File "/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/keras/optimizer_v2/optimizer_v2.py", line 639, in apply_gradients  **
        self._create_all_weights(var_list)
    File "/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/keras/optimizer_v2/optimizer_v2.py", line 828, in _create_all_weights
        _ = self.iterations
    File "/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/keras/optimizer_v2/optimizer_v2.py", line 835, in __getattribute__
        return super(OptimizerV2, self).__getattribute__(name)
    File "/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/keras/optimizer_v2/optimizer_v2.py", line 995, in iterations
        aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA)
    File "/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/keras/optimizer_v2/optimizer_v2.py", line 1202, in add_weight
        aggregation=aggregation)
    File "/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/keras/engine/base_layer_utils.py", line 129, in make_variable
        shape=variable_shape if variable_shape else None)

    ValueError: tf.function only supports singleton tf.Variables created on the first call. Make sure the tf.Variable is only created once or created outside tf.function. See https://www.tensorflow.org/guide/function#creating_tfvariables for more information.

หากคุณจำเป็นต้องเปลี่ยนเพิ่มประสิทธิภาพระหว่างการฝึกอบรมวิธีแก้ปัญหาคือการสร้างใหม่ Function สำหรับแต่ละเพิ่มประสิทธิภาพเรียก ConcreteFunction โดยตรง

opt1 = tf.keras.optimizers.Adam(learning_rate = 1e-2)
opt2 = tf.keras.optimizers.Adam(learning_rate = 1e-3)

# Not a tf.function.
def train_step(w, x, y, optimizer):
   with tf.GradientTape() as tape:
       L = tf.reduce_sum(tf.square(w*x - y))
   gradients = tape.gradient(L, [w])
   optimizer.apply_gradients(zip(gradients, [w]))

w = tf.Variable(2.)
x = tf.constant([-1.])
y = tf.constant([2.])

# Make a new Function and ConcreteFunction for each optimizer.
train_step_1 = tf.function(train_step).get_concrete_function(w, x, y, opt1)
train_step_2 = tf.function(train_step).get_concrete_function(w, x, y, opt2)
for i in range(10):
  if i % 2 == 0:
    train_step_1(w, x, y) # `opt1` is not used as a parameter. 
  else:
    train_step_2(w, x, y) # `opt2` is not used as a parameter.

ใช้กับ Keras หลายรุ่น

นอกจากนี้คุณยังอาจพบ ValueError: tf.function only supports singleton tf.Variables created on the first call. เมื่อผ่านอินสแตนซ์รูปแบบที่แตกต่างกันไปเหมือนกัน Function

ข้อผิดพลาดนี้เกิดขึ้นเนื่องจากรุ่น Keras (ซึ่ง ไม่ได้มีรูปร่างที่ใส่ของพวกเขาที่กำหนดไว้ ) และชั้น Keras สร้าง tf.Variables เมื่อพวกเขาเรียกว่าเป็นครั้งแรก คุณอาจจะพยายามที่จะเริ่มต้นตัวแปรเหล่านั้นภายใน Function ซึ่งได้รับการเรียกแล้ว เพื่อหลีกเลี่ยงข้อผิดพลาดนี้ลองโทร model.build(input_shape) ในการเริ่มต้นน้ำหนักทั้งหมดก่อนการฝึกอบรมรุ่น

อ่านเพิ่มเติม

เรียนรู้เกี่ยวกับวิธีการส่งออกและโหลด Function ให้ดูที่ คู่มือ SavedModel ต้องการเรียนรู้เพิ่มเติมเกี่ยวกับการเพิ่มประสิทธิภาพกราฟที่จะดำเนินการหลังจากที่ติดตามดู คู่มือ Grappler เพื่อเรียนรู้วิธีการเพิ่มประสิทธิภาพท่อส่งข้อมูลของคุณและรายละเอียดรูปแบบของคุณให้ดูที่ คู่มือ Profiler