ประสิทธิภาพที่ดีขึ้นด้วย tf.function

จัดทุกอย่างให้เป็นระเบียบอยู่เสมอด้วยคอลเล็กชัน บันทึกและจัดหมวดหมู่เนื้อหาตามค่ากำหนดของคุณ

ดูบน TensorFlow.org ทำงานใน Google Colab ดูแหล่งที่มาบน GitHubดาวน์โหลดโน๊ตบุ๊ค

ใน TensorFlow 2 การดำเนินการอย่างกระตือรือร้น จะเปิดไว้โดยค่าเริ่มต้น อินเทอร์เฟซผู้ใช้ใช้งานง่ายและยืดหยุ่น (การดำเนินการครั้งเดียวทำได้ง่ายกว่าและเร็วกว่ามาก) แต่สิ่งนี้อาจทำให้เสียประสิทธิภาพและความสามารถในการปรับใช้

คุณสามารถใช้ tf.function เพื่อสร้างกราฟจากโปรแกรมของคุณ เป็นเครื่องมือการแปลงที่สร้างกราฟการไหลของข้อมูลที่ไม่ขึ้นกับ Python จากโค้ด Python ของคุณ สิ่งนี้จะช่วยคุณสร้างโมเดลที่มีประสิทธิภาพและพกพาได้ และจำเป็นต้องใช้ SavedModel

คู่มือนี้จะช่วยให้คุณกำหนดแนวคิดได้ว่า tf.function ทำงานอย่างไรภายใต้ประทุน เพื่อให้คุณสามารถใช้งานได้อย่างมีประสิทธิภาพ

ประเด็นหลักและคำแนะนำคือ:

  • ดีบักในโหมดกระตือรือร้น จากนั้นตกแต่งด้วย @tf.function
  • อย่าพึ่งพาผลข้างเคียงของ Python เช่น การกลายพันธุ์ของวัตถุหรือการผนวกรายการ
  • tf.function ทำงานได้ดีที่สุดกับ TensorFlow ops; การเรียก NumPy และ Python จะถูกแปลงเป็นค่าคงที่

ติดตั้ง

import tensorflow as tf

กำหนดฟังก์ชันตัวช่วยเพื่อสาธิตประเภทของข้อผิดพลาดที่คุณอาจพบ:

import traceback
import contextlib

# Some helper code to demonstrate the kinds of errors you might encounter.
@contextlib.contextmanager
def assert_raises(error_class):
  try:
    yield
  except error_class as e:
    print('Caught expected exception \n  {}:'.format(error_class))
    traceback.print_exc(limit=2)
  except Exception as e:
    raise e
  else:
    raise Exception('Expected {} to be raised but no error was raised!'.format(
        error_class))

พื้นฐาน

การใช้งาน

Function ที่คุณกำหนด (เช่น โดยใช้ @tf.function decorator) ก็เหมือนกับการทำงานหลักของ TensorFlow: คุณสามารถดำเนินการอย่างกระตือรือร้น คุณสามารถคำนวณการไล่ระดับสี และอื่นๆ

@tf.function  # The decorator converts `add` into a `Function`.
def add(a, b):
  return a + b

add(tf.ones([2, 2]), tf.ones([2, 2]))  #  [[2., 2.], [2., 2.]]
<tf.Tensor: shape=(2, 2), dtype=float32, numpy=
array([[2., 2.],
       [2., 2.]], dtype=float32)>
v = tf.Variable(1.0)
with tf.GradientTape() as tape:
  result = add(v, 1.0)
tape.gradient(result, v)
<tf.Tensor: shape=(), dtype=float32, numpy=1.0>

คุณสามารถใช้ Function s ภายใน Function อื่นๆ ได้

@tf.function
def dense_layer(x, w, b):
  return add(tf.matmul(x, w), b)

dense_layer(tf.ones([3, 2]), tf.ones([2, 2]), tf.ones([2]))
<tf.Tensor: shape=(3, 2), dtype=float32, numpy=
array([[3., 3.],
       [3., 3.],
       [3., 3.]], dtype=float32)>

Function สามารถเร็วกว่าโค้ดที่ต้องการได้ โดยเฉพาะอย่างยิ่งสำหรับกราฟที่มี ops เล็กๆ จำนวนมาก แต่สำหรับกราฟที่มีค่า ops ราคาแพง (เช่น การบิด) คุณอาจไม่เห็นการเร่งความเร็วมากนัก

import timeit
conv_layer = tf.keras.layers.Conv2D(100, 3)

@tf.function
def conv_fn(image):
  return conv_layer(image)

image = tf.zeros([1, 200, 200, 100])
# Warm up
conv_layer(image); conv_fn(image)
print("Eager conv:", timeit.timeit(lambda: conv_layer(image), number=10))
print("Function conv:", timeit.timeit(lambda: conv_fn(image), number=10))
print("Note how there's not much difference in performance for convolutions")
Eager conv: 0.006058974999177735
Function conv: 0.005791576000774512
Note how there's not much difference in performance for convolutions

ติดตาม

ส่วนนี้จะอธิบายวิธีการทำงานของ Function ภายใต้ประทุน รวมถึงรายละเอียดการใช้งาน ที่อาจเปลี่ยนแปลงได้ในอนาคต อย่างไรก็ตาม เมื่อคุณเข้าใจสาเหตุและเมื่อการติดตามเกิดขึ้นแล้ว การใช้ tf.function อย่างมีประสิทธิภาพจะง่ายขึ้นมาก!

"การติดตาม" คืออะไร?

Function เรียกใช้โปรแกรมของคุณใน กราฟ TensorFlow อย่างไรก็ตาม tf.Graph ไม่สามารถแสดงทุกสิ่งที่คุณเขียนในโปรแกรม TensorFlow ที่กระตือรือร้น ตัวอย่างเช่น Python รองรับความหลากหลาย แต่ tf.Graph ต้องการให้อินพุตมีประเภทข้อมูลและขนาดที่ระบุ หรือคุณอาจทำงานข้างเคียง เช่น การอ่านอาร์กิวเมนต์บรรทัดคำสั่ง การทำให้เกิดข้อผิดพลาด หรือการทำงานกับอ็อบเจ็กต์ Python ที่ซับซ้อนยิ่งขึ้น สิ่งเหล่านี้ไม่สามารถทำงานใน tf.Graph ได้

Function เชื่อมช่องว่างนี้โดยแยกโค้ดของคุณออกเป็นสองขั้นตอน:

1) ในระยะแรกเรียกว่า " การติดตาม " Function จะสร้าง tf.Graph ใหม่ โค้ด Python ทำงานตามปกติ แต่การดำเนินการ TensorFlow ทั้งหมด (เช่น การเพิ่ม Tensor สองตัว) จะถูก เลื่อนออกไป : tf.Graph จับได้และจะไม่ทำงาน

2) ในระยะที่สอง tf.Graph ซึ่งมีทุกอย่างที่ถูกเลื่อนออกไปในระยะแรกจะถูกรัน ขั้นตอนนี้เร็วกว่าระยะการติดตามมาก

ขึ้นอยู่กับอินพุตของ Function จะไม่เรียกใช้สเตจแรกเสมอเมื่อมีการเรียกใช้ ดู "กฎของการติดตาม" ด้านล่างเพื่อให้เข้าใจถึงวิธีการตัดสินใจได้ดีขึ้น การข้ามขั้นตอนแรกและดำเนินการเฉพาะขั้นตอนที่สองคือสิ่งที่ให้ประสิทธิภาพสูงของ TensorFlow

เมื่อ Function ตัดสินใจติดตาม ขั้นตอนการติดตามจะถูกตามด้วยขั้นตอนที่สองทันที ดังนั้นการเรียก Function จะสร้างและรัน tf.Graph หลังจากนั้น คุณจะเห็นว่าคุณสามารถรันเฉพาะระยะการติดตามด้วย get_concrete_function ได้อย่างไร

เมื่อคุณส่งผ่านอาร์กิวเมนต์ประเภทต่างๆ ไปยัง Function ทั้งสองขั้นตอนจะถูกรัน:

@tf.function
def double(a):
  print("Tracing with", a)
  return a + a

print(double(tf.constant(1)))
print()
print(double(tf.constant(1.1)))
print()
print(double(tf.constant("a")))
print()
Tracing with Tensor("a:0", shape=(), dtype=int32)
tf.Tensor(2, shape=(), dtype=int32)

Tracing with Tensor("a:0", shape=(), dtype=float32)
tf.Tensor(2.2, shape=(), dtype=float32)

Tracing with Tensor("a:0", shape=(), dtype=string)
tf.Tensor(b'aa', shape=(), dtype=string)

