عملکرد بهتر با عملکرد tf

مشاهده در TensorFlow.org در Google Colab اجرا شود مشاهده منبع در GitHubدانلود دفترچه یادداشت

در TensorFlow 2، اجرای مشتاق به طور پیش فرض روشن است. رابط کاربری بصری و منعطف است (اجرای عملیات یکباره بسیار آسان‌تر و سریع‌تر است)، اما این می‌تواند به هزینه عملکرد و قابلیت استقرار تمام شود.

می توانید از tf.function برای ایجاد نمودار از برنامه های خود استفاده کنید. این یک ابزار تبدیل است که نمودارهای جریان داده مستقل از پایتون را از کد پایتون شما ایجاد می کند. این به شما کمک می‌کند تا مدل‌های قابل حمل و عملکردی ایجاد کنید، و لازم است از SavedModel استفاده کنید.

این راهنما به شما کمک می کند تا نحوه عملکرد tf.function در زیر کاپوت را تصور کنید، بنابراین می توانید از آن به طور موثر استفاده کنید.

نکات و توصیه های اصلی عبارتند از:

  • در حالت مشتاق اشکال زدایی کنید، سپس با @tf.function تزئین کنید.
  • به عوارض جانبی پایتون مانند جهش شی یا ضمیمه لیست اعتماد نکنید.
  • tf.function بهترین عملکرد را با تنظیمات TensorFlow دارد. فراخوانی های NumPy و Python به ثابت تبدیل می شوند.

برپایی

import tensorflow as tf

برای نشان دادن انواع خطاهایی که ممکن است با آن مواجه شوید، یک تابع کمکی تعریف کنید:

import traceback
import contextlib

# Some helper code to demonstrate the kinds of errors you might encounter.
@contextlib.contextmanager
def assert_raises(error_class):
  try:
    yield
  except error_class as e:
    print('Caught expected exception \n  {}:'.format(error_class))
    traceback.print_exc(limit=2)
  except Exception as e:
    raise e
  else:
    raise Exception('Expected {} to be raised but no error was raised!'.format(
        error_class))

مبانی

استفاده

Function که تعریف می‌کنید (مثلاً با اعمال دکوراتور @tf.function ) درست مانند یک عملیات هسته‌ای TensorFlow است: می‌توانید آن را با اشتیاق اجرا کنید. شما می توانید گرادیان ها را محاسبه کنید. و غیره

@tf.function  # The decorator converts `add` into a `Function`.
def add(a, b):
  return a + b

add(tf.ones([2, 2]), tf.ones([2, 2]))  #  [[2., 2.], [2., 2.]]
<tf.Tensor: shape=(2, 2), dtype=float32, numpy=
array([[2., 2.],
       [2., 2.]], dtype=float32)>
v = tf.Variable(1.0)
with tf.GradientTape() as tape:
  result = add(v, 1.0)
tape.gradient(result, v)
<tf.Tensor: shape=(), dtype=float32, numpy=1.0>

می توانید از Function s در داخل Function های دیگر استفاده کنید.

@tf.function
def dense_layer(x, w, b):
  return add(tf.matmul(x, w), b)

dense_layer(tf.ones([3, 2]), tf.ones([2, 2]), tf.ones([2]))
<tf.Tensor: shape=(3, 2), dtype=float32, numpy=
array([[3., 3.],
       [3., 3.],
       [3., 3.]], dtype=float32)>

Function s می تواند سریعتر از کد مشتاق باشد، به خصوص برای نمودارهایی با تعداد زیادی عملیات کوچک. اما برای نمودارهایی با چند عملیات گران قیمت (مانند کانولوشن)، ممکن است سرعت زیادی مشاهده نکنید.

import timeit
conv_layer = tf.keras.layers.Conv2D(100, 3)

@tf.function
def conv_fn(image):
  return conv_layer(image)

image = tf.zeros([1, 200, 200, 100])
# Warm up
conv_layer(image); conv_fn(image)
print("Eager conv:", timeit.timeit(lambda: conv_layer(image), number=10))
print("Function conv:", timeit.timeit(lambda: conv_fn(image), number=10))
print("Note how there's not much difference in performance for convolutions")
Eager conv: 0.006058974999177735
Function conv: 0.005791576000774512
Note how there's not much difference in performance for convolutions

ردیابی

این بخش نحوه Function در زیر هود را نشان می دهد، از جمله جزئیات پیاده سازی که ممکن است در آینده تغییر کند . با این حال، هنگامی که متوجه شدید که چرا و چه زمانی ردیابی اتفاق می افتد، استفاده موثر از tf.function بسیار آسان تر است!

"ردیابی" چیست؟

یک Function برنامه شما را در یک نمودار TensorFlow اجرا می کند. با این حال، یک tf.Graph نمی تواند همه چیزهایی را که در یک برنامه مشتاق TensorFlow بنویسید، نشان دهد. به عنوان مثال، پایتون از چندشکلی پشتیبانی می کند، اما tf.Graph نیاز دارد که ورودی های آن دارای یک نوع داده و بعد مشخص باشد. یا ممکن است کارهای جانبی مانند خواندن آرگومان های خط فرمان، ایجاد خطا، یا کار با یک شیء پیچیده تر پایتون را انجام دهید. هیچ یک از این چیزها نمی توانند در یک tf.Graph اجرا شوند.

Function این شکاف را با جداسازی کد شما در دو مرحله پر می کند:

1) در مرحله اول که به آن " ردیابی " گفته می شود، Function یک tf.Graph جدید ایجاد می کند. کد پایتون به طور معمول اجرا می شود، اما تمام عملیات TensorFlow (مانند افزودن دو تنسور) به تعویق افتاده است: آنها توسط tf.Graph ضبط می شوند و اجرا نمی شوند.

2) در مرحله دوم یک tf.Graph که شامل هر آنچه در مرحله اول به تعویق افتاده بود اجرا می شود. این مرحله بسیار سریعتر از مرحله ردیابی است.

بسته به ورودی هایش، Function همیشه اولین مرحله را هنگام فراخوانی اجرا نمی کند. "قوانین ردیابی" را در زیر ببینید تا درک بهتری از نحوه تعیین این امر داشته باشید. رد شدن از مرحله اول و تنها اجرای مرحله دوم چیزی است که عملکرد بالای TensorFlow را به شما می دهد.

وقتی Function تصمیم به ردیابی می‌گیرد، مرحله ردیابی بلافاصله با مرحله دوم دنبال می‌شود، بنابراین فراخوانی Function هم tf.Graph را ایجاد و اجرا می‌کند. بعداً خواهید دید که چگونه می توانید فقط مرحله ردیابی را با get_concrete_function کنید.

وقتی آرگومان های انواع مختلف را به یک Function ارسال می کنید، هر دو مرحله اجرا می شوند:

@tf.function
def double(a):
  print("Tracing with", a)
  return a + a

print(double(tf.constant(1)))
print()
print(double(tf.constant(1.1)))
print()
print(double(tf.constant("a")))
print()
Tracing with Tensor("a:0", shape=(), dtype=int32)
tf.Tensor(2, shape=(), dtype=int32)

Tracing with Tensor("a:0", shape=(), dtype=float32)
tf.Tensor(2.2, shape=(), dtype=float32)

Tracing with Tensor("a:0", shape=(), dtype=string)
tf.Tensor(b'aa', shape=(), dtype=string)

توجه داشته باشید که اگر به طور مکرر Function را با همان نوع آرگومان فراخوانی کنید، TensorFlow مرحله ردیابی را رد می کند و از گراف ردیابی شده قبلی مجددا استفاده می کند، زیرا نمودار تولید شده یکسان است.

