עזרה להגן על שונית המחסום הגדולה עם TensorFlow על Kaggle הצטרפו אתגר

ביצועים טובים יותר עם פונקצית tf

הצג באתר TensorFlow.org הפעל בגוגל קולאב צפה במקור ב-GitHubהורד מחברת

בשנת TensorFlow 2, ביצוע להוט מופעל כברירת מחדל. ממשק המשתמש אינטואיטיבי וגמיש (הפעלת פעולות חד פעמיות היא הרבה יותר קלה ומהירה), אבל זה יכול לבוא על חשבון הביצועים והפריסה.

אתה יכול להשתמש tf.function לעשות גרפים מתוך התוכניות שלך. זהו כלי טרנספורמציה שיוצר גרפי זרימת נתונים עצמאיים לפייתון מתוך קוד הפייתון שלך. זה יעזור לך ליצור מודלים performant ונייד, והוא נדרש להשתמש SavedModel .

מדריך זה יעזור לך להמשיג איך tf.function פועל מאחורי הקלעים, כך שאתה יכול להשתמש בו ביעילות.

הטייק אווי וההמלצות העיקריות הן:

  • Debug במצב להוט, ואז לקשט עם @tf.function .
  • אל תסתמך על תופעות לוואי של Python כמו מוטציה של אובייקט או הוספת רשימה.
  • tf.function עובד הכי טוב עם ops TensorFlow; קריאות NumPy ו- ​​Python מומרות לקבועים.

להכין

import tensorflow as tf

הגדר פונקציית עוזר כדי להדגים את סוגי השגיאות שאתה עלול להיתקל בהן:

import traceback
import contextlib

# Some helper code to demonstrate the kinds of errors you might encounter.
@contextlib.contextmanager
def assert_raises(error_class):
  try:
    yield
  except error_class as e:
    print('Caught expected exception \n  {}:'.format(error_class))
    traceback.print_exc(limit=2)
  except Exception as e:
    raise e
  else:
    raise Exception('Expected {} to be raised but no error was raised!'.format(
        error_class))

יסודות

נוֹהָג

Function שאתה מגדיר (למשל על ידי יישום @tf.function המעצב) הוא בדיוק כמו מבצע TensorFlow ליבה: אתה יכול לבצע אותו בשקיקה; אתה יכול לחשב מעברי צבע; וכן הלאה.

@tf.function  # The decorator converts `add` into a `Function`.
def add(a, b):
  return a + b

add(tf.ones([2, 2]), tf.ones([2, 2]))  #  [[2., 2.], [2., 2.]]
<tf.Tensor: shape=(2, 2), dtype=float32, numpy=
array([[2., 2.],
       [2., 2.]], dtype=float32)>
v = tf.Variable(1.0)
with tf.GradientTape() as tape:
  result = add(v, 1.0)
tape.gradient(result, v)
<tf.Tensor: shape=(), dtype=float32, numpy=1.0>

אתה יכול להשתמש Function בתוך הים אחר Function של.

@tf.function
def dense_layer(x, w, b):
  return add(tf.matmul(x, w), b)

dense_layer(tf.ones([3, 2]), tf.ones([2, 2]), tf.ones([2]))
<tf.Tensor: shape=(3, 2), dtype=float32, numpy=
array([[3., 3.],
       [3., 3.],
       [3., 3.]], dtype=float32)>

Function של יכול להיות מהיר יותר מאשר קוד להוט, במיוחד עבור גרפים עם ops קטן רב. אבל עבור גרפים עם כמה פעולות יקרות (כמו פיתולים), ייתכן שלא תראה מהירות רבה.

import timeit
conv_layer = tf.keras.layers.Conv2D(100, 3)

@tf.function
def conv_fn(image):
  return conv_layer(image)

image = tf.zeros([1, 200, 200, 100])
# Warm up
conv_layer(image); conv_fn(image)
print("Eager conv:", timeit.timeit(lambda: conv_layer(image), number=10))
print("Function conv:", timeit.timeit(lambda: conv_fn(image), number=10))
print("Note how there's not much difference in performance for convolutions")
Eager conv: 0.006058974999177735
Function conv: 0.005791576000774512
Note how there's not much difference in performance for convolutions

מַעֲקָב

סעיף זה חושף כיצד Function פועל מאחורי הקלעים, כולל פרטים יישום אשר עשוי להשתנות בעתיד. עם זאת, ברגע שאתה מבין למה ומתי התחקות קורה, זה הרבה יותר קל להשתמש tf.function ביעילות!

מה זה "מעקב"?

Function מפעילה תוכנית שלך גרף TensorFlow . עם זאת, tf.Graph לא יכול לייצג את כל הדברים שאתה רוצה לכתוב בתוכנית TensorFlow להוט. למשל, Python תומך פולימורפיזם, אבל tf.Graph דורש תשומות שלה כדי להיות בעלת סוג נתונים שצוינו מימד. או שאתה יכול לבצע משימות צד כמו קריאת ארגומנטים של שורת הפקודה, העלאת שגיאה או עבודה עם אובייקט Python מורכב יותר; אף אחד מהדברים האלה יכולים לרוץ בתוך tf.Graph .

Function מגשר על הפער הזה על ידי הפרדת הקוד שלך בשני שלבים:

1) בשלב הראשון, המכונה "התחקות", Function יוצרת חדש tf.Graph . קוד Python פועל בדרך כלל, אבל כל הפעולות TensorFlow (כמו הוספת שתי tensors) נדחות: הם נלכדים על ידי tf.Graph ולא לרוץ.

2) בשלב השני, tf.Graph אשר מכיל את כול מה נדחה בשלב הראשון מנוהל. שלב זה מהיר בהרבה משלב האיתור.

בהתאם תשומותיו, Function תמיד לא תפעל בשלב הראשון כאשר הוא נקרא. ראה "כללים של התחקות" למטה כדי לקבל תחושה טובה יותר של איך זה עושה את זה נחישות. דילוג על השלב הראשון וביצוע השלב השני בלבד הוא מה שנותן לך את הביצועים הגבוהים של TensorFlow.

כאשר Function מחליטה להתחקות, בשלב התחקות ומיד אחריו את השלב השני, כך לקרוא Function הן יוצרת ומפעילה את tf.Graph . בהמשך תוכלו לראות איך אתה יכול להריץ רק בשלב התחקות עם get_concrete_function .

כאשר אתה עובר טיעונים מסוגים שונים לתוך Function , הוא בשלבים מנוהלים:

@tf.function
def double(a):
  print("Tracing with", a)
  return a + a

print(double(tf.constant(1)))
print()
print(double(tf.constant(1.1)))
print()
print(double(tf.constant("a")))
print()
Tracing with Tensor("a:0", shape=(), dtype=int32)
tf.Tensor(2, shape=(), dtype=int32)

Tracing with Tensor("a:0", shape=(), dtype=float32)
tf.Tensor(2.2, shape=(), dtype=float32)

Tracing with Tensor("a:0", shape=(), dtype=string)
tf.Tensor(b'aa', shape=(), dtype=string)

שים לב שאם אתה שוב ושוב להתקשר Function עם סוג הטיעון זהה, TensorFlow ידלג על השלב התחקות ושימוש חוזר גרף לייחס בעבר, כמו בגרף שנוצר יהיה זהה.