โปรดทราบว่าหากคุณเรียกใช้ Function ที่มีประเภทอาร์กิวเมนต์เดียวกันซ้ำๆ TensorFlow จะข้ามขั้นตอนการติดตามและนำกราฟที่ติดตามก่อนหน้านี้มาใช้ใหม่ เนื่องจากกราฟที่สร้างขึ้นจะเหมือนกัน

# This doesn't print 'Tracing with ...'
print(double(tf.constant("b")))
tf.Tensor(b'bb', shape=(), dtype=string)

คุณสามารถใช้ pretty_printed_concrete_signatures() เพื่อดูร่องรอยที่มีอยู่ทั้งหมดได้:

print(double.pretty_printed_concrete_signatures())
double(a)
  Args:
    a: int32 Tensor, shape=()
  Returns:
    int32 Tensor, shape=()

double(a)
  Args:
    a: float32 Tensor, shape=()
  Returns:
    float32 Tensor, shape=()

double(a)
  Args:
    a: string Tensor, shape=()
  Returns:
    string Tensor, shape=()

จนถึงตอนนี้ คุณได้เห็นแล้วว่า tf.function สร้างเลเยอร์การจัดส่งแบบไดนามิกที่แคชไว้บนลอจิกการติดตามกราฟของ TensorFlow เพื่อให้เฉพาะเจาะจงมากขึ้นเกี่ยวกับคำศัพท์:

  • tf.Graph คือการแสดงข้อมูลดิบที่ไม่เชื่อเรื่องภาษา และสามารถเคลื่อนย้ายได้ของการคำนวณ TensorFlow
  • ConcreteFunction ล้อม tf.Graph
  • Function จัดการแคชของ ConcreteFunction และเลือกอันที่เหมาะสมสำหรับอินพุตของคุณ
  • tf.function ล้อมฟังก์ชัน Python ส่งคืนอ็อบเจ็กต์ Function
  • การติดตาม สร้าง tf.Graph และรวมไว้ใน ConcreteFunction หรือที่เรียกว่าการ ติดตาม

กฎการติดตาม

Function กำหนดว่าจะนำ ConcreteFunction ที่ติดตามกลับมาใช้ใหม่หรือไม่โดยคำนวณ คีย์แคช จาก args และ kwargs ของอินพุต คีย์แคช คือคีย์ที่ระบุ ConcreteFunction โดยอิงจากอินพุต args และ kwargs ของการเรียก Function ตามกฎต่อไปนี้ (ซึ่งอาจเปลี่ยนแปลงได้):

  • คีย์ที่สร้างขึ้นสำหรับ tf.Tensor คือรูปร่างและ dtype
  • คีย์ที่สร้างขึ้นสำหรับ tf.Variable คือ id ตัวแปรที่ไม่ซ้ำกัน
  • คีย์ที่สร้างขึ้นสำหรับ Python ดั้งเดิม (เช่น int , float , str ) คือค่าของมัน
  • คีย์ที่สร้างขึ้นสำหรับ dict ที่ซ้อนกัน, list s, tuple s, namedtuple s และ attr s คือ tuple แบบแบนของ leaf-keys (ดู nest.flatten ) (เนื่องจากการแบนราบนี้ การเรียกฟังก์ชันที่เป็นรูปธรรมด้วยโครงสร้างการซ้อนที่แตกต่างจากที่ใช้ระหว่างการติดตามจะส่งผลให้เกิด TypeError)
  • สำหรับประเภท Python อื่นๆ ทั้งหมด คีย์จะไม่ซ้ำกับอ็อบเจ็กต์ วิธีนี้จะมีการติดตามฟังก์ชันหรือเมธอดอย่างอิสระสำหรับแต่ละอินสแตนซ์ที่เรียกใช้ด้วย

การควบคุมการถอยกลับ

การย้อนรอย ซึ่งเกิดขึ้นเมื่อ Function ของคุณสร้างการติดตามมากกว่าหนึ่งรายการ ช่วยให้แน่ใจว่า TensorFlow สร้างกราฟที่ถูกต้องสำหรับอินพุตแต่ละชุด อย่างไรก็ตาม การติดตามเป็นการดำเนินการที่มีราคาแพง! หาก Function ของคุณย้อนกราฟใหม่ทุกครั้งที่มีการเรียก คุณจะพบว่าโค้ดของคุณทำงานช้ากว่าถ้าคุณไม่ได้ใช้ tf.function

ในการควบคุมลักษณะการติดตาม คุณสามารถใช้เทคนิคต่อไปนี้:

  • ระบุ input_signature ใน tf.function เพื่อจำกัดการติดตาม