# This doesn't print 'Tracing with ...'
print(double(tf.constant("b")))
tf.Tensor(b'bb', shape=(), dtype=string)

می‌توانید از pretty_printed_concrete_signatures() برای دیدن همه ردیابی‌های موجود استفاده کنید:

print(double.pretty_printed_concrete_signatures())
double(a)
  Args:
    a: int32 Tensor, shape=()
  Returns:
    int32 Tensor, shape=()

double(a)
  Args:
    a: float32 Tensor, shape=()
  Returns:
    float32 Tensor, shape=()

double(a)
  Args:
    a: string Tensor, shape=()
  Returns:
    string Tensor, shape=()

تاکنون مشاهده کرده‌اید که tf.function یک لایه انتقال پویا و ذخیره‌سازی شده روی منطق ردیابی نمودار TensorFlow ایجاد می‌کند. برای توضیح بیشتر در مورد اصطلاحات:

  • یک tf.Graph یک نمایش خام، زبان شناس و قابل حمل از یک محاسبه TensorFlow است.
  • یک ConcreteFunction یک tf.Graph را می پیچد.
  • یک Function یک حافظه پنهان از ConcreteFunction ها را مدیریت می کند و گزینه مناسب را برای ورودی های شما انتخاب می کند.
  • tf.function یک تابع پایتون را می پیچد و یک شی Function را برمی گرداند.
  • Tracing یک tf.Graph ایجاد می‌کند و آن را در یک ConcreteFunction می‌پیچد که به عنوان ردیابی نیز شناخته می‌شود.

قوانین ردیابی

یک Function تعیین می کند که آیا از یک ConcreteFunction ردیابی شده مجدداً استفاده شود یا خیر. یک کلید حافظه پنهان کلیدی است که یک ConcreteFunction را بر اساس آرگ های ورودی و کوارگ های فراخوانی Function ، طبق قوانین زیر (که ممکن است تغییر کند) شناسایی می کند:

  • کلید تولید شده برای tf.Tensor شکل و نوع d آن است.
  • کلید تولید شده برای tf.Variable یک شناسه متغیر منحصر به فرد است.
  • کلید تولید شده برای یک پایتون اولیه (مانند int ، float ، str ) مقدار آن است.
  • کلید تولید شده برای namedtuple تودرتو، listtuple s، s dict و attr s مجموعه مسطح کلیدهای برگ است (نگاه کنید به nest.flatten ). (در نتیجه این مسطح کردن، فراخوانی یک تابع بتن با ساختار تودرتو متفاوت از آنچه در طول ردیابی استفاده می شود منجر به TypeError می شود).
  • برای همه انواع دیگر پایتون، کلید منحصر به شی است. به این ترتیب یک تابع یا متد برای هر نمونه ای که با آن فراخوانی می شود به طور مستقل ردیابی می شود.

کنترل ردیابی مجدد

Retracing، زمانی است که Function شما بیش از یک ردیابی ایجاد می‌کند، به شما کمک می‌کند تا اطمینان حاصل شود که TensorFlow نمودارهای درستی برای هر مجموعه ورودی ایجاد می‌کند. با این حال، ردیابی یک عملیات گران است! اگر Function شما نمودار جدیدی را برای هر تماس دوباره دنبال کند، متوجه خواهید شد که کد شما کندتر از زمانی که از tf.function استفاده نکرده باشید، اجرا می شود.

برای کنترل رفتار ردیابی، می توانید از تکنیک های زیر استفاده کنید:

  • برای محدود کردن ردیابی، tf.function را در input_signature مشخص کنید.