# This doesn't print 'Tracing with ...'
print(double(tf.constant("b")))
tf.Tensor(b'bb', shape=(), dtype=string)

אתה יכול להשתמש pretty_printed_concrete_signatures() כדי לראות את כול עקבות הזמינות:

print(double.pretty_printed_concrete_signatures())
double(a)
  Args:
    a: int32 Tensor, shape=()
  Returns:
    int32 Tensor, shape=()

double(a)
  Args:
    a: float32 Tensor, shape=()
  Returns:
    float32 Tensor, shape=()

double(a)
  Args:
    a: string Tensor, shape=()
  Returns:
    string Tensor, shape=()

עד כה, ראיתם כי tf.function יוצר במטמון, שכבת דינאמית שלוח מעל הגרף של TensorFlow התחקות ההיגיון. כדי להיות יותר ספציפי לגבי הטרמינולוגיה:

  • tf.Graph הוא הגלם, אגנוסטי בשפה, הייצוג הנייד של חישוב TensorFlow.
  • ConcreteFunction עוטפת tf.Graph .
  • Function מנהלת מטמון של ConcreteFunction ים ומרימה את האדם הנכון עבור התשומות שלך.
  • tf.function עוטפת פונקצית Python, חזרת Function אובייקט.
  • איתור יוצר tf.Graph ועוטף אותו ConcreteFunction , הידוע גם בשם זכר.

כללי איתור

Function קובעת אם לעשות שימוש חוזר לייחס ConcreteFunction ידי מחשוב מפתח מטמון מתוך ארגומנטים ו kwargs של קלט. מפתח מטמון הוא מפתח מזהה ConcreteFunction מבוסס על קלט ארגומנטים ו kwargs של Function השיחה, על פי הכללים הבאים (היכולה להשתנות):

  • המפתח שנוצר עבור tf.Tensor הוא הצורה ואת dtype.
  • המפתח שנוצר עבור tf.Variable הוא מזהה משתנה ייחודי.
  • המפתח שנוצר עבור פרימיטיבי Python (כמו int , float , str ) הערך שלה.
  • המפתח שנוצר עבור מקוננות dict ים, list של, tuple ים, namedtuple ים, ו attr s הוא tuple משוטח עלה-מפתחות (ראה nest.flatten ). (כתוצאה מהשטחה זו, קריאה לפונקציית בטון בעלת מבנה קינון שונה מזה המשמש במהלך המעקב תגרום ל-TypeError).
  • עבור כל סוגי הפייתון האחרים המפתח הוא ייחודי לאובייקט. בדרך זו עוקבים אחר פונקציה או שיטה באופן עצמאי עבור כל מופע שאיתו הוא נקרא.

שליטה בחזרה

Retracing, וזה כאשר שלך Function יוצרת יותר עקבות אחד, מסייעת מבטיח כי TensorFlow יוצרת גרפים נכונים עבור כול קבוצה של תשומות. עם זאת, איתור הוא פעולה יקרה! אם שלך Function משחזרת גרף חדש עבור כול שיחה, אתה תמצא כי מבצע את הקוד שלך יותר לאט מאשר אם לא השתמש tf.function .

כדי לשלוט בהתנהגות המעקב, תוכל להשתמש בטכניקות הבאות:

  • ציין input_signature ב tf.function כדי מעקב גבול.