@tf.function(input_signature=(tf.TensorSpec(shape=[None], dtype=tf.int32),))
def next_collatz(x):
  print("Tracing with", x)
  return tf.where(x % 2 == 0, x // 2, 3 * x + 1)

print(next_collatz(tf.constant([1, 2])))
# You specified a 1-D tensor in the input signature, so this should fail.
with assert_raises(ValueError):
  next_collatz(tf.constant([[1, 2], [3, 4]]))

# You specified an int32 dtype in the input signature, so this should fail.
with assert_raises(ValueError):
  next_collatz(tf.constant([1.0, 2.0]))
Tracing with Tensor("x:0", shape=(None,), dtype=int32)
tf.Tensor([4 1], shape=(2,), dtype=int32)
Caught expected exception 
  <class 'ValueError'>:
Caught expected exception 
  <class 'ValueError'>:
Traceback (most recent call last):
  File "/tmp/ipykernel_26244/3551158538.py", line 8, in assert_raises
    yield
  File "/tmp/ipykernel_26244/1851403433.py", line 9, in <module>
    next_collatz(tf.constant([[1, 2], [3, 4]]))
ValueError: Python inputs incompatible with input_signature:
  inputs: (
    tf.Tensor(
[[1 2]
 [3 4]], shape=(2, 2), dtype=int32))
  input_signature: (
    TensorSpec(shape=(None,), dtype=tf.int32, name=None)).
Traceback (most recent call last):
  File "/tmp/ipykernel_26244/3551158538.py", line 8, in assert_raises
    yield
  File "/tmp/ipykernel_26244/1851403433.py", line 13, in <module>
    next_collatz(tf.constant([1.0, 2.0]))
ValueError: Python inputs incompatible with input_signature:
  inputs: (
    tf.Tensor([1. 2.], shape=(2,), dtype=float32))
  input_signature: (
    TensorSpec(shape=(None,), dtype=tf.int32, name=None)).
  • ระบุมิติข้อมูล [ไม่มี] ใน tf.TensorSpec เพื่อให้มีความยืดหยุ่นในการใช้การติดตามซ้ำ

    เนื่องจาก TensorFlow จับคู่เทนเซอร์ตามรูปร่าง การใช้มิติ None เป็นสัญลักษณ์แทนจะช่วยให้ Function นำการติดตามกลับมาใช้ใหม่สำหรับอินพุตที่มีขนาดแปรผันได้ อินพุตขนาดแปรผันอาจเกิดขึ้นได้หากคุณมีลำดับความยาวต่างกัน หรือมีรูปภาพที่มีขนาดต่างกันสำหรับแต่ละชุดงาน (ดูตัวอย่างบทช่วยสอน Transformer และ Deep Dream )

@tf.function(input_signature=(tf.TensorSpec(shape=[None], dtype=tf.int32),))
def g(x):
  print('Tracing with', x)
  return x

# No retrace!
print(g(tf.constant([1, 2, 3])))
print(g(tf.constant([1, 2, 3, 4, 5])))
Tracing with Tensor("x:0", shape=(None,), dtype=int32)
tf.Tensor([1 2 3], shape=(3,), dtype=int32)
tf.Tensor([1 2 3 4 5], shape=(5,), dtype=int32)
  • ส่งอาร์กิวเมนต์ Python ไปยัง Tensors เพื่อลดการย้อนกลับ

    บ่อยครั้ง อาร์กิวเมนต์ Python ถูกใช้เพื่อควบคุมไฮเปอร์พารามิเตอร์และการสร้างกราฟ - ตัวอย่างเช่น num_layers=10 หรือ training=True หรือ nonlinearity='relu' ดังนั้น หากอาร์กิวเมนต์ Python เปลี่ยนไป คุณจะต้องย้อนกราฟ

    อย่างไรก็ตาม อาจไม่ได้ใช้อาร์กิวเมนต์ Python เพื่อควบคุมการสร้างกราฟ ในกรณีเหล่านี้ การเปลี่ยนแปลงในค่า Python สามารถทริกเกอร์การย้อนกลับโดยไม่จำเป็น ตัวอย่างเช่น วนรอบการฝึกนี้ ซึ่ง AutoGraph จะคลี่คลายแบบไดนามิก แม้ว่าจะมีการติดตามหลายรายการ แต่กราฟที่สร้างขึ้นนั้นจริง ๆ แล้วเหมือนกัน ดังนั้นการย้อนรอยจึงไม่จำเป็น

def train_one_step():
  pass

@tf.function
def train(num_steps):
  print("Tracing with num_steps = ", num_steps)
  tf.print("Executing with num_steps = ", num_steps)
  for _ in tf.range(num_steps):
    train_one_step()

print("Retracing occurs for different Python arguments.")
train(num_steps=10)
train(num_steps=20)

print()
print("Traces are reused for Tensor arguments.")
train(num_steps=tf.constant(10))
train(num_steps=tf.constant(20))
Retracing occurs for different Python arguments.
Tracing with num_steps =  10
Executing with num_steps =  10
Tracing with num_steps =  20
Executing with num_steps =  20

Traces are reused for Tensor arguments.
Tracing with num_steps =  Tensor("num_steps:0", shape=(), dtype=int32)
Executing with num_steps =  10
Executing with num_steps =  20
ตัวยึดตำแหน่ง22

หากคุณต้องการบังคับการถอยกลับ ให้สร้าง Function ใหม่ ออบเจ็กต์ Function ที่แยกจากกันรับประกันว่าจะไม่แชร์การติดตาม

def f():
  print('Tracing!')
  tf.print('Executing')

tf.function(f)()
tf.function(f)()
Tracing!
Executing
Tracing!
Executing

การรับฟังก์ชั่นที่เป็นรูปธรรม

ทุกครั้งที่มีการติดตามฟังก์ชัน จะมีการสร้างฟังก์ชันที่เป็นรูปธรรมขึ้นใหม่ คุณสามารถรับฟังก์ชันที่เป็นรูปธรรมได้โดยตรงโดยใช้ get_concrete_function

print("Obtaining concrete trace")
double_strings = double.get_concrete_function(tf.constant("a"))
print("Executing traced function")
print(double_strings(tf.constant("a")))
print(double_strings(a=tf.constant("b")))
Obtaining concrete trace
Executing traced function
tf.Tensor(b'aa', shape=(), dtype=string)
tf.Tensor(b'bb', shape=(), dtype=string)
# You can also call get_concrete_function on an InputSpec
double_strings_from_inputspec = double.get_concrete_function(tf.TensorSpec(shape=[], dtype=tf.string))
print(double_strings_from_inputspec(tf.constant("c")))
tf.Tensor(b'cc', shape=(), dtype=string)

การพิมพ์ ConcreteFunction จะแสดงสรุปอาร์กิวเมนต์อินพุต (พร้อมประเภท) และประเภทเอาต์พุต

print(double_strings)
ConcreteFunction double(a)
  Args:
    a: string Tensor, shape=()
  Returns:
    string Tensor, shape=()

คุณยังสามารถดึงข้อมูลลายเซ็นของฟังก์ชันที่เป็นรูปธรรมได้โดยตรง

print(double_strings.structured_input_signature)
print(double_strings.structured_outputs)
((TensorSpec(shape=(), dtype=tf.string, name='a'),), {})
Tensor("Identity:0", shape=(), dtype=string)
ตัวยึดตำแหน่ง32

การใช้การติดตามที่เป็นรูปธรรมกับประเภทที่เข้ากันไม่ได้จะทำให้เกิดข้อผิดพลาด

with assert_raises(tf.errors.InvalidArgumentError):
  double_strings(tf.constant(1))
Caught expected exception 
  <class 'tensorflow.python.framework.errors_impl.InvalidArgumentError'>:
Traceback (most recent call last):
  File "/tmp/ipykernel_26244/3551158538.py", line 8, in assert_raises
    yield
  File "/tmp/ipykernel_26244/3196284684.py", line 2, in <module>
    double_strings(tf.constant(1))
tensorflow.python.framework.errors_impl.InvalidArgumentError: cannot compute __inference_double_162 as input #0(zero-based) was expected to be a string tensor but is a int32 tensor [Op:__inference_double_162]

คุณอาจสังเกตเห็นว่าอาร์กิวเมนต์ Python ได้รับการดูแลเป็นพิเศษในลายเซ็นอินพุตของฟังก์ชันที่เป็นรูปธรรม ก่อนหน้า TensorFlow 2.3 อาร์กิวเมนต์ Python จะถูกลบออกจากลายเซ็นของฟังก์ชันที่เป็นรูปธรรม เริ่มต้นด้วย TensorFlow 2.3 อาร์กิวเมนต์ Python จะยังคงอยู่ในลายเซ็น แต่ถูกจำกัดให้ใช้ค่าที่ตั้งไว้ระหว่างการติดตาม

@tf.function
def pow(a, b):
  return a ** b

square = pow.get_concrete_function(a=tf.TensorSpec(None, tf.float32), b=2)
print(square)
ConcreteFunction pow(a, b=2)
  Args:
    a: float32 Tensor, shape=<unknown>
  Returns:
    float32 Tensor, shape=<unknown>
assert square(tf.constant(10.0)) == 100

with assert_raises(TypeError):
  square(tf.constant(10.0), b=3)
Caught expected exception 
  <class 'TypeError'>:
Traceback (most recent call last):
  File "/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/eager/function.py", line 1721, in _call_impl
    cancellation_manager)
  File "/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/eager/function.py", line 1765, in _call_with_flat_signature
    raise TypeError(f"{self._flat_signature_summary()} got unexpected "
TypeError: pow(a) got unexpected keyword arguments: b.

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/tmp/ipykernel_26244/3551158538.py", line 8, in assert_raises
    yield
  File "/tmp/ipykernel_26244/2310937119.py", line 4, in <module>
    square(tf.constant(10.0), b=3)
TypeError: ConcreteFunction pow(a, b) was constructed with int value 2 in b, but was called with int value 3.

รับกราฟ

แต่ละฟังก์ชันที่เป็นรูปธรรมคือเครื่องห่อหุ้มที่เรียกได้รอบๆ tf.Graph แม้ว่าการเรียกวัตถุ tf.Graph จริงจะไม่ใช่สิ่งที่คุณต้องทำตามปกติ แต่คุณสามารถรับวัตถุนั้นได้อย่างง่ายดายจากฟังก์ชันที่เป็นรูปธรรม

graph = double_strings.graph
for node in graph.as_graph_def().node:
  print(f'{node.input} -> {node.name}')
[] -> a
['a', 'a'] -> add
['add'] -> Identity

แก้จุดบกพร่อง

โดยทั่วไปแล้ว การดีบักโค้ดจะง่ายกว่าในโหมดกระตือรือร้นมากกว่าภายใน tf.function คุณควรตรวจสอบให้แน่ใจว่าโค้ดของคุณทำงานโดยไม่มีข้อผิดพลาดในโหมดกระตือรือร้นก่อนที่จะตกแต่งด้วย tf.function เพื่อช่วยในกระบวนการดีบัก คุณสามารถเรียก tf.config.run_functions_eagerly(True) เพื่อปิดการใช้งานและเปิดใช้งาน tf.function อีกครั้งทั่วโลก

เมื่อติดตามปัญหาที่ปรากฏเฉพาะใน tf.function ต่อไปนี้คือเคล็ดลับบางประการ:

  • การเรียกใช้ print Python แบบเก่าจะดำเนินการเฉพาะระหว่างการติดตาม ช่วยให้คุณติดตามเมื่อฟังก์ชันของคุณได้รับการติดตาม (อีกครั้ง)
  • การเรียก tf.print จะดำเนินการทุกครั้ง และสามารถช่วยคุณติดตามค่ากลางระหว่างการดำเนินการได้
  • tf.debugging.enable_check_numerics เป็นวิธีที่ง่ายในการติดตามตำแหน่งที่สร้าง NaN และ Inf
  • pdb ( ดีบักเกอร์ Python ) สามารถช่วยให้คุณเข้าใจสิ่งที่เกิดขึ้นระหว่างการติดตาม (คำเตือน: pdb จะนำคุณเข้าสู่ซอร์สโค้ดที่แปลง AutoGraph)

การแปลงกราฟอัตโนมัติ

AutoGraph เป็นไลบรารีที่เปิดใช้งานโดยค่าเริ่มต้นใน tf.function และแปลงชุดย่อยของโค้ด Python ที่กระตือรือร้นเป็น TensorFlow ops ที่เข้ากันได้กับกราฟ ซึ่งรวมถึงโฟลว์การควบคุม เช่น if for while

TensorFlow ops เช่น tf.cond และ tf.while_loop ยังคงทำงานต่อไป แต่โฟลว์การควบคุมมักจะเขียนและเข้าใจได้ง่ายกว่าเมื่อเขียนด้วย Python

# A simple loop

@tf.function
def f(x):
  while tf.reduce_sum(x) > 1:
    tf.print(x)
    x = tf.tanh(x)
  return x

f(tf.random.uniform([5]))
[0.666458249 0.713946581 0.723879576 0.330758929 0.184087753]
[0.582645297 0.613145649 0.619306684 0.319202513 0.182036072]
[0.524585426 0.546337605 0.550645113 0.308785647 0.18005164]
[0.481231302 0.497770309 0.501003504 0.299331933 0.178130865]
[0.447229207 0.460361809 0.462906033 0.290701121 0.176270396]
[0.419618756 0.430379033 0.432449728 0.282779962 0.174467146]
[0.396609187 0.405638 0.407366514 0.275476 0.172718227]
[0.377043903 0.384762734 0.386234313 0.268712848 0.17102097]
[0.360137492 0.366836458 0.368109286 0.262426734 0.169372901]
[0.345335096 0.351221472 0.352336824 0.256563932 0.167771652]
[0.332231969 0.337458342 0.338446289 0.251078814 0.166215062]
[0.320524871 0.325206399 0.326089561 0.24593246 0.164701089]
[0.309981436 0.314206958 0.31500268 0.241091311 0.163227797]
[0.300420195 0.304259449 0.304981351 0.236526251 0.161793426]
[0.291697085 0.295205742 0.295864582 0.232211992 0.160396278]
[0.283696055 0.286919087 0.287523568 0.228126258 0.159034774]
[0.276322395 0.279296666 0.27985391 0.224249557 0.157707423]
[0.269497961 0.272254 0.272769839 0.220564634 0.15641281]
[0.263157606 0.265720904 0.266200244 0.21705614 0.155149609]
[0.257246554 0.259638608 0.260085613 0.213710397 0.153916568]
[0.251718313 0.25395745 0.254375577 0.210515186 0.152712509]
[0.246533215 0.248635098 0.249027327 0.207459539 0.151536316]
[0.241657034 0.243635193 0.244004101 0.204533577 0.15038693]
[0.237060249 0.238926381 0.239274174 0.201728329 0.149263337]
[0.232717097 0.234481394 0.234810054 0.199035719 0.148164615]
[0.228605017 0.230276451 0.230587661 0.196448416 0.147089839]
[0.224704206 0.226290658 0.22658591 0.193959698 0.14603813]
[0.220997125 0.222505584 0.222786173 0.191563457 0.145008713]
<tf.Tensor: shape=(5,), dtype=float32, numpy=
array([0.21746822, 0.21890487, 0.21917202, 0.18925412, 0.14400077],
      dtype=float32)>
ตัวยึดตำแหน่ง42

หากคุณสงสัย คุณสามารถตรวจสอบรหัสที่สร้างลายเซ็นได้

print(tf.autograph.to_code(f.python_function))
def tf__f(x):
    with ag__.FunctionScope('f', 'fscope', ag__.ConversionOptions(recursive=True, user_requested=True, optional_features=(), internal_convert_user_code=True)) as fscope:
        do_return = False
        retval_ = ag__.UndefinedReturnValue()

        def get_state():
            return (x,)

        def set_state(vars_):
            nonlocal x
            (x,) = vars_

        def loop_body():
            nonlocal x
            ag__.converted_call(ag__.ld(tf).print, (ag__.ld(x),), None, fscope)
            x = ag__.converted_call(ag__.ld(tf).tanh, (ag__.ld(x),), None, fscope)

        def loop_test():
            return (ag__.converted_call(ag__.ld(tf).reduce_sum, (ag__.ld(x),), None, fscope) > 1)
        ag__.while_stmt(loop_test, loop_body, get_state, set_state, ('x',), {})
        try:
            do_return = True
            retval_ = ag__.ld(x)
        except:
            do_return = False
            raise
        return fscope.ret(retval_, do_return)

เงื่อนไข

AutoGraph จะแปลงบางคำสั่ง if <condition> เป็นการเรียก tf.cond ที่เทียบเท่า การแทนที่นี้ถูกสร้างขึ้นหาก <condition> เป็นเทนเซอร์ มิฉะนั้น คำสั่ง if จะถูกดำเนินการตามเงื่อนไขของ Python

Python Conditional ดำเนินการระหว่างการติดตาม ดังนั้นหนึ่งสาขาของเงื่อนไขจะถูกเพิ่มลงในกราฟ หากไม่มี AutoGraph กราฟที่ติดตามนี้จะไม่สามารถใช้สาขาสำรองได้หากมีโฟลว์การควบคุมที่ขึ้นกับข้อมูล

tf.cond ติดตามและเพิ่มกิ่งของเงื่อนไขทั้งสองลงในกราฟ โดยเลือกสาขาแบบไดนามิกในขณะดำเนินการ การติดตามอาจมีผลข้างเคียงที่ไม่ได้ตั้งใจ ตรวจสอบ เอฟเฟกต์การติดตาม AutoGraph สำหรับข้อมูลเพิ่มเติม

@tf.function
def fizzbuzz(n):
  for i in tf.range(1, n + 1):
    print('Tracing for loop')
    if i % 15 == 0:
      print('Tracing fizzbuzz branch')
      tf.print('fizzbuzz')
    elif i % 3 == 0:
      print('Tracing fizz branch')
      tf.print('fizz')
    elif i % 5 == 0:
      print('Tracing buzz branch')
      tf.print('buzz')
    else:
      print('Tracing default branch')
      tf.print(i)

fizzbuzz(tf.constant(5))
fizzbuzz(tf.constant(20))
Tracing for loop
Tracing fizzbuzz branch
Tracing fizz branch
Tracing buzz branch
Tracing default branch
1
2
fizz
4
buzz
1
2
fizz
4
buzz
fizz
7
8
fizz
buzz
11
fizz
13
14
fizzbuzz
16
17
fizz
19
buzz
ตัวยึดตำแหน่ง46

ดู เอกสารอ้างอิง สำหรับข้อจำกัดเพิ่มเติมเกี่ยวกับคำสั่ง if ที่แปลง AutoGraph

ลูป

AutoGraph จะแปลงคำสั่ง for และ while บางส่วนให้เป็น ops วนซ้ำของ TensorFlow ที่เทียบเท่า เช่น tf.while_loop หากไม่ได้แปลง ลูป for หรือ while จะถูกดำเนินการเป็นลูป Python

การทดแทนนี้ทำขึ้นในสถานการณ์ต่อไปนี้:

  • for x in y : ถ้า y เป็นเทนเซอร์ ให้แปลงเป็น tf.while_loop ในกรณีพิเศษที่ y เป็น tf.data.Dataset การรวมกันของ tf.data.Dataset ops จะถูกสร้างขึ้น
  • while <condition> : ถ้า <condition> เป็นเทนเซอร์ ให้แปลงเป็น tf.while_loop

ลูป Python ทำงานระหว่างการติดตาม เพิ่ม ops เพิ่มเติมให้กับ tf.Graph สำหรับการวนซ้ำทุกครั้ง

ลูป TensorFlow ติดตามเนื้อหาของลูป และเลือกจำนวนการวนซ้ำแบบไดนามิกที่จะรันในเวลาดำเนินการ เนื้อหาลูปปรากฏขึ้นเพียงครั้งเดียวใน tf.Graph ที่สร้างขึ้น

ดู เอกสารอ้างอิง สำหรับข้อจำกัดเพิ่มเติมเกี่ยวกับคำสั่ง AutoGraph ที่แปลง for และ while

วนรอบข้อมูล Python

ข้อผิดพลาดทั่วไปคือการวนซ้ำข้อมูล Python/NumPy ภายใน tf.function การวนซ้ำนี้จะดำเนินการในระหว่างกระบวนการติดตาม โดยเพิ่มสำเนาของแบบจำลองของคุณไปที่ tf.Graph สำหรับการวนซ้ำแต่ละครั้งของลูป

หากคุณต้องการรวมวงจรการฝึกทั้งหมดใน tf.function วิธีที่ปลอดภัยที่สุดในการทำเช่นนี้คือการห่อข้อมูลของคุณเป็น tf.data.Dataset เพื่อให้ AutoGraph จะคลายการวนรอบการฝึกแบบไดนามิก

def measure_graph_size(f, *args):
  g = f.get_concrete_function(*args).graph
  print("{}({}) contains {} nodes in its graph".format(
      f.__name__, ', '.join(map(str, args)), len(g.as_graph_def().node)))

@tf.function
def train(dataset):
  loss = tf.constant(0)
  for x, y in dataset:
    loss += tf.abs(y - x) # Some dummy computation.
  return loss

small_data = [(1, 1)] * 3
big_data = [(1, 1)] * 10
measure_graph_size(train, small_data)
measure_graph_size(train, big_data)

measure_graph_size(train, tf.data.Dataset.from_generator(
    lambda: small_data, (tf.int32, tf.int32)))
measure_graph_size(train, tf.data.Dataset.from_generator(
    lambda: big_data, (tf.int32, tf.int32)))
train([(1, 1), (1, 1), (1, 1)]) contains 11 nodes in its graph
train([(1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1)]) contains 32 nodes in its graph
train(<FlatMapDataset shapes: (<unknown>, <unknown>), types: (tf.int32, tf.int32)>) contains 6 nodes in its graph
train(<FlatMapDataset shapes: (<unknown>, <unknown>), types: (tf.int32, tf.int32)>) contains 6 nodes in its graph

เมื่อตัดข้อมูล Python/NumPy ในชุดข้อมูล ให้คำนึงถึง tf.data.Dataset.from_generator กับ tf.data.Dataset.from_tensors อดีตจะเก็บข้อมูลใน Python และดึงข้อมูลผ่าน tf.py_function ซึ่งอาจมีผลกระทบต่อประสิทธิภาพในขณะที่หลังจะรวมสำเนาของข้อมูลเป็นโหนด tf.constant() ขนาดใหญ่หนึ่งโหนดในกราฟ ซึ่งอาจมีผลกระทบต่อหน่วยความจำ

การอ่านข้อมูลจากไฟล์ผ่าน TFRecordDataset , CsvDataset ฯลฯ เป็นวิธีที่มีประสิทธิภาพสูงสุดในการใช้ข้อมูล เนื่องจาก TensorFlow เองสามารถจัดการการโหลดและการดึงข้อมูลล่วงหน้าแบบอะซิงโครนัส โดยไม่ต้องเกี่ยวข้องกับ Python หากต้องการเรียนรู้เพิ่มเติม โปรดดูที่ tf.data : สร้างคำแนะนำไปป์ไลน์อินพุต TensorFlow

สะสมค่าเป็นวง

รูปแบบทั่วไปคือการสะสมค่ากลางจากลูป โดยปกติ ทำได้โดยผนวกรายการ Python หรือเพิ่มรายการลงในพจนานุกรม Python อย่างไรก็ตาม เนื่องจากสิ่งเหล่านี้เป็นผลข้างเคียงของ Python พวกมันจะไม่ทำงานตามที่คาดไว้ในลูปที่คลี่ออกแบบไดนามิก ใช้ tf.TensorArray เพื่อรวบรวมผลลัพธ์จากการวนซ้ำที่คลี่คลายแบบไดนามิก

batch_size = 2
seq_len = 3
feature_size = 4

def rnn_step(inp, state):
  return inp + state

@tf.function
def dynamic_rnn(rnn_step, input_data, initial_state):
  # [batch, time, features] -> [time, batch, features]
  input_data = tf.transpose(input_data, [1, 0, 2])
  max_seq_len = input_data.shape[0]

  states = tf.TensorArray(tf.float32, size=max_seq_len)
  state = initial_state
  for i in tf.range(max_seq_len):
    state = rnn_step(input_data[i], state)
    states = states.write(i, state)
  return tf.transpose(states.stack(), [1, 0, 2])

dynamic_rnn(rnn_step,
            tf.random.uniform([batch_size, seq_len, feature_size]),
            tf.zeros([batch_size, feature_size]))
<tf.Tensor: shape=(2, 3, 4), dtype=float32, numpy=
array([[[0.06309307, 0.9938811 , 0.90789986, 0.42136216],
        [0.44997275, 1.9107027 , 1.0716251 , 0.717237  ],
        [0.6026064 , 2.1622117 , 1.4164022 , 1.4153863 ]],

       [[0.04946005, 0.69127274, 0.56848884, 0.22406638],
        [0.8148316 , 1.0278493 , 0.6207781 , 1.1935129 ],
        [0.9178308 , 1.320889  , 0.989761  , 2.0120025 ]]], dtype=float32)>

ข้อจำกัด

Function TensorFlow มีข้อจำกัดเล็กน้อยจากการออกแบบที่คุณควรทราบเมื่อแปลงฟังก์ชัน Python เป็น Function

การดำเนินการผลข้างเคียงของ Python

ผลข้างเคียง เช่น การพิมพ์ การต่อท้ายรายการ และการกลายพันธุ์ globals สามารถทำงานโดยไม่คาดคิดภายใน Function ซึ่งบางครั้งดำเนินการสองครั้งหรือไม่ทั้งหมด จะเกิดขึ้นในครั้งแรกที่คุณเรียกใช้ Function ด้วยชุดอินพุตเท่านั้น หลังจากนั้น tf.Graph ที่ติดตามจะถูกดำเนินการอีกครั้ง โดยไม่ต้องรันโค้ด Python

หลักการทั่วไปคือการหลีกเลี่ยงการพึ่งพา Python ผลข้างเคียงในตรรกะของคุณ และใช้เฉพาะเพื่อดีบักการติดตามของคุณ มิฉะนั้น TensorFlow API เช่น tf.data , tf.print , tf.summary , tf.Variable.assign และ tf.TensorArray เป็นวิธีที่ดีที่สุดเพื่อให้แน่ใจว่าโค้ดของคุณจะถูกเรียกใช้โดยรันไทม์ TensorFlow ในแต่ละครั้ง

@tf.function
def f(x):
  print("Traced with", x)
  tf.print("Executed with", x)

f(1)
f(1)
f(2)
Traced with 1
Executed with 1
Executed with 1
Traced with 2
Executed with 2
ตัวยึดตำแหน่ง52

หากคุณต้องการรันโค้ด Python ในระหว่างการเรียกใช้ Function แต่ละครั้ง tf.py_function จะเป็นทางออก ข้อเสียของ tf.py_function คือไม่ใช่แบบพกพาหรือมีประสิทธิภาพเป็นพิเศษ ไม่สามารถบันทึกด้วย SavedModel ได้ และทำงานได้ไม่ดีในการตั้งค่าแบบกระจาย (multi-GPU, TPU) นอกจากนี้ เนื่องจากต้องเชื่อมต่อ tf.py_function เข้ากับกราฟ จึงส่งอินพุต/เอาต์พุตทั้งหมดไปยังเทนเซอร์

การเปลี่ยนตัวแปรโกลบอลและตัวแปรฟรีของ Python

การเปลี่ยนตัวแปรส่วนกลางและ ตัวแปรอิสระ ของ Python ถือเป็นผลข้างเคียงของ Python ดังนั้นจะเกิดขึ้นระหว่างการติดตามเท่านั้น

external_list = []

@tf.function
def side_effect(x):
  print('Python side effect')
  external_list.append(x)

side_effect(1)
side_effect(1)
side_effect(1)
# The list append only happened once!
assert len(external_list) == 1
Python side effect

บางครั้งพฤติกรรมที่ไม่คาดคิดก็สังเกตได้ยาก ในตัวอย่างด้านล่าง ตัว counter มีไว้เพื่อป้องกันการเพิ่มขึ้นของตัวแปร อย่างไรก็ตาม เนื่องจากเป็นจำนวนเต็มหลามและไม่ใช่อ็อบเจ็กต์ TensorFlow ค่าจึงถูกจับระหว่างการติดตามครั้งแรก เมื่อใช้ tf.function assign_add จะถูกบันทึกอย่างไม่มีเงื่อนไขในกราฟพื้นฐาน ดังนั้น v จะเพิ่มขึ้น 1 ทุกครั้งที่มีการเรียก tf.function ปัญหานี้พบได้บ่อยในหมู่ผู้ใช้ที่พยายามย้ายโค้ด Tensorflow ในโหมด Grpah ไปยัง Tensorflow 2 โดยใช้ตัวตกแต่ง tf.function เมื่อใช้ผลข้างเคียงของ python (ตัว counter ในตัวอย่าง) เพื่อกำหนดว่า ops ใดที่จะเรียกใช้ ( assign_add ในตัวอย่าง ). โดยปกติ ผู้ใช้จะทราบสิ่งนี้ก็ต่อเมื่อเห็นผลตัวเลขที่น่าสงสัยเท่านั้น หรือประสิทธิภาพต่ำกว่าที่คาดไว้อย่างมาก (เช่น หากการดำเนินการที่มีการป้องกันนั้นมีค่าใช้จ่ายสูง)

class Model(tf.Module):
  def __init__(self):
    self.v = tf.Variable(0)
    self.counter = 0

  @tf.function
  def __call__(self):
    if self.counter == 0:
      # A python side-effect
      self.counter += 1
      self.v.assign_add(1)

    return self.v

m = Model()
for n in range(3):
  print(m().numpy()) # prints 1, 2, 3
1
2
3

วิธีแก้ปัญหาเพื่อให้ได้พฤติกรรมที่คาดไว้คือการใช้ tf.init_scope เพื่อยกการดำเนินการนอกกราฟฟังก์ชัน เพื่อให้แน่ใจว่าการเพิ่มตัวแปรทำได้เพียงครั้งเดียวในช่วงเวลาการติดตาม ควรสังเกตว่า init_scope มีผลข้างเคียงอื่น ๆ รวมถึงการเคลียร์โฟลว์ควบคุมและเทปเกรเดียนต์ บางครั้งการใช้ init_scope อาจซับซ้อนเกินกว่าจะจัดการได้อย่างสมจริง

class Model(tf.Module):
  def __init__(self):
    self.v = tf.Variable(0)
    self.counter = 0

  @tf.function
  def __call__(self):
    if self.counter == 0:
      # Lifts ops out of function-building graphs
      with tf.init_scope():
        self.counter += 1
        self.v.assign_add(1)

    return self.v

m = Model()
for n in range(3):
  print(m().numpy()) # prints 1, 1, 1
1
1
1

โดยสรุป ตามหลักการทั่วไป คุณควรหลีกเลี่ยงการกลายพันธุ์อ็อบเจ็กต์ python เช่น จำนวนเต็มหรือคอนเทนเนอร์ เช่น รายการที่อาศัยอยู่นอก Function ให้ใช้อาร์กิวเมนต์และวัตถุ TF แทน ตัวอย่างเช่น ส่วน "การสะสมค่าในลูป" มีตัวอย่างหนึ่งตัวอย่างเกี่ยวกับวิธีการดำเนินการที่เหมือนรายการ

ในบางกรณี คุณสามารถจับภาพและจัดการสถานะได้หากเป็น tf.Variable นี่คือวิธีอัปเดตน้ำหนักของโมเดล Keras ด้วยการเรียกซ้ำไปยัง ConcreteFunction เดียวกัน

การใช้ตัววนซ้ำและตัวสร้าง Python

คุณลักษณะของ Python จำนวนมาก เช่น ตัวสร้างและตัววนซ้ำ อาศัยรันไทม์ของ Python เพื่อติดตามสถานะ โดยทั่วไป แม้ว่าโครงสร้างเหล่านี้จะทำงานตามที่คาดไว้ในโหมดกระตือรือร้น แต่ก็เป็นตัวอย่างของผลข้างเคียงของ Python ดังนั้นจึงเกิดขึ้นในระหว่างการติดตามเท่านั้น

@tf.function
def buggy_consume_next(iterator):
  tf.print("Value:", next(iterator))

iterator = iter([1, 2, 3])
buggy_consume_next(iterator)
# This reuses the first value from the iterator, rather than consuming the next value.
buggy_consume_next(iterator)
buggy_consume_next(iterator)
Value: 1
Value: 1
Value: 1
ตัวยึดตำแหน่ง60

เช่นเดียวกับที่ TensorFlow มี tf.TensorArray เฉพาะสำหรับการสร้างรายการ แต่ก็มี tf.data.Iterator เฉพาะสำหรับการสร้างการวนซ้ำ ดูส่วน การแปลง AutoGraph สำหรับภาพรวม นอกจากนี้ tf.data API ยังสามารถช่วยในการนำรูปแบบตัวสร้างไปใช้:

@tf.function
def good_consume_next(iterator):
  # This is ok, iterator is a tf.data.Iterator
  tf.print("Value:", next(iterator))

ds = tf.data.Dataset.from_tensor_slices([1, 2, 3])
iterator = iter(ds)
good_consume_next(iterator)
good_consume_next(iterator)
good_consume_next(iterator)
Value: 1
Value: 2
Value: 3

เอาต์พุตทั้งหมดของ tf.function ต้องเป็นค่าส่งคืน

ยกเว้น tf.Variable s tf.function ต้องส่งคืนเอาต์พุตทั้งหมด การพยายามเข้าถึงเทนเซอร์จากฟังก์ชันโดยตรงโดยไม่ผ่านค่าที่ส่งกลับทำให้เกิด "การรั่วไหล"

ตัวอย่างเช่น ฟังก์ชั่นด้านล่าง "รั่ว" เทนเซอร์ a ผ่าน Python global x :

x = None

@tf.function
def leaky_function(a):
  global x
  x = a + 1  # Bad - leaks local tensor
  return a + 2

correct_a = leaky_function(tf.constant(1))

print(correct_a.numpy())  # Good - value obtained from function's returns
try:
  x.numpy()  # Bad - tensor leaked from inside the function, cannot be used here
except AttributeError as expected:
  print(expected)
3
'Tensor' object has no attribute 'numpy'

สิ่งนี้เป็นจริงแม้ว่าจะส่งคืนค่าที่รั่วไหลออกมาด้วย:

@tf.function
def leaky_function(a):
  global x
  x = a + 1  # Bad - leaks local tensor
  return x  # Good - uses local tensor

correct_a = leaky_function(tf.constant(1))

print(correct_a.numpy())  # Good - value obtained from function's returns
try:
  x.numpy()  # Bad - tensor leaked from inside the function, cannot be used here
except AttributeError as expected:
  print(expected)

@tf.function
def captures_leaked_tensor(b):
  b += x  # Bad - `x` is leaked from `leaky_function`
  return b

with assert_raises(TypeError):
  captures_leaked_tensor(tf.constant(2))
2
'Tensor' object has no attribute 'numpy'
Caught expected exception 
  <class 'TypeError'>:
Traceback (most recent call last):
  File "/tmp/ipykernel_26244/3551158538.py", line 8, in assert_raises
    yield
  File "/tmp/ipykernel_26244/566849597.py", line 21, in <module>
    captures_leaked_tensor(tf.constant(2))
TypeError: Originated from a graph execution error.

The graph execution error is detected at a node built at (most recent call last):
>>>  File /usr/lib/python3.7/runpy.py, line 193, in _run_module_as_main
>>>  File /usr/lib/python3.7/runpy.py, line 85, in _run_code
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/ipykernel_launcher.py, line 16, in <module>
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/traitlets/config/application.py, line 846, in launch_instance
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/ipykernel/kernelapp.py, line 677, in start
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tornado/platform/asyncio.py, line 199, in start
>>>  File /usr/lib/python3.7/asyncio/base_events.py, line 534, in run_forever
>>>  File /usr/lib/python3.7/asyncio/base_events.py, line 1771, in _run_once
>>>  File /usr/lib/python3.7/asyncio/events.py, line 88, in _run
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/ipykernel/kernelbase.py, line 457, in dispatch_queue
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/ipykernel/kernelbase.py, line 446, in process_one
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/ipykernel/kernelbase.py, line 353, in dispatch_shell
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/ipykernel/kernelbase.py, line 648, in execute_request
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/ipykernel/ipkernel.py, line 353, in do_execute
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/ipykernel/zmqshell.py, line 533, in run_cell
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/IPython/core/interactiveshell.py, line 2902, in run_cell
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/IPython/core/interactiveshell.py, line 2947, in _run_cell
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/IPython/core/async_helpers.py, line 68, in _pseudo_sync_runner
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/IPython/core/interactiveshell.py, line 3173, in run_cell_async
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/IPython/core/interactiveshell.py, line 3364, in run_ast_nodes
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/IPython/core/interactiveshell.py, line 3444, in run_code
>>>  File /tmp/ipykernel_26244/566849597.py, line 7, in <module>
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/util/traceback_utils.py, line 150, in error_handler
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py, line 910, in __call__
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py, line 958, in _call
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py, line 781, in _initialize
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/eager/function.py, line 3157, in _get_concrete_function_internal_garbage_collected
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/eager/function.py, line 3557, in _maybe_define_function
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/eager/function.py, line 3402, in _create_graph_function
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/framework/func_graph.py, line 1143, in func_graph_from_py_func
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py, line 672, in wrapped_fn
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/framework/func_graph.py, line 1125, in autograph_handler
>>>  File /tmp/ipykernel_26244/566849597.py, line 4, in leaky_function
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/util/traceback_utils.py, line 150, in error_handler
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/ops/math_ops.py, line 1383, in binary_op_wrapper
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/util/traceback_utils.py, line 150, in error_handler
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/util/dispatch.py, line 1096, in op_dispatch_handler
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/ops/math_ops.py, line 1737, in _add_dispatch
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/ops/gen_math_ops.py, line 476, in add_v2
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/framework/op_def_library.py, line 746, in _apply_op_helper
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/framework/func_graph.py, line 691, in _create_op_internal
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/framework/ops.py, line 3705, in _create_op_internal
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/framework/ops.py, line 2101, in __init__

Error detected in node 'add' defined at: File "/tmp/ipykernel_26244/566849597.py", line 4, in leaky_function

TypeError: tf.Graph captured an external symbolic tensor. The symbolic tensor 'add:0' created by node 'add' is captured by the tf.Graph being executed as an input. But a tf.Graph is not allowed to take symbolic tensors from another graph as its inputs. Make sure all captured inputs of the executing tf.Graph are not symbolic tensors. Use return values, explicit Python locals or TensorFlow collections to access it. Please see https://www.tensorflow.org/guide/function#all_outputs_of_a_tffunction_must_be_return_values for more information.

โดยปกติ การรั่วไหลเช่นนี้จะเกิดขึ้นเมื่อคุณใช้คำสั่ง Python หรือโครงสร้างข้อมูล นอกเหนือจากการรั่วของเทนเซอร์ที่ไม่สามารถเข้าถึงได้ คำสั่งดังกล่าวยังมีแนวโน้มว่าผิดเพราะนับเป็นผลข้างเคียงของ Python และไม่รับประกันว่าจะดำเนินการทุกครั้งที่เรียกใช้ฟังก์ชัน

วิธีทั่วไปในการรั่วไหลของเทนเซอร์ภายในยังรวมถึงการกลายพันธุ์คอลเล็กชัน Python ภายนอกหรืออ็อบเจ็กต์:

class MyClass:

  def __init__(self):
    self.field = None

external_list = []
external_object = MyClass()

def leaky_function():
  a = tf.constant(1)
  external_list.append(a)  # Bad - leaks tensor
  external_object.field = a  # Bad - leaks tensor

ไม่รองรับ tf.functions แบบเรียกซ้ำ

ไม่รองรับ Function แบบเรียกซ้ำและอาจทำให้เกิดการวนซ้ำไม่สิ้นสุด ตัวอย่างเช่น,

@tf.function
def recursive_fn(n):
  if n > 0:
    return recursive_fn(n - 1)
  else:
    return 1

with assert_raises(Exception):
  recursive_fn(tf.constant(5))  # Bad - maximum recursion error.
Caught expected exception 
  <class 'Exception'>:
Traceback (most recent call last):
  File "/tmp/ipykernel_26244/3551158538.py", line 8, in assert_raises
    yield
  File "/tmp/ipykernel_26244/2233998312.py", line 9, in <module>
    recursive_fn(tf.constant(5))  # Bad - maximum recursion error.
tensorflow.python.autograph.impl.api.StagingError: in user code:

    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 3, in recursive_fn  *
        if n > 0:
    File "/usr/lib/python3.7/abc.py", line 139, in __instancecheck__
        return _abc_instancecheck(cls, instance)

    RecursionError: maximum recursion depth exceeded while calling a Python object

แม้ว่า Function แบบเรียกซ้ำจะทำงานได้ แต่ฟังก์ชัน python จะถูกติดตามหลายครั้งและอาจมีผลกระทบต่อประสิทธิภาพ ตัวอย่างเช่น,

@tf.function
def recursive_fn(n):
  if n > 0:
    print('tracing')
    return recursive_fn(n - 1)
  else:
    return 1

recursive_fn(5)  # Warning - multiple tracings
tracing
tracing
tracing
tracing
tracing
<tf.Tensor: shape=(), dtype=int32, numpy=1>

ปัญหาที่ทราบ

หาก Function ของคุณประเมินไม่ถูกต้อง ข้อผิดพลาดอาจอธิบายได้จากปัญหาที่ทราบเหล่านี้ซึ่งวางแผนจะแก้ไขในอนาคต

ขึ้นอยู่กับตัวแปรโกลบอลและตัวแปรอิสระของ Python

Function สร้าง ConcreteFunction ใหม่เมื่อถูกเรียกด้วยค่าใหม่ของอาร์กิวเมนต์ Python อย่างไรก็ตาม มันไม่ได้ทำอย่างนั้นสำหรับการปิด Python, globals หรือ nonlocals ของ Function นั้น หากค่าของการเปลี่ยนแปลงระหว่างการเรียกใช้ Function Function จะยังใช้ค่าที่พวกเขามีเมื่อถูกติดตาม ซึ่งแตกต่างจากวิธีการทำงานของ Python ปกติ

ด้วยเหตุผลดังกล่าว คุณควรปฏิบัติตามรูปแบบการเขียนโปรแกรมเชิงฟังก์ชันที่ใช้อาร์กิวเมนต์แทนการปิดทับชื่อภายนอก

@tf.function
def buggy_add():
  return 1 + foo

@tf.function
def recommended_add(foo):
  return 1 + foo

foo = 1
print("Buggy:", buggy_add())
print("Correct:", recommended_add(foo))
Buggy: tf.Tensor(2, shape=(), dtype=int32)
Correct: tf.Tensor(2, shape=(), dtype=int32)
print("Updating the value of `foo` to 100!")
foo = 100
print("Buggy:", buggy_add())  # Did not change!
print("Correct:", recommended_add(foo))
Updating the value of `foo` to 100!
Buggy: tf.Tensor(2, shape=(), dtype=int32)
Correct: tf.Tensor(101, shape=(), dtype=int32)

อีกวิธีหนึ่งในการอัปเดตค่าส่วนกลางคือทำให้เป็น tf.Variable และใช้วิธี Variable.assign แทน

@tf.function
def variable_add():
  return 1 + foo

foo = tf.Variable(1)
print("Variable:", variable_add())
Variable: tf.Tensor(2, shape=(), dtype=int32)
print("Updating the value of `foo` to 100!")
foo.assign(100)
print("Variable:", variable_add())
Updating the value of `foo` to 100!
Variable: tf.Tensor(101, shape=(), dtype=int32)

ขึ้นอยู่กับวัตถุ Python

คำแนะนำในการส่งอ็อบเจ็กต์ Python เป็นอาร์กิวเมนต์ไปยัง tf.function มีปัญหาที่ทราบจำนวนหนึ่ง ซึ่งคาดว่าจะได้รับการแก้ไขในอนาคต โดยทั่วไป คุณสามารถพึ่งพาการติดตามที่สอดคล้องกันได้หากคุณใช้ Python primitive หรือโครงสร้างที่เข้ากันได้กับ tf.nest เป็นอาร์กิวเมนต์ หรือส่งผ่านในอินสแตนซ์ อื่น ของอ็อบเจ็กต์ไปยัง Function อย่างไรก็ตาม Function จะ ไม่ สร้างการติดตามใหม่เมื่อคุณส่งผ่าน วัตถุเดียวกันและเปลี่ยนเฉพาะแอตทริบิวต์ของวัตถุเท่านั้น

class SimpleModel(tf.Module):
  def __init__(self):
    # These values are *not* tf.Variables.
    self.bias = 0.
    self.weight = 2.

@tf.function
def evaluate(model, x):
  return model.weight * x + model.bias

simple_model = SimpleModel()
x = tf.constant(10.)
print(evaluate(simple_model, x))
tf.Tensor(20.0, shape=(), dtype=float32)
print("Adding bias!")
simple_model.bias += 5.0
print(evaluate(simple_model, x))  # Didn't change :(
Adding bias!
tf.Tensor(20.0, shape=(), dtype=float32)

การใช้ Function เดียวกันเพื่อประเมินอินสแตนซ์ที่อัปเดตของโมเดลจะมีปัญหาเนื่องจากรุ่นที่อัปเดตมี คีย์แคชเดียวกัน กับรุ่นดั้งเดิม

ด้วยเหตุผลดังกล่าว ขอแนะนำให้คุณเขียน Function เพื่อหลีกเลี่ยงการขึ้นอยู่กับแอตทริบิวต์ของวัตถุที่เปลี่ยนแปลงได้หรือสร้างวัตถุใหม่

หากไม่สามารถทำได้ วิธีแก้ไขชั่วคราววิธีหนึ่งคือสร้าง Function ใหม่ทุกครั้งที่คุณแก้ไขวัตถุของคุณเพื่อบังคับให้ย้อนเวลา:

def evaluate(model, x):
  return model.weight * x + model.bias

new_model = SimpleModel()
evaluate_no_bias = tf.function(evaluate).get_concrete_function(new_model, x)
# Don't pass in `new_model`, `Function` already captured its state during tracing.
print(evaluate_no_bias(x))
tf.Tensor(20.0, shape=(), dtype=float32)
print("Adding bias!")
new_model.bias += 5.0
# Create new Function and ConcreteFunction since you modified new_model.
evaluate_with_bias = tf.function(evaluate).get_concrete_function(new_model, x)
print(evaluate_with_bias(x)) # Don't pass in `new_model`.
Adding bias!
tf.Tensor(25.0, shape=(), dtype=float32)

เนื่องจาก การย้อนรอยอาจมีราคาแพง คุณสามารถใช้ tf.Variable s เป็นแอตทริบิวต์ของอ็อบเจ็กต์ ซึ่งสามารถกลายพันธุ์ได้ (แต่ไม่สามารถเปลี่ยนแปลงได้ โปรดระวัง!) เพื่อให้ได้ผลที่คล้ายกันโดยไม่ต้องมีการย้อน

class BetterModel:

  def __init__(self):
    self.bias = tf.Variable(0.)
    self.weight = tf.Variable(2.)

@tf.function
def evaluate(model, x):
  return model.weight * x + model.bias

better_model = BetterModel()
print(evaluate(better_model, x))
tf.Tensor(20.0, shape=(), dtype=float32)
print("Adding bias!")
better_model.bias.assign_add(5.0)  # Note: instead of better_model.bias += 5
print(evaluate(better_model, x))  # This works!
Adding bias!
tf.Tensor(25.0, shape=(), dtype=float32)

การสร้าง tf.Variables

Function รองรับเฉพาะ singleton tf.Variable ที่สร้างขึ้นเพียงครั้งเดียวในการเรียกครั้งแรก และนำมาใช้ใหม่ในการเรียกฟังก์ชันที่ตามมา ข้อมูลโค้ดด้านล่างจะสร้าง tf.Variable ใหม่ในทุกๆ การเรียกใช้ฟังก์ชัน ซึ่งส่งผลให้เกิดข้อยกเว้น ValueError

ตัวอย่าง:

@tf.function
def f(x):
  v = tf.Variable(1.0)
  return v

with assert_raises(ValueError):
  f(1.0)
Caught expected exception 
  <class 'ValueError'>:
Traceback (most recent call last):
  File "/tmp/ipykernel_26244/3551158538.py", line 8, in assert_raises
    yield
  File "/tmp/ipykernel_26244/3018268426.py", line 7, in <module>
    f(1.0)
ValueError: in user code:

    File "/tmp/ipykernel_26244/3018268426.py", line 3, in f  *
        v = tf.Variable(1.0)

    ValueError: tf.function only supports singleton tf.Variables created on the first call. Make sure the tf.Variable is only created once or created outside tf.function. See https://www.tensorflow.org/guide/function#creating_tfvariables for more information.
ตัวยึดตำแหน่ง93

รูปแบบทั่วไปที่ใช้ในการแก้ไขข้อจำกัดนี้คือการเริ่มต้นด้วยค่า Python None จากนั้นจึงสร้าง tf.Variable มีเงื่อนไขหากค่าเป็น None:

class Count(tf.Module):
  def __init__(self):
    self.count = None

  @tf.function
  def __call__(self):
    if self.count is None:
      self.count = tf.Variable(0)
    return self.count.assign_add(1)

c = Count()
print(c())
print(c())
tf.Tensor(1, shape=(), dtype=int32)
tf.Tensor(2, shape=(), dtype=int32)

การใช้งานกับเครื่องมือเพิ่มประสิทธิภาพ Keras หลายตัว

คุณอาจพบ ValueError: tf.function only supports singleton tf.Variables created on the first call. เมื่อใช้เครื่องมือเพิ่มประสิทธิภาพ Keras มากกว่าหนึ่งตัวกับ tf.function ข้อผิดพลาดนี้เกิดขึ้นเนื่องจากเครื่องมือเพิ่มประสิทธิภาพภายในสร้าง tf.Variables เมื่อใช้การไล่ระดับสีเป็นครั้งแรก

opt1 = tf.keras.optimizers.Adam(learning_rate = 1e-2)
opt2 = tf.keras.optimizers.Adam(learning_rate = 1e-3)

@tf.function
def train_step(w, x, y, optimizer):
   with tf.GradientTape() as tape:
       L = tf.reduce_sum(tf.square(w*x - y))
   gradients = tape.gradient(L, [w])
   optimizer.apply_gradients(zip(gradients, [w]))

w = tf.Variable(2.)
x = tf.constant([-1.])
y = tf.constant([2.])

train_step(w, x, y, opt1)
print("Calling `train_step` with different optimizer...")
with assert_raises(ValueError):
  train_step(w, x, y, opt2)
Calling `train_step` with different optimizer...
Caught expected exception 
  <class 'ValueError'>:
Traceback (most recent call last):
  File "/tmp/ipykernel_26244/3551158538.py", line 8, in assert_raises
    yield
  File "/tmp/ipykernel_26244/3167358578.py", line 18, in <module>
    train_step(w, x, y, opt2)
ValueError: in user code:

    File "/tmp/ipykernel_26244/3167358578.py", line 9, in train_step  *
        optimizer.apply_gradients(zip(gradients, [w]))
    File "/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/keras/optimizer_v2/optimizer_v2.py", line 639, in apply_gradients  **
        self._create_all_weights(var_list)
    File "/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/keras/optimizer_v2/optimizer_v2.py", line 828, in _create_all_weights
        _ = self.iterations
    File "/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/keras/optimizer_v2/optimizer_v2.py", line 835, in __getattribute__
        return super(OptimizerV2, self).__getattribute__(name)
    File "/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/keras/optimizer_v2/optimizer_v2.py", line 995, in iterations
        aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA)
    File "/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/keras/optimizer_v2/optimizer_v2.py", line 1202, in add_weight
        aggregation=aggregation)
    File "/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/keras/engine/base_layer_utils.py", line 129, in make_variable
        shape=variable_shape if variable_shape else None)

    ValueError: tf.function only supports singleton tf.Variables created on the first call. Make sure the tf.Variable is only created once or created outside tf.function. See https://www.tensorflow.org/guide/function#creating_tfvariables for more information.