@tf.function(input_signature=(tf.TensorSpec(shape=[None], dtype=tf.int32),))
def next_collatz(x):
  print("Tracing with", x)
  return tf.where(x % 2 == 0, x // 2, 3 * x + 1)

print(next_collatz(tf.constant([1, 2])))
# You specified a 1-D tensor in the input signature, so this should fail.
with assert_raises(ValueError):
  next_collatz(tf.constant([[1, 2], [3, 4]]))

# You specified an int32 dtype in the input signature, so this should fail.
with assert_raises(ValueError):
  next_collatz(tf.constant([1.0, 2.0]))
Tracing with Tensor("x:0", shape=(None,), dtype=int32)
tf.Tensor([4 1], shape=(2,), dtype=int32)
Caught expected exception 
  <class 'ValueError'>:
Caught expected exception 
  <class 'ValueError'>:
Traceback (most recent call last):
  File "/tmp/ipykernel_26244/3551158538.py", line 8, in assert_raises
    yield
  File "/tmp/ipykernel_26244/1851403433.py", line 9, in <module>
    next_collatz(tf.constant([[1, 2], [3, 4]]))
ValueError: Python inputs incompatible with input_signature:
  inputs: (
    tf.Tensor(
[[1 2]
 [3 4]], shape=(2, 2), dtype=int32))
  input_signature: (
    TensorSpec(shape=(None,), dtype=tf.int32, name=None)).
Traceback (most recent call last):
  File "/tmp/ipykernel_26244/3551158538.py", line 8, in assert_raises
    yield
  File "/tmp/ipykernel_26244/1851403433.py", line 13, in <module>
    next_collatz(tf.constant([1.0, 2.0]))
ValueError: Python inputs incompatible with input_signature:
  inputs: (
    tf.Tensor([1. 2.], shape=(2,), dtype=float32))
  input_signature: (
    TensorSpec(shape=(None,), dtype=tf.int32, name=None)).
  • یک بعد [None] را در tf.TensorSpec کنید تا در استفاده مجدد از ردیابی انعطاف پذیر باشد.

    از آنجایی که TensorFlow تانسورها را بر اساس شکل آنها منطبق می‌کند، استفاده از یک بعد None به عنوان علامت عام به Function s اجازه می‌دهد تا از ردیابی‌ها برای ورودی با اندازه متغیر استفاده کند. اگر دنباله هایی با طول های مختلف یا تصاویری با اندازه های مختلف برای هر دسته داشته باشید، ورودی با اندازه متغیر می تواند رخ دهد (برای مثال به آموزش های ترانسفورماتور و Deep Dream مراجعه کنید).

@tf.function(input_signature=(tf.TensorSpec(shape=[None], dtype=tf.int32),))
def g(x):
  print('Tracing with', x)
  return x

# No retrace!
print(g(tf.constant([1, 2, 3])))
print(g(tf.constant([1, 2, 3, 4, 5])))
Tracing with Tensor("x:0", shape=(None,), dtype=int32)
tf.Tensor([1 2 3], shape=(3,), dtype=int32)
tf.Tensor([1 2 3 4 5], shape=(5,), dtype=int32)
  • برای کاهش ردیابی مجدد، آرگومان های پایتون را به تنسورها ارسال کنید.

    اغلب، آرگومان‌های پایتون برای کنترل هایپرپارامترها و ساختارهای گراف استفاده می‌شوند - برای مثال num_layers=10 یا training=True یا nonlinearity='relu' . بنابراین، اگر آرگومان پایتون تغییر کند، منطقی است که باید نمودار را دوباره دنبال کنید.

    با این حال، ممکن است از آرگومان پایتون برای کنترل ساخت گراف استفاده نشود. در این موارد، تغییر در مقدار پایتون می‌تواند باعث ردیابی مجدد بی‌ضروری شود. به عنوان مثال، این حلقه آموزشی را در نظر بگیرید، که AutoGraph به صورت پویا باز می شود. با وجود ردیابی های متعدد، نمودار تولید شده در واقع یکسان است، بنابراین ردیابی مجدد غیر ضروری است.

def train_one_step():
  pass

@tf.function
def train(num_steps):
  print("Tracing with num_steps = ", num_steps)
  tf.print("Executing with num_steps = ", num_steps)
  for _ in tf.range(num_steps):
    train_one_step()

print("Retracing occurs for different Python arguments.")
train(num_steps=10)
train(num_steps=20)

print()
print("Traces are reused for Tensor arguments.")
train(num_steps=tf.constant(10))
train(num_steps=tf.constant(20))
Retracing occurs for different Python arguments.
Tracing with num_steps =  10
Executing with num_steps =  10
Tracing with num_steps =  20
Executing with num_steps =  20

Traces are reused for Tensor arguments.
Tracing with num_steps =  Tensor("num_steps:0", shape=(), dtype=int32)
Executing with num_steps =  10
Executing with num_steps =  20

اگر نیاز به ردیابی مجدد دارید، یک Function جدید ایجاد کنید. اشیاء Function جداگانه تضمین می شود که ردیابی را به اشتراک نگذارند.

def f():
  print('Tracing!')
  tf.print('Executing')

tf.function(f)()
tf.function(f)()
Tracing!
Executing
Tracing!
Executing

به دست آوردن توابع بتن

هر بار که یک تابع ردیابی می شود، یک تابع بتن جدید ایجاد می شود. با استفاده از get_concrete_function می توانید مستقیماً یک تابع بتن به دست آورید.

print("Obtaining concrete trace")
double_strings = double.get_concrete_function(tf.constant("a"))
print("Executing traced function")
print(double_strings(tf.constant("a")))
print(double_strings(a=tf.constant("b")))
Obtaining concrete trace
Executing traced function
tf.Tensor(b'aa', shape=(), dtype=string)
tf.Tensor(b'bb', shape=(), dtype=string)
# You can also call get_concrete_function on an InputSpec
double_strings_from_inputspec = double.get_concrete_function(tf.TensorSpec(shape=[], dtype=tf.string))
print(double_strings_from_inputspec(tf.constant("c")))
tf.Tensor(b'cc', shape=(), dtype=string)

چاپ یک ConcreteFunction خلاصه ای از آرگومان های ورودی (با انواع) و نوع خروجی آن را نمایش می دهد.

print(double_strings)
ConcreteFunction double(a)
  Args:
    a: string Tensor, shape=()
  Returns:
    string Tensor, shape=()

شما همچنین می توانید به طور مستقیم امضای یک تابع بتن را بازیابی کنید.

print(double_strings.structured_input_signature)
print(double_strings.structured_outputs)
((TensorSpec(shape=(), dtype=tf.string, name='a'),), {})
Tensor("Identity:0", shape=(), dtype=string)

استفاده از ردیابی بتن با انواع ناسازگار باعث ایجاد خطا می شود

with assert_raises(tf.errors.InvalidArgumentError):
  double_strings(tf.constant(1))
Caught expected exception 
  <class 'tensorflow.python.framework.errors_impl.InvalidArgumentError'>:
Traceback (most recent call last):
  File "/tmp/ipykernel_26244/3551158538.py", line 8, in assert_raises
    yield
  File "/tmp/ipykernel_26244/3196284684.py", line 2, in <module>
    double_strings(tf.constant(1))
tensorflow.python.framework.errors_impl.InvalidArgumentError: cannot compute __inference_double_162 as input #0(zero-based) was expected to be a string tensor but is a int32 tensor [Op:__inference_double_162]

ممکن است متوجه شوید که آرگومان‌های پایتون در امضای ورودی یک تابع مشخص رفتار خاصی دارند. قبل از TensorFlow 2.3، آرگومان های پایتون به سادگی از امضای تابع بتن حذف شدند. با شروع TensorFlow 2.3، آرگومان‌های پایتون در امضا باقی می‌مانند، اما برای گرفتن مقدار تنظیم شده در طول ردیابی محدود می‌شوند.

@tf.function
def pow(a, b):
  return a ** b

square = pow.get_concrete_function(a=tf.TensorSpec(None, tf.float32), b=2)
print(square)
ConcreteFunction pow(a, b=2)
  Args:
    a: float32 Tensor, shape=<unknown>
  Returns:
    float32 Tensor, shape=<unknown>
assert square(tf.constant(10.0)) == 100

with assert_raises(TypeError):
  square(tf.constant(10.0), b=3)
Caught expected exception 
  <class 'TypeError'>:
Traceback (most recent call last):
  File "/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/eager/function.py", line 1721, in _call_impl
    cancellation_manager)
  File "/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/eager/function.py", line 1765, in _call_with_flat_signature
    raise TypeError(f"{self._flat_signature_summary()} got unexpected "
TypeError: pow(a) got unexpected keyword arguments: b.

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/tmp/ipykernel_26244/3551158538.py", line 8, in assert_raises
    yield
  File "/tmp/ipykernel_26244/2310937119.py", line 4, in <module>
    square(tf.constant(10.0), b=3)
TypeError: ConcreteFunction pow(a, b) was constructed with int value 2 in b, but was called with int value 3.

به دست آوردن نمودارها

هر تابع بتن یک لفاف قابل فراخوانی در اطراف یک tf.Graph است. اگرچه بازیابی شی tf.Graph واقعی چیزی نیست که معمولاً باید انجام دهید، می توانید آن را به راحتی از هر تابع مشخصی بدست آورید.

graph = double_strings.graph
for node in graph.as_graph_def().node:
  print(f'{node.input} -> {node.name}')
[] -> a
['a', 'a'] -> add
['add'] -> Identity

اشکال زدایی

به طور کلی، اشکال زدایی کد در حالت مشتاق آسان تر از داخل tf.function است. قبل از تزئین با tf.function باید مطمئن شوید که کد شما بدون خطا در حالت مشتاق اجرا می شود. برای کمک به فرآیند اشکال‌زدایی، می‌توانید با tf.config.run_functions_eagerly(True) تماس بگیرید تا tf.function را به صورت سراسری غیرفعال و دوباره فعال کنید.

هنگام ردیابی مشکلاتی که فقط در tf.function ظاهر می شوند، در اینجا چند نکته وجود دارد:

  • تماس‌های print ساده پایتون فقط در حین ردیابی اجرا می‌شوند و به شما کمک می‌کنند تا زمانی که عملکرد شما (دوباره) ردیابی می‌شود، ردیابی کنید.
  • تماس‌های tf.print هر بار اجرا می‌شوند و می‌توانند به شما در ردیابی مقادیر میانی در طول اجرا کمک کنند.
  • tf.debugging.enable_check_numerics یک راه آسان برای ردیابی محل ایجاد NaNs و Inf است.
  • pdb ( اشکال‌زدای پایتون ) می‌تواند به شما در درک آنچه در طول ردیابی می‌گذرد کمک کند. (اخطار: pdf شما را وارد کد منبع تبدیل شده توسط pdb می کند.)

تبدیلات خودکار

AutoGraph یک کتابخانه است که به طور پیش‌فرض در tf.function است و زیرمجموعه‌ای از کد مشتاق پایتون را به عملیات TensorFlow سازگار با گراف تبدیل می‌کند. این شامل جریان کنترل می شود مانند if , for , while .

عملیات های TensorFlow مانند tf.cond و tf.while_loop همچنان به کار خود ادامه می دهند، اما نوشتن و درک جریان کنترل اغلب زمانی که در پایتون نوشته می شود آسان تر است.

# A simple loop

@tf.function
def f(x):
  while tf.reduce_sum(x) > 1:
    tf.print(x)
    x = tf.tanh(x)
  return x

f(tf.random.uniform([5]))
[0.666458249 0.713946581 0.723879576 0.330758929 0.184087753]
[0.582645297 0.613145649 0.619306684 0.319202513 0.182036072]
[0.524585426 0.546337605 0.550645113 0.308785647 0.18005164]
[0.481231302 0.497770309 0.501003504 0.299331933 0.178130865]
[0.447229207 0.460361809 0.462906033 0.290701121 0.176270396]
[0.419618756 0.430379033 0.432449728 0.282779962 0.174467146]
[0.396609187 0.405638 0.407366514 0.275476 0.172718227]
[0.377043903 0.384762734 0.386234313 0.268712848 0.17102097]
[0.360137492 0.366836458 0.368109286 0.262426734 0.169372901]
[0.345335096 0.351221472 0.352336824 0.256563932 0.167771652]
[0.332231969 0.337458342 0.338446289 0.251078814 0.166215062]
[0.320524871 0.325206399 0.326089561 0.24593246 0.164701089]
[0.309981436 0.314206958 0.31500268 0.241091311 0.163227797]
[0.300420195 0.304259449 0.304981351 0.236526251 0.161793426]
[0.291697085 0.295205742 0.295864582 0.232211992 0.160396278]
[0.283696055 0.286919087 0.287523568 0.228126258 0.159034774]
[0.276322395 0.279296666 0.27985391 0.224249557 0.157707423]
[0.269497961 0.272254 0.272769839 0.220564634 0.15641281]
[0.263157606 0.265720904 0.266200244 0.21705614 0.155149609]
[0.257246554 0.259638608 0.260085613 0.213710397 0.153916568]
[0.251718313 0.25395745 0.254375577 0.210515186 0.152712509]
[0.246533215 0.248635098 0.249027327 0.207459539 0.151536316]
[0.241657034 0.243635193 0.244004101 0.204533577 0.15038693]
[0.237060249 0.238926381 0.239274174 0.201728329 0.149263337]
[0.232717097 0.234481394 0.234810054 0.199035719 0.148164615]
[0.228605017 0.230276451 0.230587661 0.196448416 0.147089839]
[0.224704206 0.226290658 0.22658591 0.193959698 0.14603813]
[0.220997125 0.222505584 0.222786173 0.191563457 0.145008713]
<tf.Tensor: shape=(5,), dtype=float32, numpy=
array([0.21746822, 0.21890487, 0.21917202, 0.18925412, 0.14400077],
      dtype=float32)>

اگر کنجکاو هستید، می توانید خودکار کد تولید شده را بررسی کنید.

print(tf.autograph.to_code(f.python_function))
def tf__f(x):
    with ag__.FunctionScope('f', 'fscope', ag__.ConversionOptions(recursive=True, user_requested=True, optional_features=(), internal_convert_user_code=True)) as fscope:
        do_return = False
        retval_ = ag__.UndefinedReturnValue()

        def get_state():
            return (x,)

        def set_state(vars_):
            nonlocal x
            (x,) = vars_

        def loop_body():
            nonlocal x
            ag__.converted_call(ag__.ld(tf).print, (ag__.ld(x),), None, fscope)
            x = ag__.converted_call(ag__.ld(tf).tanh, (ag__.ld(x),), None, fscope)

        def loop_test():
            return (ag__.converted_call(ag__.ld(tf).reduce_sum, (ag__.ld(x),), None, fscope) > 1)
        ag__.while_stmt(loop_test, loop_body, get_state, set_state, ('x',), {})
        try:
            do_return = True
            retval_ = ag__.ld(x)
        except:
            do_return = False
            raise
        return fscope.ret(retval_, do_return)

شرایط

AutoGraph برخی از دستورات if <condition> را به فراخوانی های tf.cond معادل تبدیل می کند. این جایگزینی در صورتی انجام می شود که <condition> یک تانسور باشد. در غیر این if به صورت شرطی پایتون اجرا می شود.

یک شرطی پایتون در حین ردیابی اجرا می شود، بنابراین دقیقاً یک شاخه از شرطی به گراف اضافه می شود. بدون AutoGraph، اگر جریان کنترل وابسته به داده وجود داشته باشد، این نمودار ردیابی شده نمی تواند شاخه جایگزین را بگیرد.

tf.cond هر دو شاخه شرطی را ردیابی کرده و به گراف اضافه می کند و به صورت پویا یک شاخه را در زمان اجرا انتخاب می کند. ردیابی می تواند عوارض جانبی ناخواسته ای داشته باشد. برای اطلاعات بیشتر جلوه‌های ردیابی AutoGraph را بررسی کنید.

@tf.function
def fizzbuzz(n):
  for i in tf.range(1, n + 1):
    print('Tracing for loop')
    if i % 15 == 0:
      print('Tracing fizzbuzz branch')
      tf.print('fizzbuzz')
    elif i % 3 == 0:
      print('Tracing fizz branch')
      tf.print('fizz')
    elif i % 5 == 0:
      print('Tracing buzz branch')
      tf.print('buzz')
    else:
      print('Tracing default branch')
      tf.print(i)

fizzbuzz(tf.constant(5))
fizzbuzz(tf.constant(20))
Tracing for loop
Tracing fizzbuzz branch
Tracing fizz branch
Tracing buzz branch
Tracing default branch
1
2
fizz
4
buzz
1
2
fizz
4
buzz
fizz
7
8
fizz
buzz
11
fizz
13
14
fizzbuzz
16
17
fizz
19
buzz

برای محدودیت‌های اضافی در مورد AutoGraph-Converted if به مستندات مرجع مراجعه کنید.

حلقه ها

AutoGraph برخی از دستورات for و while را به عملیات حلقه‌ای معادل TensorFlow تبدیل می‌کند، مانند tf.while_loop . اگر تبدیل نشود، حلقه for یا while به صورت حلقه پایتون اجرا می شود.

این جایگزینی در شرایط زیر انجام می شود:

  • for x in y : اگر y یک تانسور است، به tf.while_loop تبدیل کنید. در حالت خاصی که y یک tf.data.Dataset است، ترکیبی از tf.data.Dataset ops تولید می شود.
  • while <condition> : اگر <condition> یک تانسور است، به tf.while_loop تبدیل کنید.

یک حلقه پایتون در حین ردیابی اجرا می شود و برای هر تکرار حلقه، عملیات اضافی به tf.Graph اضافه می کند.

یک حلقه TensorFlow بدنه حلقه را ردیابی می کند و به صورت پویا تعداد دفعات تکرار در زمان اجرا را انتخاب می کند. بدنه حلقه فقط یک بار در tf.Graph ایجاد شده ظاهر می شود.

برای محدودیت‌های اضافی در مورد AutoGraph-Converted for and while به مستندات مرجع مراجعه کنید.

حلقه زدن روی داده های پایتون

یک دام رایج این است که روی داده های Python/NumPy در یک tf.function . این حلقه در طول فرآیند ردیابی اجرا می شود و برای هر تکرار حلقه یک کپی از مدل شما به tf.Graph می شود.

اگر می‌خواهید کل حلقه آموزشی را در tf.function ، مطمئن‌ترین راه برای انجام این کار این است که داده‌های خود را به صورت tf.data.Dataset تا AutoGraph به صورت پویا حلقه آموزشی را باز کند.

def measure_graph_size(f, *args):
  g = f.get_concrete_function(*args).graph
  print("{}({}) contains {} nodes in its graph".format(
      f.__name__, ', '.join(map(str, args)), len(g.as_graph_def().node)))

@tf.function
def train(dataset):
  loss = tf.constant(0)
  for x, y in dataset:
    loss += tf.abs(y - x) # Some dummy computation.
  return loss

small_data = [(1, 1)] * 3
big_data = [(1, 1)] * 10
measure_graph_size(train, small_data)
measure_graph_size(train, big_data)

measure_graph_size(train, tf.data.Dataset.from_generator(
    lambda: small_data, (tf.int32, tf.int32)))
measure_graph_size(train, tf.data.Dataset.from_generator(
    lambda: big_data, (tf.int32, tf.int32)))
train([(1, 1), (1, 1), (1, 1)]) contains 11 nodes in its graph
train([(1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1)]) contains 32 nodes in its graph
train(<FlatMapDataset shapes: (<unknown>, <unknown>), types: (tf.int32, tf.int32)>) contains 6 nodes in its graph
train(<FlatMapDataset shapes: (<unknown>, <unknown>), types: (tf.int32, tf.int32)>) contains 6 nodes in its graph

هنگام قرار دادن داده‌های Python/NumPy در یک مجموعه داده، به tf.data.Dataset.from_generator در مقابل tf.data.Dataset.from_tensors کنید. اولی داده‌ها را در پایتون نگه می‌دارد و آن‌ها را از طریق tf.py_function که می‌تواند پیامدهای عملکردی داشته باشد، واکشی می‌کند، در حالی که دومی یک کپی از داده‌ها را به عنوان یک گره tf.constant() بزرگ در نمودار جمع می‌کند که می‌تواند پیامدهای حافظه داشته باشد.

خواندن داده‌ها از فایل‌ها از طریق TFRecordDataset ، CsvDataset و غیره مؤثرترین راه برای مصرف داده‌ها است، زیرا خود TensorFlow می‌تواند بارگیری ناهمزمان و واکشی اولیه داده‌ها را بدون نیاز به پایتون مدیریت کند. برای کسب اطلاعات بیشتر، به راهنمای خطوط لوله ورودی tf.data : Build TensorFlow مراجعه کنید.

انباشته کردن مقادیر در یک حلقه

یک الگوی رایج جمع آوری مقادیر میانی از یک حلقه است. به طور معمول، این کار با افزودن به فهرست پایتون یا افزودن مدخل هایی به فرهنگ لغت پایتون انجام می شود. با این حال، از آنجایی که اینها عوارض جانبی پایتون هستند، آن‌طور که انتظار می‌رود در یک حلقه بازشده پویا کار نخواهند کرد. از tf.TensorArray برای جمع آوری نتایج از یک حلقه به طور پویا استفاده کنید.

batch_size = 2
seq_len = 3
feature_size = 4

def rnn_step(inp, state):
  return inp + state

@tf.function
def dynamic_rnn(rnn_step, input_data, initial_state):
  # [batch, time, features] -> [time, batch, features]
  input_data = tf.transpose(input_data, [1, 0, 2])
  max_seq_len = input_data.shape[0]

  states = tf.TensorArray(tf.float32, size=max_seq_len)
  state = initial_state
  for i in tf.range(max_seq_len):
    state = rnn_step(input_data[i], state)
    states = states.write(i, state)
  return tf.transpose(states.stack(), [1, 0, 2])

dynamic_rnn(rnn_step,
            tf.random.uniform([batch_size, seq_len, feature_size]),
            tf.zeros([batch_size, feature_size]))
<tf.Tensor: shape=(2, 3, 4), dtype=float32, numpy=
array([[[0.06309307, 0.9938811 , 0.90789986, 0.42136216],
        [0.44997275, 1.9107027 , 1.0716251 , 0.717237  ],
        [0.6026064 , 2.1622117 , 1.4164022 , 1.4153863 ]],

       [[0.04946005, 0.69127274, 0.56848884, 0.22406638],
        [0.8148316 , 1.0278493 , 0.6207781 , 1.1935129 ],
        [0.9178308 , 1.320889  , 0.989761  , 2.0120025 ]]], dtype=float32)>

محدودیت ها

Function TensorFlow از نظر طراحی دارای چند محدودیت است که هنگام تبدیل تابع پایتون به Function باید از آنها آگاه باشید.

اجرای عوارض جانبی پایتون

عوارض جانبی، مانند چاپ، الحاق به فهرست‌ها، و جهش جهانی‌ها، می‌توانند به‌طور غیرمنتظره‌ای در داخل یک Function رفتار کنند، گاهی اوقات دو بار یا نه همه اجرا می‌شوند. آنها فقط اولین باری که یک Function را با مجموعه ای از ورودی ها فراخوانی می کنید اتفاق می افتد. سپس، tf.Graph ردیابی شده مجدداً بدون اجرای کد پایتون اجرا می شود.

قاعده کلی این است که از تکیه بر عوارض جانبی پایتون در منطق خود اجتناب کنید و فقط از آنها برای رفع اشکال ردیابی خود استفاده کنید. در غیر این صورت، API های TensorFlow مانند tf.data ، tf.print ، tf.summary ، tf.Variable.assign ، و tf.TensorArray بهترین راه برای اطمینان از اجرای کد شما توسط زمان اجرای TensorFlow با هر تماس هستند.

@tf.function
def f(x):
  print("Traced with", x)
  tf.print("Executed with", x)

f(1)
f(1)
f(2)
Traced with 1
Executed with 1
Executed with 1
Traced with 2
Executed with 2

اگر می خواهید کد پایتون را در طول هر فراخوانی یک Function اجرا کنید، tf.py_function یک دریچه خروجی است. اشکال tf.py_function این است که قابل حمل یا عملکرد خاصی نیست، نمی توان آن را با SavedModel ذخیره کرد و در تنظیمات توزیع شده (چند GPU، TPU) به خوبی کار نمی کند. همچنین، از آنجایی که tf.py_function باید به گراف متصل شود، تمام ورودی ها/خروجی ها را به تانسورها ارسال می کند.

تغییر متغیرهای جهانی و رایگان پایتون

تغییر متغیرهای جهانی و رایگان پایتون به عنوان یک عارضه جانبی پایتون به حساب می آید، بنابراین فقط در حین ردیابی اتفاق می افتد.

external_list = []

@tf.function
def side_effect(x):
  print('Python side effect')
  external_list.append(x)

side_effect(1)
side_effect(1)
side_effect(1)
# The list append only happened once!
assert len(external_list) == 1
Python side effect

گاهی اوقات تشخیص رفتارهای غیرمنتظره بسیار سخت است. در مثال زیر، counter برای محافظت از افزایش یک متغیر در نظر گرفته شده است. اما چون یک عدد صحیح پایتون است و نه یک شی TensorFlow، مقدار آن در اولین ردیابی ثبت می شود. زمانی که از tf.function استفاده می شود، assign_add بدون قید و شرط در نمودار زیرین ثبت می شود. بنابراین، هر بار که tf.function . فراخوانی شود، v به میزان 1 افزایش می یابد. این مشکل در میان کاربرانی که سعی می‌کنند کد Tensorflow حالت Grpah خود را با استفاده از دکوراتورهای tf.function به Tensorflow 2 منتقل کنند رایج است، زمانی که از عوارض جانبی پایتون (مشتری در مثال) برای تعیین اینکه چه assign_add counter مثال). ). معمولاً، کاربران تنها پس از مشاهده نتایج عددی مشکوک، یا عملکرد بسیار کمتر از حد انتظار (مثلاً اگر عملیات محافظت شده بسیار پرهزینه باشد) متوجه این موضوع می‌شوند.