@tf.function(input_signature=(tf.TensorSpec(shape=[None], dtype=tf.int32),))
def next_collatz(x):
  print("Tracing with", x)
  return tf.where(x % 2 == 0, x // 2, 3 * x + 1)

print(next_collatz(tf.constant([1, 2])))
# You specified a 1-D tensor in the input signature, so this should fail.
with assert_raises(ValueError):
  next_collatz(tf.constant([[1, 2], [3, 4]]))

# You specified an int32 dtype in the input signature, so this should fail.
with assert_raises(ValueError):
  next_collatz(tf.constant([1.0, 2.0]))
Tracing with Tensor("x:0", shape=(None,), dtype=int32)
tf.Tensor([4 1], shape=(2,), dtype=int32)
Caught expected exception 
  <class 'ValueError'>:
Caught expected exception 
  <class 'ValueError'>:
Traceback (most recent call last):
  File "/tmp/ipykernel_26244/3551158538.py", line 8, in assert_raises
    yield
  File "/tmp/ipykernel_26244/1851403433.py", line 9, in <module>
    next_collatz(tf.constant([[1, 2], [3, 4]]))
ValueError: Python inputs incompatible with input_signature:
  inputs: (
    tf.Tensor(
[[1 2]
 [3 4]], shape=(2, 2), dtype=int32))
  input_signature: (
    TensorSpec(shape=(None,), dtype=tf.int32, name=None)).
Traceback (most recent call last):
  File "/tmp/ipykernel_26244/3551158538.py", line 8, in assert_raises
    yield
  File "/tmp/ipykernel_26244/1851403433.py", line 13, in <module>
    next_collatz(tf.constant([1.0, 2.0]))
ValueError: Python inputs incompatible with input_signature:
  inputs: (
    tf.Tensor([1. 2.], shape=(2,), dtype=float32))
  input_signature: (
    TensorSpec(shape=(None,), dtype=tf.int32, name=None)).
  • ציין מימד [אף] ב tf.TensorSpec כדי לאפשר גמישות חוזר עקבות.

    מאז TensorFlow תואם tensors מבוסס על הצורה שלהם, באמצעות None ממד בתור תו יאפשר Function של עד עקבות חוזרות עבור קלט variably בגודל. קלט variably בגודל יכול להתרחש אם יש לך רצפים באורך שונה, או תמונות בגדלים שונים עבור כל אצווה (עיין Transformer ו- Deep חלום הדרכות למשל).

@tf.function(input_signature=(tf.TensorSpec(shape=[None], dtype=tf.int32),))
def g(x):
  print('Tracing with', x)
  return x

# No retrace!
print(g(tf.constant([1, 2, 3])))
print(g(tf.constant([1, 2, 3, 4, 5])))
Tracing with Tensor("x:0", shape=(None,), dtype=int32)
tf.Tensor([1 2 3], shape=(3,), dtype=int32)
tf.Tensor([1 2 3 4 5], shape=(5,), dtype=int32)
  • העבר ארגומנטים של Python לטנסורים כדי להפחית את החזרה.

    לעתים קרובות, ויכוחי Python משמשים hyperparameters המלא ומבנים גרף - למשל, num_layers=10 או training=True או nonlinearity='relu' . לכן, אם הארגומנט של Python משתנה, הגיוני שתצטרך לחזור על הגרף.

    עם זאת, ייתכן שלא נעשה שימוש בארגומנט Python כדי לשלוט בבניית הגרפים. במקרים אלה, שינוי בערך Python יכול להפעיל מעקב מיותר. קחו, למשל, את לולאת האימון הזו, ש-AutoGraph יפרש באופן דינמי. למרות העקבות המרובות, הגרף שנוצר למעשה זהה, כך שהחזרה לאחור היא מיותרת.

def train_one_step():
  pass

@tf.function
def train(num_steps):
  print("Tracing with num_steps = ", num_steps)
  tf.print("Executing with num_steps = ", num_steps)
  for _ in tf.range(num_steps):
    train_one_step()

print("Retracing occurs for different Python arguments.")
train(num_steps=10)
train(num_steps=20)

print()
print("Traces are reused for Tensor arguments.")
train(num_steps=tf.constant(10))
train(num_steps=tf.constant(20))
Retracing occurs for different Python arguments.
Tracing with num_steps =  10
Executing with num_steps =  10
Tracing with num_steps =  20
Executing with num_steps =  20

Traces are reused for Tensor arguments.
Tracing with num_steps =  Tensor("num_steps:0", shape=(), dtype=int32)
Executing with num_steps =  10
Executing with num_steps =  20

אם אתה צריך לשחזר את הכח, ליצור חדשה Function . פרד Function החפצה מובטחים שלא עקבות נתח.

def f():
  print('Tracing!')
  tf.print('Executing')

tf.function(f)()
tf.function(f)()
Tracing!
Executing
Tracing!
Executing

השגת פונקציות קונקרטיות

בכל פעם שמתחקים אחר פונקציה, נוצרת פונקציה קונקרטית חדשה. אתה יכול להשיג פונקצית בטון ישירות, באמצעות get_concrete_function .

print("Obtaining concrete trace")
double_strings = double.get_concrete_function(tf.constant("a"))
print("Executing traced function")
print(double_strings(tf.constant("a")))
print(double_strings(a=tf.constant("b")))
Obtaining concrete trace
Executing traced function
tf.Tensor(b'aa', shape=(), dtype=string)
tf.Tensor(b'bb', shape=(), dtype=string)
# You can also call get_concrete_function on an InputSpec
double_strings_from_inputspec = double.get_concrete_function(tf.TensorSpec(shape=[], dtype=tf.string))
print(double_strings_from_inputspec(tf.constant("c")))
tf.Tensor(b'cc', shape=(), dtype=string)

הדפסת ConcreteFunction מציג סיכום של טענות הקלט שלה (עם סוגים) ואת סוג הפלט שלה.

print(double_strings)
ConcreteFunction double(a)
  Args:
    a: string Tensor, shape=()
  Returns:
    string Tensor, shape=()

אתה יכול גם לאחזר ישירות את החתימה של פונקציה קונקרטית.

print(double_strings.structured_input_signature)
print(double_strings.structured_outputs)
((TensorSpec(shape=(), dtype=tf.string, name='a'),), {})
Tensor("Identity:0", shape=(), dtype=string)

שימוש בעקבות בטון עם סוגים לא תואמים יגרום לשגיאה

with assert_raises(tf.errors.InvalidArgumentError):
  double_strings(tf.constant(1))
Caught expected exception 
  <class 'tensorflow.python.framework.errors_impl.InvalidArgumentError'>:
Traceback (most recent call last):
  File "/tmp/ipykernel_26244/3551158538.py", line 8, in assert_raises
    yield
  File "/tmp/ipykernel_26244/3196284684.py", line 2, in <module>
    double_strings(tf.constant(1))
tensorflow.python.framework.errors_impl.InvalidArgumentError: cannot compute __inference_double_162 as input #0(zero-based) was expected to be a string tensor but is a int32 tensor [Op:__inference_double_162]

ייתכן שתבחין כי ארגומנטים של Python מקבלים טיפול מיוחד בחתימת הקלט של פונקציה קונקרטית. לפני TensorFlow 2.3, ארגומנטים של Python פשוט הוסרו מהחתימה של הפונקציה הקונקרטית. החל מ- TensorFlow 2.3, ארגומנטים של Python נשארים בחתימה, אך הם מוגבלים לקחת את הערך שנקבע במהלך המעקב.

@tf.function
def pow(a, b):
  return a ** b

square = pow.get_concrete_function(a=tf.TensorSpec(None, tf.float32), b=2)
print(square)
ConcreteFunction pow(a, b=2)
  Args:
    a: float32 Tensor, shape=<unknown>
  Returns:
    float32 Tensor, shape=<unknown>
assert square(tf.constant(10.0)) == 100

with assert_raises(TypeError):
  square(tf.constant(10.0), b=3)
Caught expected exception 
  <class 'TypeError'>:
Traceback (most recent call last):
  File "/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/eager/function.py", line 1721, in _call_impl
    cancellation_manager)
  File "/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/eager/function.py", line 1765, in _call_with_flat_signature
    raise TypeError(f"{self._flat_signature_summary()} got unexpected "
TypeError: pow(a) got unexpected keyword arguments: b.

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/tmp/ipykernel_26244/3551158538.py", line 8, in assert_raises
    yield
  File "/tmp/ipykernel_26244/2310937119.py", line 4, in <module>
    square(tf.constant(10.0), b=3)
TypeError: ConcreteFunction pow(a, b) was constructed with int value 2 in b, but was called with int value 3.

השגת גרפים

כל פונקציה בטון הוא מעטפת callable סביב tf.Graph . למרות אחזור בפועל tf.Graph האובייקט הוא לא משהו שאתה תצטרך בדרך כלל לעשות, אתה יכול להשיג אותו בקלות מכל פונקציה בטון.

graph = double_strings.graph
for node in graph.as_graph_def().node:
  print(f'{node.input} -> {node.name}')
[] -> a
['a', 'a'] -> add
['add'] -> Identity

איתור באגים

באופן כללי, באגים בקוד קל במצב להוט מאשר בתוך tf.function . אתה צריך לוודא כי הקוד שלך מבצע משגיאות במצב להוט לפני לקשט עם tf.function . כדי לסייע בתהליך הניפוי, אתה יכול להתקשר tf.config.run_functions_eagerly(True) כדי גלובלית להשבית בהפעלה מחדש tf.function .

כאשר להתחקות אחר סוגיות המופיעות רק בתוך tf.function , הנה כמה טיפים:

  • זקן Python מישור print שיחות רק לבצע במהלך מעקב, עוזרות לך לעקוב אחר מטה כאשר הפונקציה שלך מקבלת (מחדש) לייחס.
  • tf.print שיחות תבצענה בכול פעם, והוא יכול לעזור לך לאתר את ערכי ביניים במהלך ביצוע.
  • tf.debugging.enable_check_numerics היא דרך קלה לעקוב אחר מטה איפה NaNs ו Inf נוצרים.
  • pdb (את הבאגים Python ) יכול לעזור לך להבין מה קורה במהלך מעקב. (אזהרה: pdb ירד לך לתוך קוד מקור-טרנספורמציה חתימה.)

טרנספורמציות של גרף אוטומטי

חתימה היא ספרייה אשר מופעלת כברירת מחדל ב tf.function , והופכת משנה של קוד להוט Python לתוך ops TensorFlow תואם גרף. זה כולל בקרת זרימה כמו if , for , while .

Ops TensorFlow כמו tf.cond ו tf.while_loop להמשיך לעבוד, אך בקרת זרימה הוא בדרך כלל קל יותר לכתוב ולהבין כאשר כתוב Python.

# A simple loop

@tf.function
def f(x):
  while tf.reduce_sum(x) > 1:
    tf.print(x)
    x = tf.tanh(x)
  return x

f(tf.random.uniform([5]))
[0.666458249 0.713946581 0.723879576 0.330758929 0.184087753]
[0.582645297 0.613145649 0.619306684 0.319202513 0.182036072]
[0.524585426 0.546337605 0.550645113 0.308785647 0.18005164]
[0.481231302 0.497770309 0.501003504 0.299331933 0.178130865]
[0.447229207 0.460361809 0.462906033 0.290701121 0.176270396]
[0.419618756 0.430379033 0.432449728 0.282779962 0.174467146]
[0.396609187 0.405638 0.407366514 0.275476 0.172718227]
[0.377043903 0.384762734 0.386234313 0.268712848 0.17102097]
[0.360137492 0.366836458 0.368109286 0.262426734 0.169372901]
[0.345335096 0.351221472 0.352336824 0.256563932 0.167771652]
[0.332231969 0.337458342 0.338446289 0.251078814 0.166215062]
[0.320524871 0.325206399 0.326089561 0.24593246 0.164701089]
[0.309981436 0.314206958 0.31500268 0.241091311 0.163227797]
[0.300420195 0.304259449 0.304981351 0.236526251 0.161793426]
[0.291697085 0.295205742 0.295864582 0.232211992 0.160396278]
[0.283696055 0.286919087 0.287523568 0.228126258 0.159034774]
[0.276322395 0.279296666 0.27985391 0.224249557 0.157707423]
[0.269497961 0.272254 0.272769839 0.220564634 0.15641281]
[0.263157606 0.265720904 0.266200244 0.21705614 0.155149609]
[0.257246554 0.259638608 0.260085613 0.213710397 0.153916568]
[0.251718313 0.25395745 0.254375577 0.210515186 0.152712509]
[0.246533215 0.248635098 0.249027327 0.207459539 0.151536316]
[0.241657034 0.243635193 0.244004101 0.204533577 0.15038693]
[0.237060249 0.238926381 0.239274174 0.201728329 0.149263337]
[0.232717097 0.234481394 0.234810054 0.199035719 0.148164615]
[0.228605017 0.230276451 0.230587661 0.196448416 0.147089839]
[0.224704206 0.226290658 0.22658591 0.193959698 0.14603813]
[0.220997125 0.222505584 0.222786173 0.191563457 0.145008713]
<tf.Tensor: shape=(5,), dtype=float32, numpy=
array([0.21746822, 0.21890487, 0.21917202, 0.18925412, 0.14400077],
      dtype=float32)>

אם אתה סקרן אתה יכול לבדוק את חתימת הקוד שנוצרת.

print(tf.autograph.to_code(f.python_function))
def tf__f(x):
    with ag__.FunctionScope('f', 'fscope', ag__.ConversionOptions(recursive=True, user_requested=True, optional_features=(), internal_convert_user_code=True)) as fscope:
        do_return = False
        retval_ = ag__.UndefinedReturnValue()

        def get_state():
            return (x,)

        def set_state(vars_):
            nonlocal x
            (x,) = vars_

        def loop_body():
            nonlocal x
            ag__.converted_call(ag__.ld(tf).print, (ag__.ld(x),), None, fscope)
            x = ag__.converted_call(ag__.ld(tf).tanh, (ag__.ld(x),), None, fscope)

        def loop_test():
            return (ag__.converted_call(ag__.ld(tf).reduce_sum, (ag__.ld(x),), None, fscope) > 1)
        ag__.while_stmt(loop_test, loop_body, get_state, set_state, ('x',), {})
        try:
            do_return = True
            retval_ = ag__.ld(x)
        except:
            do_return = False
            raise
        return fscope.ret(retval_, do_return)

תנאים

החתימה תמיר חלק if <condition> דוחות למקבילת tf.cond השיחות. החלפה זו מתבצעת אם <condition> הוא מותח. אחרת, if בהצהרה מבוצעת בתור מותנית Python.

תנאי Python מבוצע במהלך המעקב, כך שדווקא ענף אחד של התנאי יתווסף לגרף. ללא AutoGraph, הגרף המתחקה הזה לא יוכל לקחת את הענף החלופי אם יש זרימת בקרה תלוית נתונים.

tf.cond עקבות ומוסיפות שני הסניפים של מותנות הגרף, בחירת סניף דינמי בשלב ביצוע. למעקב יכולות להיות תופעות לוואי לא מכוונות; לבדוק תופעות התחקות חתימה לקבלת מידע נוסף.

@tf.function
def fizzbuzz(n):
  for i in tf.range(1, n + 1):
    print('Tracing for loop')
    if i % 15 == 0:
      print('Tracing fizzbuzz branch')
      tf.print('fizzbuzz')
    elif i % 3 == 0:
      print('Tracing fizz branch')
      tf.print('fizz')
    elif i % 5 == 0:
      print('Tracing buzz branch')
      tf.print('buzz')
    else:
      print('Tracing default branch')
      tf.print(i)

fizzbuzz(tf.constant(5))
fizzbuzz(tf.constant(20))
Tracing for loop
Tracing fizzbuzz branch
Tracing fizz branch
Tracing buzz branch
Tracing default branch
1
2
fizz
4
buzz
1
2
fizz
4
buzz
fizz
7
8
fizz
buzz
11
fizz
13
14
fizzbuzz
16
17
fizz
19
buzz

עיין בתיעוד העזר עבור מגבלות נוספות על אוטוגרף המרה אם הצהרות.

לולאות

חתימה תמיר חלק for ו while הצהרות לתוך TensorFlow המקבילה looping ops, כמו tf.while_loop . אם לא המיר את for או while הלולאה מבוצעת כמו לולאת Python.

החלפה זו מתבצעת במצבים הבאים:

  • for x in y : אם y הוא מותח, להמיר tf.while_loop . במקרה המיוחד שבו y הוא tf.data.Dataset , שילוב של tf.data.Dataset ops נוצר.
  • while <condition> : אם <condition> הוא מותח, להמיר tf.while_loop .

פייתון עוסקת בביצוע לולאה במהלך עקיבה, הוספת ops נוסף tf.Graph לכול איטרציה של הלולאה.

לולאת TensorFlow עוקבת אחר גוף הלולאה, ובוחרת באופן דינמי כמה איטרציות לרוץ בזמן הביצוע. גוף הלולאה מופיע רק פעם אחת שנוצר tf.Graph .

עיין בתיעוד העזר עבור מגבלות נוספות על אוטוגרף מרה for ו while הצהרות.

לולאה מעל נתוני Python

נופל בפח הוא לולאה מעל נתון numpy / Python בתוך tf.function . לולאה זו תבצע במהלך תהליך התחקות, הוספת עותק של המודל שלך אל tf.Graph עבור כל איטרציה של הלולאה.

אם אתה רוצה לעטוף את לולאת אימונים השלמה tf.function , הדרך הבטוחה ביותר לעשות זאת היא לעטוף את נתון שלך בתור tf.data.Dataset כך החתימה תתעדכן באופן דינמי לְהִתְגוֹלֵל לולאת אימונים.

def measure_graph_size(f, *args):
  g = f.get_concrete_function(*args).graph
  print("{}({}) contains {} nodes in its graph".format(
      f.__name__, ', '.join(map(str, args)), len(g.as_graph_def().node)))

@tf.function
def train(dataset):
  loss = tf.constant(0)
  for x, y in dataset:
    loss += tf.abs(y - x) # Some dummy computation.
  return loss

small_data = [(1, 1)] * 3
big_data = [(1, 1)] * 10
measure_graph_size(train, small_data)
measure_graph_size(train, big_data)

measure_graph_size(train, tf.data.Dataset.from_generator(
    lambda: small_data, (tf.int32, tf.int32)))
measure_graph_size(train, tf.data.Dataset.from_generator(
    lambda: big_data, (tf.int32, tf.int32)))
train([(1, 1), (1, 1), (1, 1)]) contains 11 nodes in its graph
train([(1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1)]) contains 32 nodes in its graph
train(<FlatMapDataset shapes: (<unknown>, <unknown>), types: (tf.int32, tf.int32)>) contains 6 nodes in its graph
train(<FlatMapDataset shapes: (<unknown>, <unknown>), types: (tf.int32, tf.int32)>) contains 6 nodes in its graph

כאשר גלישת נתונים Python / numpy של מערך נתונים, להיות מודע tf.data.Dataset.from_generator לעומת tf.data.Dataset.from_tensors . הראשון ישמור על נתון ב Python ותחזיר אותו באמצעות tf.py_function אשר יכולה להיות שלכות ביצועים, בעוד שהאחרון יהיה צרור עותק של הנתונים כמו גדול אחד tf.constant() צומת בגרף, אשר יכולה להיות שלכות זיכרון.

קריאת נתונים מקבצים באמצעות TFRecordDataset , CsvDataset , וכו 'היא הדרך האפקטיבית ביותר לצרוך נתונים, כמו אז TensorFlow עצמו יכול לנהל את הטעינה אסינכרוני הכנה מקדימה של נתונים, מבלי לערב Python. כדי ללמוד עוד, ראה tf.data : צינורות הזנה Build TensorFlow להנחות.

צבירת ערכים בלולאה

דפוס נפוץ הוא צבירת ערכי ביניים מלולאה. בדרך כלל, זה מושג על ידי הוספה לרשימת Python או הוספת ערכים למילון Python. עם זאת, מכיוון שמדובר בתופעות לוואי של Python, הן לא יפעלו כצפוי בלולאה מגולגלת דינמית. השתמש tf.TensorArray לצבור תוצאות לולאה גלול דינמי.

batch_size = 2
seq_len = 3
feature_size = 4

def rnn_step(inp, state):
  return inp + state

@tf.function
def dynamic_rnn(rnn_step, input_data, initial_state):
  # [batch, time, features] -> [time, batch, features]
  input_data = tf.transpose(input_data, [1, 0, 2])
  max_seq_len = input_data.shape[0]

  states = tf.TensorArray(tf.float32, size=max_seq_len)
  state = initial_state
  for i in tf.range(max_seq_len):
    state = rnn_step(input_data[i], state)
    states = states.write(i, state)
  return tf.transpose(states.stack(), [1, 0, 2])

dynamic_rnn(rnn_step,
            tf.random.uniform([batch_size, seq_len, feature_size]),
            tf.zeros([batch_size, feature_size]))
<tf.Tensor: shape=(2, 3, 4), dtype=float32, numpy=
array([[[0.06309307, 0.9938811 , 0.90789986, 0.42136216],
        [0.44997275, 1.9107027 , 1.0716251 , 0.717237  ],
        [0.6026064 , 2.1622117 , 1.4164022 , 1.4153863 ]],

       [[0.04946005, 0.69127274, 0.56848884, 0.22406638],
        [0.8148316 , 1.0278493 , 0.6207781 , 1.1935129 ],
        [0.9178308 , 1.320889  , 0.989761  , 2.0120025 ]]], dtype=float32)>

מגבלות

TensorFlow Function יש מספר מגבלות על ידי עיצוב שאתה צריך להיות מודע בעת המרת פונקצית פיתון על Function .

ביצוע תופעות לוואי של Python

תופעות לוואי, כמו הדפסה, צירוף לרשימות, ו מוטציה globals, יכול להתנהג באופן בלתי צפוי בתוך Function , לפעמים ביצוע פעמיים או לא בכלל. הם רק לקרות בפעם הראשונה שאתם קוראים Function עם סט של כניסות. לאחר מכן, לייחס tf.Graph הוא reexecuted, ללא צורך בהפעלת קוד פיתון.

כלל האצבע הכללי הוא להימנע מהסתמכות על תופעות לוואי של Python בלוגיקה שלך ולהשתמש בהן רק כדי לנפות באגים שלך. אחרת, APIs TensorFlow כמו tf.data , tf.print , tf.summary , tf.Variable.assign , ו tf.TensorArray הם הדרך הטובה ביותר להבטיח את הקוד שלך יבוצע על ידי ריצה TensorFlow עם כל שיחה.

@tf.function
def f(x):
  print("Traced with", x)
  tf.print("Executed with", x)

f(1)
f(1)
f(2)
Traced with 1
Executed with 1
Executed with 1
Traced with 2
Executed with 2

אם אתה רוצה לבצע קוד פיתון במהלך כל שבע של Function , tf.py_function מהווה צוהר יציאה. החיסרון של tf.py_function היא שזה לא נייד או במיוחד performant, לא ניתן לשמור עם SavedModel, ואינו פועל היטב מופץ הגדרות (-GPU רב, TPU). כמו כן, מאז tf.py_function יש חוטית לתוך הגרף, זה מטיל את כל כניסות / יציאות כדי tensors.

שינוי משתנים גלובליים וחופשיים של Python

שינוי הגלובלי Python משתנה חינם ספירה כתופעת לוואי Python, אז זה קורה רק במהלך מעקב.

external_list = []

@tf.function
def side_effect(x):
  print('Python side effect')
  external_list.append(x)

side_effect(1)
side_effect(1)
side_effect(1)
# The list append only happened once!
assert len(external_list) == 1
Python side effect

לפעמים קשה מאוד לשים לב להתנהגויות בלתי צפויות. בדוגמא להלן, counter נועד להגן על התוספת של משתנה. אולם מכיוון שהוא מספר שלם של פיתון ולא אובייקט TensorFlow, הערך שלו נקלט במהלך המעקב הראשון. כאשר tf.function משמש, assign_add יירשם ללא תנאי בגרף הבסיסי. לכן v יגדל ב 1, בכל פעם tf.function נקרא. בעיה זו נפוצה בקרב משתמשים שמנסים להעביר קוד Grpah-mode Tensorflow שלהם Tensorflow 2 באמצעות tf.function מעצבי פנים, כאשר python-תופעות לוואי (את counter בדוגמה) משמשים כדי לקבוע מה ops כדי ריצה ( assign_add בדוגמה ). בדרך כלל, משתמשים מבינים זאת רק לאחר שראו תוצאות מספריות חשודות, או ביצועים נמוכים משמעותית מהצפוי (למשל אם הפעולה המוגנת היא יקרה מאוד).

class Model(tf.Module):
  def __init__(self):
    self.v = tf.Variable(0)
    self.counter = 0

  @tf.function
  def __call__(self):
    if self.counter == 0:
      # A python side-effect
      self.counter += 1
      self.v.assign_add(1)

    return self.v

m = Model()
for n in range(3):
  print(m().numpy()) # prints 1, 2, 3
1
2
3

לעקיפת הבעיה כדי להשיג את ההתנהגות הצפויה משתמש tf.init_scope להרים את הפעילות מחוץ גרף הפונקציה. זה מבטיח שהתוספת המשתנה תתבצע פעם אחת בלבד במהלך זמן המעקב. יצוין init_scope יש תופעות לוואי אחרות, כולל בקרת זרימה מנוקה קלטת שיפוע. לפעמים שימוש init_scope יכול להיות מורכב מדי כדי לנהל באופן ריאליסטי.

class Model(tf.Module):
  def __init__(self):
    self.v = tf.Variable(0)
    self.counter = 0

  @tf.function
  def __call__(self):
    if self.counter == 0:
      # Lifts ops out of function-building graphs
      with tf.init_scope():
        self.counter += 1
        self.v.assign_add(1)

    return self.v

m = Model()
for n in range(3):
  print(m().numpy()) # prints 1, 1, 1
1
1
1

לסיכום, ככלל אצבע, עליך להימנע מוטצית פיתון אובייקטים כגון מספרים שלמים או מיכל כמו רשימות כי מחוץ לחיות את Function . במקום זאת, השתמש בארגומנטים ובאובייקטי TF. לדוגמה, בפרק "צבירת ערכי בתוך לולאה" יש דוגמה אחת לאופן רשימה דמוי פעולות יכול להיות מיושם.

אתה יכול, במקרים מסוימים, ללכוד ולטפל מדינה אם מדובר tf.Variable . כך משקלות דגמי Keras מתעדכנים עם קריאות חוזרות ונשנות באותו ConcreteFunction .

שימוש באיטרטורים ומחוללים של Python

תכונות רבות של Python, כגון גנרטורים ואיטרטורים, מסתמכים על זמן הריצה של Python כדי לעקוב אחר המצב. באופן כללי, בעוד שהמבנים הללו פועלים כצפוי במצב להוט, הם מהווים דוגמאות לתופעות לוואי של Python ולכן קורות רק במהלך המעקב.

@tf.function
def buggy_consume_next(iterator):
  tf.print("Value:", next(iterator))

iterator = iter([1, 2, 3])
buggy_consume_next(iterator)
# This reuses the first value from the iterator, rather than consuming the next value.
buggy_consume_next(iterator)
buggy_consume_next(iterator)
Value: 1
Value: 1
Value: 1

בדיוק כמו איך TensorFlow יש לו התמחה tf.TensorArray עבור מבני רשימה, יש לו התמחה tf.data.Iterator עבור איטרציה מבנים. עיין בסעיף על טרנספורמציות חתימה עבור סקירה. כמו כן, tf.data API יכול לעזור ליישם דפוסי גנרטור:

@tf.function
def good_consume_next(iterator):
  # This is ok, iterator is a tf.data.Iterator
  tf.print("Value:", next(iterator))

ds = tf.data.Dataset.from_tensor_slices([1, 2, 3])
iterator = iter(ds)
good_consume_next(iterator)
good_consume_next(iterator)
good_consume_next(iterator)
Value: 1
Value: 2
Value: 3

כל הפלטים של פונקציית tf. חייבים להיות ערכי החזרה

למעט tf.Variable s, tf.function חייב להחזיר את כל תפוקותיו. ניסיון לגשת ישירות לכל טנסור מפונקציה מבלי לעבור על ערכי החזרה גורם ל"דליפות".

לדוגמא, הפונקציה מתחת "הדלפות" מותח הדרך העולמית Python a x :

x = None

@tf.function
def leaky_function(a):
  global x
  x = a + 1  # Bad - leaks local tensor
  return a + 2

correct_a = leaky_function(tf.constant(1))

print(correct_a.numpy())  # Good - value obtained from function's returns
try:
  x.numpy()  # Bad - tensor leaked from inside the function, cannot be used here
except AttributeError as expected:
  print(expected)
3
'Tensor' object has no attribute 'numpy'

זה נכון גם אם הערך הדלף מוחזר גם:

@tf.function
def leaky_function(a):
  global x
  x = a + 1  # Bad - leaks local tensor
  return x  # Good - uses local tensor

correct_a = leaky_function(tf.constant(1))

print(correct_a.numpy())  # Good - value obtained from function's returns
try:
  x.numpy()  # Bad - tensor leaked from inside the function, cannot be used here
except AttributeError as expected:
  print(expected)

@tf.function
def captures_leaked_tensor(b):
  b += x  # Bad - `x` is leaked from `leaky_function`
  return b

with assert_raises(TypeError):
  captures_leaked_tensor(tf.constant(2))
2
'Tensor' object has no attribute 'numpy'
Caught expected exception 
  <class 'TypeError'>:
Traceback (most recent call last):
  File "/tmp/ipykernel_26244/3551158538.py", line 8, in assert_raises
    yield
  File "/tmp/ipykernel_26244/566849597.py", line 21, in <module>
    captures_leaked_tensor(tf.constant(2))
TypeError: Originated from a graph execution error.

The graph execution error is detected at a node built at (most recent call last):
>>>  File /usr/lib/python3.7/runpy.py, line 193, in _run_module_as_main
>>>  File /usr/lib/python3.7/runpy.py, line 85, in _run_code
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/ipykernel_launcher.py, line 16, in <module>
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/traitlets/config/application.py, line 846, in launch_instance
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/ipykernel/kernelapp.py, line 677, in start
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tornado/platform/asyncio.py, line 199, in start
>>>  File /usr/lib/python3.7/asyncio/base_events.py, line 534, in run_forever
>>>  File /usr/lib/python3.7/asyncio/base_events.py, line 1771, in _run_once
>>>  File /usr/lib/python3.7/asyncio/events.py, line 88, in _run
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/ipykernel/kernelbase.py, line 457, in dispatch_queue
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/ipykernel/kernelbase.py, line 446, in process_one
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/ipykernel/kernelbase.py, line 353, in dispatch_shell
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/ipykernel/kernelbase.py, line 648, in execute_request
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/ipykernel/ipkernel.py, line 353, in do_execute
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/ipykernel/zmqshell.py, line 533, in run_cell
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/IPython/core/interactiveshell.py, line 2902, in run_cell
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/IPython/core/interactiveshell.py, line 2947, in _run_cell
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/IPython/core/async_helpers.py, line 68, in _pseudo_sync_runner
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/IPython/core/interactiveshell.py, line 3173, in run_cell_async
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/IPython/core/interactiveshell.py, line 3364, in run_ast_nodes
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/IPython/core/interactiveshell.py, line 3444, in run_code
>>>  File /tmp/ipykernel_26244/566849597.py, line 7, in <module>
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/util/traceback_utils.py, line 150, in error_handler
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py, line 910, in __call__
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py, line 958, in _call
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py, line 781, in _initialize
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/eager/function.py, line 3157, in _get_concrete_function_internal_garbage_collected
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/eager/function.py, line 3557, in _maybe_define_function
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/eager/function.py, line 3402, in _create_graph_function
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/framework/func_graph.py, line 1143, in func_graph_from_py_func
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py, line 672, in wrapped_fn
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/framework/func_graph.py, line 1125, in autograph_handler
>>>  File /tmp/ipykernel_26244/566849597.py, line 4, in leaky_function
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/util/traceback_utils.py, line 150, in error_handler
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/ops/math_ops.py, line 1383, in binary_op_wrapper
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/util/traceback_utils.py, line 150, in error_handler
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/util/dispatch.py, line 1096, in op_dispatch_handler
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/ops/math_ops.py, line 1737, in _add_dispatch
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/ops/gen_math_ops.py, line 476, in add_v2
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/framework/op_def_library.py, line 746, in _apply_op_helper
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/framework/func_graph.py, line 691, in _create_op_internal
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/framework/ops.py, line 3705, in _create_op_internal
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/framework/ops.py, line 2101, in __init__

Error detected in node 'add' defined at: File "/tmp/ipykernel_26244/566849597.py", line 4, in leaky_function

TypeError: tf.Graph captured an external symbolic tensor. The symbolic tensor 'add:0' created by node 'add' is captured by the tf.Graph being executed as an input. But a tf.Graph is not allowed to take symbolic tensors from another graph as its inputs. Make sure all captured inputs of the executing tf.Graph are not symbolic tensors. Use return values, explicit Python locals or TensorFlow collections to access it. Please see https://www.tensorflow.org/guide/function#all_outputs_of_a_tffunction_must_be_return_values for more information.

בדרך כלל, דליפות כאלה מתרחשות כאשר אתה משתמש בהצהרות Python או במבני נתונים. בנוסף להדלפת טנסורים בלתי נגישים, סביר להניח שהצהרות כאלה שגויות מכיוון שהן נחשבות כתופעות לוואי של Python, ואינן מובטחות שיבוצעו בכל קריאת פונקציה.

דרכים נפוצות להדלפת טנסורים מקומיים כוללות גם מוטציה של אוסף Python חיצוני, או אובייקט:

class MyClass:

  def __init__(self):
    self.field = None

external_list = []
external_object = MyClass()

def leaky_function():
  a = tf.constant(1)
  external_list.append(a)  # Bad - leaks tensor
  external_object.field = a  # Bad - leaks tensor

פונקציות tf. רקורסיביות אינן נתמכות

רקורסיבית Function של אינם נתמכים ועלולה לגרום לולאות אינסופיות. לדוגמה,

@tf.function
def recursive_fn(n):
  if n > 0:
    return recursive_fn(n - 1)
  else:
    return 1

with assert_raises(Exception):
  recursive_fn(tf.constant(5))  # Bad - maximum recursion error.
Caught expected exception 
  <class 'Exception'>:
Traceback (most recent call last):
  File "/tmp/ipykernel_26244/3551158538.py", line 8, in assert_raises
    yield
  File "/tmp/ipykernel_26244/2233998312.py", line 9, in <module>
    recursive_fn(tf.constant(5))  # Bad - maximum recursion error.
tensorflow.python.autograph.impl.api.StagingError: in user code:

    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 3, in recursive_fn  *
        if n > 0:
    File "/usr/lib/python3.7/abc.py", line 139, in __instancecheck__
        return _abc_instancecheck(cls, instance)

    RecursionError: maximum recursion depth exceeded while calling a Python object

גם אם רקורסיבית Function נראה עבודה, הפונקציה Python יהיה לייחס מספר פעמים עשויה להיות השלכה על הביצועים. לדוגמה,

@tf.function
def recursive_fn(n):
  if n > 0:
    print('tracing')
    return recursive_fn(n - 1)
  else:
    return 1

recursive_fn(5)  # Warning - multiple tracings
tracing
tracing
tracing
tracing
tracing
<tf.Tensor: shape=(), dtype=int32, numpy=1>

בעיות ידועות

אם שלך Function לא מעריכים כראוי, את השגיאה יכולה להיות מוסברת על ידי בעיות ידועות אלה אשר מתוכננים להיות קבוע בעתיד.

תלוי במשתנים גלובליים וחופשיים של Python

Function יוצרת חדשה ConcreteFunction כאשר נקרא עם ערך חדש של טיעון Python. עם זאת, זה לא עושה את זה בשביל סגירת Python, globals, או nonlocals של אותה Function . אם ערכם משתנה בין שיחות Function , את Function עדיין תשתמש הערכים הם היו כשזה אותר. זה שונה מהאופן שבו פועלות פונקציות Python רגילות.

מסיבה זו, עליך לעקוב אחר סגנון תכנות פונקציונלי המשתמש בארגומנטים במקום לסגור מעל שמות חיצוניים.

@tf.function
def buggy_add():
  return 1 + foo

@tf.function
def recommended_add(foo):
  return 1 + foo

foo = 1
print("Buggy:", buggy_add())
print("Correct:", recommended_add(foo))
Buggy: tf.Tensor(2, shape=(), dtype=int32)
Correct: tf.Tensor(2, shape=(), dtype=int32)
print("Updating the value of `foo` to 100!")
foo = 100
print("Buggy:", buggy_add())  # Did not change!
print("Correct:", recommended_add(foo))
Updating the value of `foo` to 100!
Buggy: tf.Tensor(2, shape=(), dtype=int32)
Correct: tf.Tensor(101, shape=(), dtype=int32)

דרך נוספת לעדכן ערך גלובלי, היא להפוך אותו tf.Variable ולהשתמש Variable.assign שיטת במקום.

@tf.function
def variable_add():
  return 1 + foo

foo = tf.Variable(1)
print("Variable:", variable_add())
Variable: tf.Tensor(2, shape=(), dtype=int32)
print("Updating the value of `foo` to 100!")
foo.assign(100)
print("Variable:", variable_add())
Updating the value of `foo` to 100!
Variable: tf.Tensor(101, shape=(), dtype=int32)

תלוי באובייקטים של Python

ההמלצה להעביר חפצים Python כטיעונים לתוך tf.function יש מספר בעיות ידועות, אשר צפויים להיות קבוע בעתיד. באופן כללי, אתה יכול לסמוך על מעקב עקבי אם אתה משתמש Python פרימיטיבי או tf.nest מבנה -compatible כטיעון או לעבור במופע אחר של אובייקט לתוך Function . עם זאת, Function לא תיצור עקבות חדשות כאשר אתה עובר את אותו אובייקט ורק לשנות התכונות שלה.

class SimpleModel(tf.Module):
  def __init__(self):
    # These values are *not* tf.Variables.
    self.bias = 0.
    self.weight = 2.

@tf.function
def evaluate(model, x):
  return model.weight * x + model.bias

simple_model = SimpleModel()
x = tf.constant(10.)
print(evaluate(simple_model, x))
tf.Tensor(20.0, shape=(), dtype=float32)
print("Adding bias!")
simple_model.bias += 5.0
print(evaluate(simple_model, x))  # Didn't change :(
Adding bias!
tf.Tensor(20.0, shape=(), dtype=float32)

שימוש באותה Function להעריך את המופע המעודכן של המודל יהיה מרכבה מאז המודל המעודכן יש את אותו מפתח המטמון כמו הדגם המקורי.

מסיבה זו, אתה מומלץ לכתוב את Function כדי למנוע תלוי תכונות האובייקט משתנה או ליצור אובייקטים חדשים.

אם זה לא אפשרי, מעקף אחד היא להפוך חדשה Function s בכול פעם שתשנה האובייקט שלך כדי לשחזר את הכח:

def evaluate(model, x):
  return model.weight * x + model.bias

new_model = SimpleModel()
evaluate_no_bias = tf.function(evaluate).get_concrete_function(new_model, x)
# Don't pass in `new_model`, `Function` already captured its state during tracing.
print(evaluate_no_bias(x))
tf.Tensor(20.0, shape=(), dtype=float32)
print("Adding bias!")
new_model.bias += 5.0
# Create new Function and ConcreteFunction since you modified new_model.
evaluate_with_bias = tf.function(evaluate).get_concrete_function(new_model, x)
print(evaluate_with_bias(x)) # Don't pass in `new_model`.
Adding bias!
tf.Tensor(25.0, shape=(), dtype=float32)

כפי retracing יכול להיות יקר , אתה יכול להשתמש tf.Variable ים כמו תכונות אובייקט, אשר יכול להיות מוטציה (אך לא השתנו,! זהיר) עבור אפקט דומה ללא צורך retrace.

class BetterModel:

  def __init__(self):
    self.bias = tf.Variable(0.)
    self.weight = tf.Variable(2.)

@tf.function
def evaluate(model, x):
  return model.weight * x + model.bias

better_model = BetterModel()
print(evaluate(better_model, x))
tf.Tensor(20.0, shape=(), dtype=float32)
print("Adding bias!")
better_model.bias.assign_add(5.0)  # Note: instead of better_model.bias += 5
print(evaluate(better_model, x))  # This works!
Adding bias!
tf.Tensor(25.0, shape=(), dtype=float32)

יצירת tf.Variables

Function תומכת רק סינגלטון tf.Variable s נוצר פעם על השיחה הראשונה, ולעשות בהם שימוש חוזר ברחבי שיחות פונקציה עוקבות. קטע הקוד מתחת תיצור חדש tf.Variable בכל קריאה לפונקציה, שתוצאתה ValueError החריג.

דוגמא:

@tf.function
def f(x):
  v = tf.Variable(1.0)
  return v

with assert_raises(ValueError):
  f(1.0)
Caught expected exception 
  <class 'ValueError'>:
Traceback (most recent call last):
  File "/tmp/ipykernel_26244/3551158538.py", line 8, in assert_raises
    yield
  File "/tmp/ipykernel_26244/3018268426.py", line 7, in <module>
    f(1.0)
ValueError: in user code:

    File "/tmp/ipykernel_26244/3018268426.py", line 3, in f  *
        v = tf.Variable(1.0)

    ValueError: tf.function only supports singleton tf.Variables created on the first call. Make sure the tf.Variable is only created once or created outside tf.function. See https://www.tensorflow.org/guide/function#creating_tfvariables for more information.

דפוס נפוץ המשמש לעקוף מגבלה זו היא להתחיל עם ערך אף Python, אז באופן מותנה ליצור את tf.Variable אם הערך הוא None:

class Count(tf.Module):
  def __init__(self):
    self.count = None

  @tf.function
  def __call__(self):
    if self.count is None:
      self.count = tf.Variable(0)
    return self.count.assign_add(1)

c = Count()
print(c())
print(c())
tf.Tensor(1, shape=(), dtype=int32)
tf.Tensor(2, shape=(), dtype=int32)

שימוש עם מספר אופטימיזציית Keras

אתה עלול להיתקל ValueError: tf.function only supports singleton tf.Variables created on the first call. כשמשתמשים יותר האופטימיזציה אחד Keras עם tf.function . שגיאה זו מתרחשת משום אופטימיזציה פנימית ליצור tf.Variables כשהם חלים הדרגתיים בפעם הראשונה.

opt1 = tf.keras.optimizers.Adam(learning_rate = 1e-2)
opt2 = tf.keras.optimizers.Adam(learning_rate = 1e-3)

@tf.function
def train_step(w, x, y, optimizer):
   with tf.GradientTape() as tape:
       L = tf.reduce_sum(tf.square(w*x - y))
   gradients = tape.gradient(L, [w])
   optimizer.apply_gradients(zip(gradients, [w]))

w = tf.Variable(2.)
x = tf.constant([-1.])
y = tf.constant([2.])

train_step(w, x, y, opt1)
print("Calling `train_step` with different optimizer...")
with assert_raises(ValueError):
  train_step(w, x, y, opt2)
Calling `train_step` with different optimizer...
Caught expected exception 
  <class 'ValueError'>:
Traceback (most recent call last):
  File "/tmp/ipykernel_26244/3551158538.py", line 8, in assert_raises
    yield
  File "/tmp/ipykernel_26244/3167358578.py", line 18, in <module>
    train_step(w, x, y, opt2)
ValueError: in user code:

    File "/tmp/ipykernel_26244/3167358578.py", line 9, in train_step  *
        optimizer.apply_gradients(zip(gradients, [w]))
    File "/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/keras/optimizer_v2/optimizer_v2.py", line 639, in apply_gradients  **
        self._create_all_weights(var_list)
    File "/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/keras/optimizer_v2/optimizer_v2.py", line 828, in _create_all_weights
        _ = self.iterations
    File "/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/keras/optimizer_v2/optimizer_v2.py", line 835, in __getattribute__
        return super(OptimizerV2, self).__getattribute__(name)
    File "/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/keras/optimizer_v2/optimizer_v2.py", line 995, in iterations
        aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA)
    File "/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/keras/optimizer_v2/optimizer_v2.py", line 1202, in add_weight
        aggregation=aggregation)
    File "/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/keras/engine/base_layer_utils.py", line 129, in make_variable
        shape=variable_shape if variable_shape else None)

    ValueError: tf.function only supports singleton tf.Variables created on the first call. Make sure the tf.Variable is only created once or created outside tf.function. See https://www.tensorflow.org/guide/function#creating_tfvariables for more information.