หากคุณต้องการเปลี่ยนตัวเพิ่มประสิทธิภาพระหว่างการฝึก วิธีแก้ปัญหาคือสร้าง Function ใหม่สำหรับตัวเพิ่มประสิทธิภาพแต่ละตัว โดยเรียก ConcreteFunction โดยตรง

opt1 = tf.keras.optimizers.Adam(learning_rate = 1e-2)
opt2 = tf.keras.optimizers.Adam(learning_rate = 1e-3)

# Not a tf.function.
def train_step(w, x, y, optimizer):
   with tf.GradientTape() as tape:
       L = tf.reduce_sum(tf.square(w*x - y))
   gradients = tape.gradient(L, [w])
   optimizer.apply_gradients(zip(gradients, [w]))

w = tf.Variable(2.)
x = tf.constant([-1.])
y = tf.constant([2.])

# Make a new Function and ConcreteFunction for each optimizer.
train_step_1 = tf.function(train_step).get_concrete_function(w, x, y, opt1)
train_step_2 = tf.function(train_step).get_concrete_function(w, x, y, opt2)
for i in range(10):
  if i % 2 == 0:
    train_step_1(w, x, y) # `opt1` is not used as a parameter. 
  else:
    train_step_2(w, x, y) # `opt2` is not used as a parameter.