class Model(tf.Module):
  def __init__(self):
    self.v = tf.Variable(0)
    self.counter = 0

  @tf.function
  def __call__(self):
    if self.counter == 0:
      # A python side-effect
      self.counter += 1
      self.v.assign_add(1)

    return self.v

m = Model()
for n in range(3):
  print(m().numpy()) # prints 1, 2, 3
1
2
3

یک راه حل برای دستیابی به رفتار مورد انتظار استفاده از tf.init_scope برای برداشتن عملیات خارج از نمودار تابع است. این تضمین می کند که افزایش متغیر فقط یک بار در طول زمان ردیابی انجام می شود. لازم به ذکر است init_scope دارای عوارض جانبی دیگری از جمله جریان کنترل پاک شده و نوار گرادیان است. گاهی اوقات استفاده از init_scope ممکن است برای مدیریت واقع بینانه آنقدر پیچیده شود.

class Model(tf.Module):
  def __init__(self):
    self.v = tf.Variable(0)
    self.counter = 0

  @tf.function
  def __call__(self):
    if self.counter == 0:
      # Lifts ops out of function-building graphs
      with tf.init_scope():
        self.counter += 1
        self.v.assign_add(1)

    return self.v

m = Model()
for n in range(3):
  print(m().numpy()) # prints 1, 1, 1