אם אתה צריך לשנות את האופטימיזציה במהלך האימון, ניתן לעקוף את הבעיה על מנת ליצור חדש Function עבור כל האופטימיזציה, לקרוא ConcreteFunction ישירות.

opt1 = tf.keras.optimizers.Adam(learning_rate = 1e-2)
opt2 = tf.keras.optimizers.Adam(learning_rate = 1e-3)

# Not a tf.function.
def train_step(w, x, y, optimizer):
   with tf.GradientTape() as tape:
       L = tf.reduce_sum(tf.square(w*x - y))
   gradients = tape.gradient(L, [w])
   optimizer.apply_gradients(zip(gradients, [w]))

w = tf.Variable(2.)
x = tf.constant([-1.])
y = tf.constant([2.])

# Make a new Function and ConcreteFunction for each optimizer.
train_step_1 = tf.function(train_step).get_concrete_function(w, x, y, opt1)
train_step_2 = tf.function(train_step).get_concrete_function(w, x, y, opt2)
for i in range(10):
  if i % 2 == 0:
    train_step_1(w, x, y) # `opt1` is not used as a parameter. 
  else:
    train_step_2(w, x, y) # `opt2` is not used as a parameter.

שימוש עם מספר דגמי Keras

ייתכן גם מפגש ValueError: tf.function only supports singleton tf.Variables created on the first call. כאשר עוברים מקרי מודל שונים לאותה Function .

שגיאה זו מתרחשת משום מודלי Keras (אשר אין צורת הקלט שלהם מוגדרת ) ושכבות Keras ליצור tf.Variables ים כשהם נקראים ראשונים. ייתכן שאתה מנסה לאתחל אלה משתנים בתוך Function , אשר כבר נקרא. כדי למנוע שגיאה זו, נסה להתקשר model.build(input_shape) כדי לאתחל את כל המשקולות לפני אימון המודל.

לקריאה נוספת

כדי ללמוד על איך לייצא לטעון Function , לראות את מדריך SavedModel . כדי ללמוד עוד על אופטימיזציות גרף המבוצעות לאחר התחקות, לראות את מדריך grappler . כדי ללמוד כיצד לייעל צינור הנתונים שלך ופרופיל המודל שלך, לראות את מדריך Profiler .