ใช้กับ Keras หลายรุ่น

คุณอาจพบ ValueError: tf.function only supports singleton tf.Variables created on the first call. เมื่อส่งผ่านอินสแตนซ์รุ่นต่าง ๆ ไปยัง Function เดียวกัน

ข้อผิดพลาดนี้เกิดขึ้นเนื่องจากโมเดล Keras (ซึ่ง ไม่ได้กำหนดรูปร่างอินพุต ) และเลเยอร์ Keras สร้าง tf.Variables เมื่อถูกเรียกครั้งแรก คุณอาจกำลังพยายามเริ่มต้นตัวแปรเหล่านั้นภายใน Function ซึ่งถูกเรียกไปแล้ว เพื่อหลีกเลี่ยงข้อผิดพลาดนี้ ให้ลองเรียก model.build(input_shape) เพื่อเริ่มต้นน้ำหนักทั้งหมดก่อนที่จะฝึกโมเดล

อ่านเพิ่มเติม

หากต้องการเรียนรู้เกี่ยวกับวิธีการส่งออกและโหลด Function โปรดดู คู่มือ SavedModel หากต้องการเรียนรู้เพิ่มเติมเกี่ยวกับการเพิ่มประสิทธิภาพกราฟที่ดำเนินการหลังการติดตาม โปรดดู คู่มือ Grappler หากต้องการเรียนรู้วิธีเพิ่มประสิทธิภาพไปป์ไลน์ข้อมูลและโปรไฟล์โมเดลของคุณ โปรดดู คู่มือตัวสร้างโปรไฟล์