1
1
1

به طور خلاصه، به عنوان یک قانون سرانگشتی، باید از جهش اشیاء پایتون مانند اعداد صحیح یا محفظه هایی مانند لیست هایی که خارج از Function زندگی می کنند، اجتناب کنید. در عوض، از آرگومان ها و اشیاء TF استفاده کنید. به عنوان مثال، بخش "انباشت مقادیر در یک حلقه" یک مثال از نحوه اجرای عملیات لیست مانند دارد.

اگر حالت tf.Variable باشد، می‌توانید در برخی موارد، حالت را بگیرید و دستکاری کنید. اینگونه است که وزن مدل های Keras با تماس های مکرر به همان ConcreteFunction به روز می شود.

استفاده از تکرار کننده ها و مولدهای پایتون

بسیاری از ویژگی‌های پایتون، مانند ژنراتورها و تکرارکننده‌ها، برای پیگیری وضعیت به زمان اجرا پایتون متکی هستند. به طور کلی، در حالی که این ساختارها در حالت مشتاق همانطور که انتظار می رود کار می کنند، نمونه هایی از عوارض جانبی پایتون هستند و بنابراین فقط در حین ردیابی اتفاق می افتند.

@tf.function
def buggy_consume_next(iterator):
  tf.print("Value:", next(iterator))

iterator = iter([1, 2, 3])
buggy_consume_next(iterator)
# This reuses the first value from the iterator, rather than consuming the next value.
buggy_consume_next(iterator)
buggy_consume_next(iterator)
Value: 1
Value: 1
Value: 1

درست مانند اینکه چگونه TensorFlow یک tf.TensorArray تخصصی برای ساختارهای لیست دارد، یک tf.data.Iterator تخصصی برای ساختارهای تکرار دارد. برای یک نمای کلی به بخش تبدیلات خودکار گراف مراجعه کنید. همچنین، tf.data API می تواند به پیاده سازی الگوهای مولد کمک کند:

@tf.function
def good_consume_next(iterator):
  # This is ok, iterator is a tf.data.Iterator
  tf.print("Value:", next(iterator))

ds = tf.data.Dataset.from_tensor_slices([1, 2, 3])
iterator = iter(ds)
good_consume_next(iterator)
good_consume_next(iterator)
good_consume_next(iterator)
Value: 1
Value: 2
Value: 3

تمام خروجی های یک tf.function باید مقادیر بازگشتی باشند

به استثنای tf.Variable s، یک تابع tf. باید تمام خروجی های خود را برگرداند. تلاش برای دسترسی مستقیم به هر تانسور از یک تابع بدون عبور از مقادیر بازگشتی باعث "نشت" می شود.

به عنوان مثال، تابع زیر تانسور a را از طریق x جهانی پایتون "نشت" می کند:

x = None

@tf.function
def leaky_function(a):
  global x
  x = a + 1  # Bad - leaks local tensor
  return a + 2

correct_a = leaky_function(tf.constant(1))

print(correct_a.numpy())  # Good - value obtained from function's returns
try:
  x.numpy()  # Bad - tensor leaked from inside the function, cannot be used here
except AttributeError as expected:
  print(expected)
3
'Tensor' object has no attribute 'numpy'

این درست است حتی اگر مقدار لو رفته نیز برگردانده شود:

@tf.function
def leaky_function(a):
  global x
  x = a + 1  # Bad - leaks local tensor
  return x  # Good - uses local tensor

correct_a = leaky_function(tf.constant(1))

print(correct_a.numpy())  # Good - value obtained from function's returns
try:
  x.numpy()  # Bad - tensor leaked from inside the function, cannot be used here
except AttributeError as expected:
  print(expected)

@tf.function
def captures_leaked_tensor(b):
  b += x  # Bad - `x` is leaked from `leaky_function`
  return b

with assert_raises(TypeError):
  captures_leaked_tensor(tf.constant(2))
2
'Tensor' object has no attribute 'numpy'
Caught expected exception 
  <class 'TypeError'>:
Traceback (most recent call last):
  File "/tmp/ipykernel_26244/3551158538.py", line 8, in assert_raises
    yield
  File "/tmp/ipykernel_26244/566849597.py", line 21, in <module>
    captures_leaked_tensor(tf.constant(2))
TypeError: Originated from a graph execution error.

The graph execution error is detected at a node built at (most recent call last):
>>>  File /usr/lib/python3.7/runpy.py, line 193, in _run_module_as_main
>>>  File /usr/lib/python3.7/runpy.py, line 85, in _run_code
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/ipykernel_launcher.py, line 16, in <module>
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/traitlets/config/application.py, line 846, in launch_instance
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/ipykernel/kernelapp.py, line 677, in start
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tornado/platform/asyncio.py, line 199, in start
>>>  File /usr/lib/python3.7/asyncio/base_events.py, line 534, in run_forever
>>>  File /usr/lib/python3.7/asyncio/base_events.py, line 1771, in _run_once
>>>  File /usr/lib/python3.7/asyncio/events.py, line 88, in _run
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/ipykernel/kernelbase.py, line 457, in dispatch_queue
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/ipykernel/kernelbase.py, line 446, in process_one
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/ipykernel/kernelbase.py, line 353, in dispatch_shell
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/ipykernel/kernelbase.py, line 648, in execute_request
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/ipykernel/ipkernel.py, line 353, in do_execute
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/ipykernel/zmqshell.py, line 533, in run_cell
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/IPython/core/interactiveshell.py, line 2902, in run_cell
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/IPython/core/interactiveshell.py, line 2947, in _run_cell
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/IPython/core/async_helpers.py, line 68, in _pseudo_sync_runner
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/IPython/core/interactiveshell.py, line 3173, in run_cell_async
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/IPython/core/interactiveshell.py, line 3364, in run_ast_nodes
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/IPython/core/interactiveshell.py, line 3444, in run_code
>>>  File /tmp/ipykernel_26244/566849597.py, line 7, in <module>
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/util/traceback_utils.py, line 150, in error_handler
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py, line 910, in __call__
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py, line 958, in _call
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py, line 781, in _initialize
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/eager/function.py, line 3157, in _get_concrete_function_internal_garbage_collected
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/eager/function.py, line 3557, in _maybe_define_function
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/eager/function.py, line 3402, in _create_graph_function
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/framework/func_graph.py, line 1143, in func_graph_from_py_func
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py, line 672, in wrapped_fn
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/framework/func_graph.py, line 1125, in autograph_handler
>>>  File /tmp/ipykernel_26244/566849597.py, line 4, in leaky_function
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/util/traceback_utils.py, line 150, in error_handler
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/ops/math_ops.py, line 1383, in binary_op_wrapper
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/util/traceback_utils.py, line 150, in error_handler
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/util/dispatch.py, line 1096, in op_dispatch_handler
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/ops/math_ops.py, line 1737, in _add_dispatch
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/ops/gen_math_ops.py, line 476, in add_v2
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/framework/op_def_library.py, line 746, in _apply_op_helper
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/framework/func_graph.py, line 691, in _create_op_internal
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/framework/ops.py, line 3705, in _create_op_internal
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/framework/ops.py, line 2101, in __init__

Error detected in node 'add' defined at: File "/tmp/ipykernel_26244/566849597.py", line 4, in leaky_function

TypeError: tf.Graph captured an external symbolic tensor. The symbolic tensor 'add:0' created by node 'add' is captured by the tf.Graph being executed as an input. But a tf.Graph is not allowed to take symbolic tensors from another graph as its inputs. Make sure all captured inputs of the executing tf.Graph are not symbolic tensors. Use return values, explicit Python locals or TensorFlow collections to access it. Please see https://www.tensorflow.org/guide/function#all_outputs_of_a_tffunction_must_be_return_values for more information.

معمولاً وقتی از دستورات پایتون یا ساختارهای داده استفاده می کنید، نشت هایی مانند این اتفاق می افتد. علاوه بر افشای تانسورهای غیرقابل دسترس، چنین عباراتی نیز احتمالاً اشتباه هستند زیرا به عنوان عوارض جانبی پایتون به حساب می آیند و اجرای آنها در هر فراخوانی تابع تضمین نمی شود.

راه های رایج برای نشت تانسورهای محلی نیز شامل جهش یک مجموعه خارجی پایتون یا یک شی است:

class MyClass:

  def __init__(self):
    self.field = None

external_list = []
external_object = MyClass()

def leaky_function():
  a = tf.constant(1)
  external_list.append(a)  # Bad - leaks tensor
  external_object.field = a  # Bad - leaks tensor

توابع tf. بازگشتی پشتیبانی نمی شوند

Function بازگشتی پشتیبانی نمی شوند و می توانند حلقه های بی نهایت ایجاد کنند. مثلا،

@tf.function
def recursive_fn(n):
  if n > 0:
    return recursive_fn(n - 1)
  else:
    return 1

with assert_raises(Exception):
  recursive_fn(tf.constant(5))  # Bad - maximum recursion error.
Caught expected exception 
  <class 'Exception'>:
Traceback (most recent call last):
  File "/tmp/ipykernel_26244/3551158538.py", line 8, in assert_raises
    yield
  File "/tmp/ipykernel_26244/2233998312.py", line 9, in <module>
    recursive_fn(tf.constant(5))  # Bad - maximum recursion error.
tensorflow.python.autograph.impl.api.StagingError: in user code:

    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 3, in recursive_fn  *
        if n > 0:
    File "/usr/lib/python3.7/abc.py", line 139, in __instancecheck__
        return _abc_instancecheck(cls, instance)

    RecursionError: maximum recursion depth exceeded while calling a Python object

حتی اگر به نظر می رسد که یک Function بازگشتی کار می کند، تابع پایتون چندین بار ردیابی می شود و می تواند پیامدهای عملکردی داشته باشد. مثلا،

@tf.function
def recursive_fn(n):
  if n > 0:
    print('tracing')
    return recursive_fn(n - 1)
  else:
    return 1

recursive_fn(5)  # Warning - multiple tracings
tracing
tracing
tracing
tracing
tracing
<tf.Tensor: shape=(), dtype=int32, numpy=1>

مشکلات شناخته شده

اگر Function شما به درستی ارزیابی نمی‌شود، ممکن است خطا با این مسائل شناخته شده توضیح داده شود که در آینده برنامه‌ریزی شده‌اند برطرف شوند.

بسته به متغیرهای جهانی و رایگان پایتون

Function زمانی که با مقدار جدیدی از آرگومان پایتون فراخوانی می شود، یک ConcreteFunction جدید ایجاد می کند. با این حال، این کار را برای بسته شدن پایتون، جهانی‌ها یا غیرمحلی‌های آن Function انجام نمی‌دهد. اگر مقدار آنها در بین فراخوانی های Function تغییر کند، Function همچنان از مقادیری که هنگام ردیابی داشت استفاده می کند. این با نحوه عملکرد توابع معمولی پایتون متفاوت است.

به همین دلیل، شما باید از یک سبک برنامه نویسی تابعی پیروی کنید که به جای بستن روی نام های بیرونی، از آرگومان ها استفاده می کند.

@tf.function
def buggy_add():
  return 1 + foo

@tf.function
def recommended_add(foo):
  return 1 + foo

foo = 1
print("Buggy:", buggy_add())
print("Correct:", recommended_add(foo))
Buggy: tf.Tensor(2, shape=(), dtype=int32)
Correct: tf.Tensor(2, shape=(), dtype=int32)
print("Updating the value of `foo` to 100!")
foo = 100
print("Buggy:", buggy_add())  # Did not change!
print("Correct:", recommended_add(foo))
Updating the value of `foo` to 100!
Buggy: tf.Tensor(2, shape=(), dtype=int32)
Correct: tf.Tensor(101, shape=(), dtype=int32)

راه دیگر برای به روز رسانی یک مقدار سراسری، تبدیل آن به tf.Variable و استفاده از متد Variable.assign به جای آن است.

@tf.function
def variable_add():
  return 1 + foo

foo = tf.Variable(1)
print("Variable:", variable_add())
Variable: tf.Tensor(2, shape=(), dtype=int32)
print("Updating the value of `foo` to 100!")
foo.assign(100)
print("Variable:", variable_add())
Updating the value of `foo` to 100!
Variable: tf.Tensor(101, shape=(), dtype=int32)

بسته به اشیاء پایتون

توصیه برای ارسال اشیاء پایتون به عنوان آرگومان در tf.function دارای تعدادی مشکلات شناخته شده است که انتظار می رود در آینده برطرف شوند. به طور کلی، اگر از ساختار اولیه پایتون یا سازگار با tf.nest به عنوان آرگومان استفاده کنید یا نمونه‌ای متفاوت از یک شی را به یک Function ارسال کنید، می‌توانید به ردیابی ثابت تکیه کنید. با این حال، هنگام عبور از یک شی ، Function ردیابی جدیدی ایجاد نمی کند و فقط ویژگی های آن را تغییر می دهد.

class SimpleModel(tf.Module):
  def __init__(self):
    # These values are *not* tf.Variables.
    self.bias = 0.
    self.weight = 2.

@tf.function
def evaluate(model, x):
  return model.weight * x + model.bias

simple_model = SimpleModel()
x = tf.constant(10.)
print(evaluate(simple_model, x))
tf.Tensor(20.0, shape=(), dtype=float32)
print("Adding bias!")
simple_model.bias += 5.0
print(evaluate(simple_model, x))  # Didn't change :(
Adding bias!
tf.Tensor(20.0, shape=(), dtype=float32)

استفاده از همان Function برای ارزیابی نمونه به روز شده مدل باگ خواهد بود زیرا مدل به روز شده دارای کلید حافظه پنهان مشابه مدل اصلی است.

به همین دلیل، به شما توصیه می شود که Function خود را بنویسید تا از وابستگی به ویژگی های شیء قابل تغییر یا ایجاد اشیاء جدید اجتناب کنید.

اگر این امکان پذیر نیست، یک راه حل این است که هر بار که شیء خود را برای ردیابی مجدد اجباری تغییر می دهید، یک Function جدید ایجاد کنید:

def evaluate(model, x):
  return model.weight * x + model.bias

new_model = SimpleModel()
evaluate_no_bias = tf.function(evaluate).get_concrete_function(new_model, x)
# Don't pass in `new_model`, `Function` already captured its state during tracing.
print(evaluate_no_bias(x))
tf.Tensor(20.0, shape=(), dtype=float32)
print("Adding bias!")
new_model.bias += 5.0
# Create new Function and ConcreteFunction since you modified new_model.
evaluate_with_bias = tf.function(evaluate).get_concrete_function(new_model, x)
print(evaluate_with_bias(x)) # Don't pass in `new_model`.
Adding bias!
tf.Tensor(25.0, shape=(), dtype=float32)

از آنجایی که ردیابی مجدد می تواند گران باشد ، می توانید از tf.Variable s به عنوان ویژگی های شی استفاده کنید، که می تواند جهش یابد (اما تغییر نمی کند، مراقب باشید!) برای یک اثر مشابه بدون نیاز به ردیابی مجدد.

class BetterModel:

  def __init__(self):
    self.bias = tf.Variable(0.)
    self.weight = tf.Variable(2.)

@tf.function
def evaluate(model, x):
  return model.weight * x + model.bias

better_model = BetterModel()
print(evaluate(better_model, x))
tf.Tensor(20.0, shape=(), dtype=float32)
print("Adding bias!")
better_model.bias.assign_add(5.0)  # Note: instead of better_model.bias += 5
print(evaluate(better_model, x))  # This works!
Adding bias!
tf.Tensor(25.0, shape=(), dtype=float32)

ایجاد tf.Variables

Function فقط از singleton tf.Variable پشتیبانی می کند که یک بار در اولین تماس ایجاد شده و در فراخوانی های تابع بعدی مجددا استفاده می شود. قطعه کد زیر یک tf.Variable جدید در هر فراخوانی تابع ایجاد می کند که منجر به یک استثنا ValueError می شود.

مثال:

@tf.function
def f(x):
  v = tf.Variable(1.0)
  return v

with assert_raises(ValueError):
  f(1.0)
Caught expected exception 
  <class 'ValueError'>:
Traceback (most recent call last):
  File "/tmp/ipykernel_26244/3551158538.py", line 8, in assert_raises
    yield
  File "/tmp/ipykernel_26244/3018268426.py", line 7, in <module>
    f(1.0)
ValueError: in user code:

    File "/tmp/ipykernel_26244/3018268426.py", line 3, in f  *
        v = tf.Variable(1.0)

    ValueError: tf.function only supports singleton tf.Variables created on the first call. Make sure the tf.Variable is only created once or created outside tf.function. See https://www.tensorflow.org/guide/function#creating_tfvariables for more information.

یک الگوی رایج برای حل این محدودیت این است که با یک مقدار Python None شروع کنید، سپس اگر مقدار None باشد، tf.Variable را به صورت مشروط ایجاد کنید:

class Count(tf.Module):
  def __init__(self):
    self.count = None

  @tf.function
  def __call__(self):
    if self.count is None:
      self.count = tf.Variable(0)
    return self.count.assign_add(1)

c = Count()
print(c())
print(c())
tf.Tensor(1, shape=(), dtype=int32)
tf.Tensor(2, shape=(), dtype=int32)

استفاده با چندین بهینه ساز Keras

ممکن است با ValueError: tf.function only supports singleton tf.Variables created on the first call. هنگام استفاده از بیش از یک بهینه ساز Keras با یک tf.function . این خطا به این دلیل رخ می دهد که بهینه سازها وقتی برای اولین بار گرادیان ها را اعمال می کنند، tf.Variables را به صورت داخلی ایجاد می کنند.

opt1 = tf.keras.optimizers.Adam(learning_rate = 1e-2)
opt2 = tf.keras.optimizers.Adam(learning_rate = 1e-3)

@tf.function
def train_step(w, x, y, optimizer):
   with tf.GradientTape() as tape:
       L = tf.reduce_sum(tf.square(w*x - y))
   gradients = tape.gradient(L, [w])
   optimizer.apply_gradients(zip(gradients, [w]))

w = tf.Variable(2.)
x = tf.constant([-1.])
y = tf.constant([2.])

train_step(w, x, y, opt1)
print("Calling `train_step` with different optimizer...")
with assert_raises(ValueError):
  train_step(w, x, y, opt2)
Calling `train_step` with different optimizer...
Caught expected exception 
  <class 'ValueError'>:
Traceback (most recent call last):
  File "/tmp/ipykernel_26244/3551158538.py", line 8, in assert_raises
    yield
  File "/tmp/ipykernel_26244/3167358578.py", line 18, in <module>
    train_step(w, x, y, opt2)
ValueError: in user code:

    File "/tmp/ipykernel_26244/3167358578.py", line 9, in train_step  *
        optimizer.apply_gradients(zip(gradients, [w]))
    File "/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/keras/optimizer_v2/optimizer_v2.py", line 639, in apply_gradients  **
        self._create_all_weights(var_list)
    File "/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/keras/optimizer_v2/optimizer_v2.py", line 828, in _create_all_weights
        _ = self.iterations
    File "/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/keras/optimizer_v2/optimizer_v2.py", line 835, in __getattribute__
        return super(OptimizerV2, self).__getattribute__(name)
    File "/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/keras/optimizer_v2/optimizer_v2.py", line 995, in iterations
        aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA)
    File "/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/keras/optimizer_v2/optimizer_v2.py", line 1202, in add_weight
        aggregation=aggregation)
    File "/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/keras/engine/base_layer_utils.py", line 129, in make_variable
        shape=variable_shape if variable_shape else None)

    ValueError: tf.function only supports singleton tf.Variables created on the first call. Make sure the tf.Variable is only created once or created outside tf.function. See https://www.tensorflow.org/guide/function#creating_tfvariables for more information.

اگر در طول آموزش نیاز به تغییر بهینه ساز دارید، یک راه حل این است که برای هر بهینه ساز یک Function جدید ایجاد کنید و مستقیما ConcreteFunction را فراخوانی کنید.

opt1 = tf.keras.optimizers.Adam(learning_rate = 1e-2)
opt2 = tf.keras.optimizers.Adam(learning_rate = 1e-3)

# Not a tf.function.
def train_step(w, x, y, optimizer):
   with tf.GradientTape() as tape:
       L = tf.reduce_sum(tf.square(w*x - y))
   gradients = tape.gradient(L, [w])
   optimizer.apply_gradients(zip(gradients, [w]))

w = tf.Variable(2.)
x = tf.constant([-1.])
y = tf.constant([2.])

# Make a new Function and ConcreteFunction for each optimizer.
train_step_1 = tf.function(train_step).get_concrete_function(w, x, y, opt1)
train_step_2 = tf.function(train_step).get_concrete_function(w, x, y, opt2)
for i in range(10):
  if i % 2 == 0:
    train_step_1(w, x, y) # `opt1` is not used as a parameter. 
  else:
    train_step_2(w, x, y) # `opt2` is not used as a parameter.

استفاده با چندین مدل Keras

همچنین ممکن است با ValueError: tf.function only supports singleton tf.Variables created on the first call. هنگام ارسال نمونه های مختلف مدل به یک Function .

این خطا به این دلیل رخ می‌دهد که مدل‌های Keras (که شکل ورودی‌شان تعریف نشده است ) و لایه‌های Keras در اولین فراخوانی، tf.Variables ایجاد می‌کنند. ممکن است سعی کنید آن متغیرها را در داخل یک Function ، که قبلا فراخوانی شده است، مقداردهی اولیه کنید. برای جلوگیری از این خطا، سعی کنید مدل. model.build(input_shape) را فراخوانی کنید تا قبل از آموزش مدل، تمام وزن‌ها را مقداردهی اولیه کنید.

خواندن بیشتر

برای آشنایی با نحوه صادرات و بارگذاری یک Function ، به راهنمای SavedModel مراجعه کنید. برای کسب اطلاعات بیشتر در مورد بهینه سازی نمودار که پس از ردیابی انجام می شود، به راهنمای Grappler مراجعه کنید. برای یادگیری نحوه بهینه سازی خط لوله داده و نمایه مدل خود، به راهنمای Profiler مراجعه کنید.