לשמור את התאריך! קלט / פלט של Google חוזר 18-20 במאי הירשם עכשיו
דף זה תורגם על ידי Cloud Translation API.
Switch to English

ביצועים טובים יותר עם פונקצית tf

צפה ב- TensorFlow.org הפעל בגוגל קולאב צפה במקור ב- GitHub הורד מחברת

ב- TensorFlow 2, ביצוע להוט מופעל כברירת מחדל. ממשק המשתמש הוא אינטואיטיבי וגמיש (הפעלת פעולות חד פעמיות קלה ומהירה בהרבה), אך הדבר יכול לבוא על חשבון ביצועים ופריסה.

אתה יכול להשתמש ב- tf.function כדי ליצור גרפים מהתוכניות שלך. זהו כלי טרנספורמציה שיוצר גרפי זרימת נתונים שאינם תלויים בפייתון מתוך קוד הפייתון שלך. זה יעזור לך ליצור מודלים ביצועים וניידים, וזה נדרש להשתמש ב- SavedModel .

מדריך זה יעזור לך tf.function כיצד tf.function עובדת מתחת למכסה המנוע, כך שתוכל להשתמש בה בצורה יעילה.

ההמלצות העיקריות וההמלצות הן:

  • ניפוי שגיאות במצב להוט, ואז קישוט עם @tf.function .
  • אל תסמוך על תופעות לוואי של פייתון כמו מוטציה של אובייקטים או רשימות.
  • tf.function עובד הכי טוב עם TensorFlow ops; שיחות NumPy ו- ​​Python מומרות לקבועים.

להכין

import tensorflow as tf

הגדר פונקציית עוזר כדי להדגים את סוגי השגיאות שאתה עלול להיתקל בהן:

import traceback
import contextlib

# Some helper code to demonstrate the kinds of errors you might encounter.
@contextlib.contextmanager
def assert_raises(error_class):
  try:
    yield
  except error_class as e:
    print('Caught expected exception \n  {}:'.format(error_class))
    traceback.print_exc(limit=2)
  except Exception as e:
    raise e
  else:
    raise Exception('Expected {} to be raised but no error was raised!'.format(
        error_class))

יסודות

נוֹהָג

Function שאתה מגדיר (למשל על ידי יישום @tf.function ) היא בדיוק כמו פעולת ליבה של TensorFlow: אתה יכול לבצע אותה בשקיקה; אתה יכול לחשב שיפועים; וכולי.

@tf.function  # The decorator converts `add` into a `Function`.
def add(a, b):
  return a + b

add(tf.ones([2, 2]), tf.ones([2, 2]))  #  [[2., 2.], [2., 2.]]
<tf.Tensor: shape=(2, 2), dtype=float32, numpy=
array([[2., 2.],
       [2., 2.]], dtype=float32)>
v = tf.Variable(1.0)
with tf.GradientTape() as tape:
  result = add(v, 1.0)
tape.gradient(result, v)
<tf.Tensor: shape=(), dtype=float32, numpy=1.0>

אתה יכול להשתמש Function בתוך Function אחרות.

@tf.function
def dense_layer(x, w, b):
  return add(tf.matmul(x, w), b)

dense_layer(tf.ones([3, 2]), tf.ones([2, 2]), tf.ones([2]))
<tf.Tensor: shape=(3, 2), dtype=float32, numpy=
array([[3., 3.],
       [3., 3.],
       [3., 3.]], dtype=float32)>

Function s יכולות להיות מהירות יותר מקוד להוט, במיוחד עבור גרפים עם אופים קטנים רבים. אבל עבור גרפים עם כמה אופציות יקרות (כמו התפתלות), ייתכן שלא תראה מהירות רבה.

import timeit
conv_layer = tf.keras.layers.Conv2D(100, 3)

@tf.function
def conv_fn(image):
  return conv_layer(image)

image = tf.zeros([1, 200, 200, 100])
# warm up
conv_layer(image); conv_fn(image)
print("Eager conv:", timeit.timeit(lambda: conv_layer(image), number=10))
print("Function conv:", timeit.timeit(lambda: conv_fn(image), number=10))
print("Note how there's not much difference in performance for convolutions")
Eager conv: 0.0035502629999655255
Function conv: 0.004116348000025027
Note how there's not much difference in performance for convolutions

מַעֲקָב

סעיף זה חושף כיצד פועלת Function מתחת למכסה המנוע, כולל פרטי יישום שעשויים להשתנות בעתיד . עם זאת, ברגע שאתה מבין מדוע ומתי מתרחש מעקב, הרבה יותר קל להשתמש ב- tf.function בצורה יעילה!

מה זה "מעקב"?

Function מריצה את התוכנית שלך בגרף TensorFlow . עם זאת, tf.Graph אינו יכול לייצג את כל הדברים שאתה כותב בתוכנית TensorFlow להוטה. לדוגמה, Python תומך בפולימורפיזם, אך tf.Graph דורש שהתשומות שלו יהיו בעלות סוג נתונים ומימד מוגדרים. לחלופין, תוכל לבצע משימות לוואי כמו קריאת טיעוני שורת פקודה, העלאת שגיאה או עבודה עם אובייקט פייתון מורכב יותר; אף אחד מהדברים האלה לא יכול לרוץ tf.Graph .

Function מגשרת על פער זה על ידי הפרדת הקוד שלך בשני שלבים:

1) בשלב הראשון, המכונה " מעקב ", Function יוצרת tf.Graph חדש. קוד פייתון פועל כרגיל, אך כל פעולות TensorFlow (כמו הוספת שני טנזורים) נדחות : הן tf.Graph על ידי tf.Graph ולא פועלות.

2) בשלב השני tf.Graph המכיל את כל מה tf.Graph בשלב הראשון. שלב זה מהיר בהרבה משלב האיתור.

בהתאם לקלטים שלה, Function לא תמיד תריץ את השלב הראשון כאשר היא נקראת. ראה "כללי מעקב" להלן כדי להבין טוב יותר כיצד היא קובעת את אותה קביעה. דילוג על השלב הראשון ורק ביצוע השלב השני הוא זה שנותן לך את הביצועים הגבוהים של TensorFlow.

כאשר Function אכן מחליטה להתחקות אחר שלב המעקב מיד אחריו השלב השני, לכן קריאה Function יוצרת ומפעילה את ה- tf.Graph . בהמשך תוכלו לראות כיצד תוכלו להריץ רק את שלב העקיבה באמצעות get_concrete_function .

כאשר אנו מעבירים טיעונים מסוגים שונים Function , שני השלבים מנוהלים:

@tf.function
def double(a):
  print("Tracing with", a)
  return a + a

print(double(tf.constant(1)))
print()
print(double(tf.constant(1.1)))
print()
print(double(tf.constant("a")))
print()
Tracing with Tensor("a:0", shape=(), dtype=int32)
tf.Tensor(2, shape=(), dtype=int32)

Tracing with Tensor("a:0", shape=(), dtype=float32)
tf.Tensor(2.2, shape=(), dtype=float32)

Tracing with Tensor("a:0", shape=(), dtype=string)
tf.Tensor(b'aa', shape=(), dtype=string)

שים לב שאם אתה מתקשר שוב ושוב Function עם אותו סוג ארגומנט, TensorFlow ידלג על שלב העקיבה וישתמש מחדש בגרף שנחקק בעבר, מכיוון שהגרף שנוצר יהיה זהה.

# This doesn't print 'Tracing with ...'
print(double(tf.constant("b")))
tf.Tensor(b'bb', shape=(), dtype=string)

אתה יכול להשתמש pretty_printed_concrete_signatures() כדי לראות את כל העקבות הזמינות:

print(double.pretty_printed_concrete_signatures())
double(a)
  Args:
    a: float32 Tensor, shape=()
  Returns:
    float32 Tensor, shape=()

double(a)
  Args:
    a: string Tensor, shape=()
  Returns:
    string Tensor, shape=()

double(a)
  Args:
    a: int32 Tensor, shape=()
  Returns:
    int32 Tensor, shape=()

עד כה ראית ש- tf.function יוצר שכבת משלוח דינמית במטמון מעל לוגיקת האיתור הגרפי של TensorFlow. ליתר דיוק לגבי המינוח:

  • tf.Graph הוא tf.Graph גולמי, אגנוסטי לשפה, ונייד של חישוב TensorFlow.
  • tf.Graph ConcreteFunction עוטפת תמונה tf.Graph .
  • Function מנהלת מטמון של ConcreteFunction s ובוחרת את המתאים עבור הקלטים שלך.
  • tf.function עוטף פונקציית פייתון, ומחזיר אובייקט Function .
  • מעקב יוצר tf.Graph ועוטף אותו ב- ConcreteFunction , הידוע גם בשם עקבות.

כללי התחקות

Function קובעת אם לעשות שימוש חוזר ב- ConcreteFunction שעוקב אחריה על ידי חישוב מפתח מטמון מתוך הארגונים והקווארגים של קלט. מפתח מטמון הוא מפתח המזהה Function ConcreteFunction בהתבסס על טיעוני הקלט והקווארוגים של קריאת ה- Function , על פי הכללים הבאים (אשר עשויים להשתנות):

  • המפתח שנוצר עבור tf.Tensor הוא צורתו וסוגו.
  • המפתח שנוצר עבור tf.Variable הוא מזהה משתנה ייחודי.
  • המפתח שנוצר לפרימיטיבי של פייתון (כמו int , float , str ) הוא הערך שלו.
  • המפתח שנוצר עבור dict מקוננים, list s, tuple s, namedtuple tuple s ו- attr s הוא namedtuple השטוח של מקשי העלה (ראהnest.flatten ). (כתוצאה מהשטחה זו, קריאה לפונקציית בטון עם מבנה קינון שונה מזה המשמש במהלך העקיבה תביא ל- TypeError).
  • עבור כל שאר סוגי הפיתון המפתח הוא ייחודי לאובייקט. בדרך זו ניתן לעקוב אחר פונקציה או שיטה באופן עצמאי עבור כל מקרה שהוא נקרא באמצעותו.

שליטה חוזרת

מעקב אחר המעקב, כאשר Function שלך יוצרת יותר ממעקב אחד, מסייע להבטיח ש- TensorFlow מייצר גרפים נכונים עבור כל קבוצת תשומות. עם זאת, מעקב הוא פעולה יקרה! אם Function שלך משחזרת גרף חדש לכל שיחה, תגלה שהקוד שלך פועל לאט יותר מאשר אם לא tf.function ב- tf.function .

כדי לשלוט בהתנהגות האיתור, ניתן להשתמש בטכניקות הבאות:

  • ציין את input_signature ב- tf.function כדי להגביל את העקיבה.
@tf.function(input_signature=(tf.TensorSpec(shape=[None], dtype=tf.int32),))
def next_collatz(x):
  print("Tracing with", x)
  return tf.where(x % 2 == 0, x // 2, 3 * x + 1)

print(next_collatz(tf.constant([1, 2])))
# We specified a 1-D tensor in the input signature, so this should fail.
with assert_raises(ValueError):
  next_collatz(tf.constant([[1, 2], [3, 4]]))

# We specified an int32 dtype in the input signature, so this should fail.
with assert_raises(ValueError):
  next_collatz(tf.constant([1.0, 2.0]))
Tracing with Tensor("x:0", shape=(None,), dtype=int32)
tf.Tensor([4 1], shape=(2,), dtype=int32)
Caught expected exception 
  <class 'ValueError'>:
Caught expected exception 
  <class 'ValueError'>:
Traceback (most recent call last):
  File "<ipython-input-1-73d0ca52e838>", line 8, in assert_raises
    yield
  File "<ipython-input-1-20f544b8adbf>", line 9, in <module>
    next_collatz(tf.constant([[1, 2], [3, 4]]))
ValueError: Python inputs incompatible with input_signature:
  inputs: (
    tf.Tensor(
[[1 2]
 [3 4]], shape=(2, 2), dtype=int32))
  input_signature: (
    TensorSpec(shape=(None,), dtype=tf.int32, name=None))
Traceback (most recent call last):
  File "<ipython-input-1-73d0ca52e838>", line 8, in assert_raises
    yield
  File "<ipython-input-1-20f544b8adbf>", line 13, in <module>
    next_collatz(tf.constant([1.0, 2.0]))
ValueError: Python inputs incompatible with input_signature:
  inputs: (
    tf.Tensor([1. 2.], shape=(2,), dtype=float32))
  input_signature: (
    TensorSpec(shape=(None,), dtype=tf.int32, name=None))
  • ציין ממד [ללא] ב- tf.TensorSpec כדי לאפשר גמישות בשימוש חוזר אחר.

    מכיוון ש- TensorFlow תואם טנזורים על סמך צורתם, שימוש בממד None בתור תו כללי יאפשר Function s לעשות שימוש חוזר במעקב לצורך קלט בגודל משתנה. קלט בגודל משתנה יכול להתרחש אם יש לך רצפים באורך שונה, או תמונות בגדלים שונים עבור כל אצווה (ראה למשל מדריכי רובוטריקים ו Deep Dream ).

@tf.function(input_signature=(tf.TensorSpec(shape=[None], dtype=tf.int32),))
def g(x):
  print('Tracing with', x)
  return x

# No retrace!
print(g(tf.constant([1, 2, 3])))
print(g(tf.constant([1, 2, 3, 4, 5])))
Tracing with Tensor("x:0", shape=(None,), dtype=int32)
tf.Tensor([1 2 3], shape=(3,), dtype=int32)
tf.Tensor([1 2 3 4 5], shape=(5,), dtype=int32)
  • העבר טיעונים של פייתון ל Tensors כדי להפחית את ההחזרה.

    לעיתים קרובות, ארגומנטים של פייתון משמשים לשליטה num_layers=10 ובבניית גרפים - למשל, num_layers=10 או training=True או nonlinearity='relu' . אז אם הטיעון של פייתון ישתנה, הגיוני שתצטרך לעקוב אחר הגרף מחדש.

    עם זאת, יתכן שלא משתמשים בטיעון של פייתון לבקרת בניית גרפים. במקרים אלה, שינוי בערך הפיתון יכול לעורר חזרה מיותרת. קחו, למשל, את לולאת האימון הזו, ש- AutoGraph תפרוש באופן דינמי. למרות העקבות המרובים, הגרף שנוצר הוא זהה למעשה, ולכן חזרה מיותרת.

def train_one_step():
  pass

@tf.function
def train(num_steps):
  print("Tracing with num_steps = ", num_steps)
  tf.print("Executing with num_steps = ", num_steps)
  for _ in tf.range(num_steps):
    train_one_step()

print("Retracing occurs for different Python arguments.")
train(num_steps=10)
train(num_steps=20)

print()
print("Traces are reused for Tensor arguments.")
train(num_steps=tf.constant(10))
train(num_steps=tf.constant(20))
Retracing occurs for different Python arguments.
Tracing with num_steps =  10
Executing with num_steps =  10
Tracing with num_steps =  20
Executing with num_steps =  20

Traces are reused for Tensor arguments.
Tracing with num_steps =  Tensor("num_steps:0", shape=(), dtype=int32)
Executing with num_steps =  10
Executing with num_steps =  20

אם אתה זקוק לאילוח חוזר בכוח, צור Function חדשה. מובטח כי אובייקטי Function נפרדים אינם חולקים עקבות.

def f():
  print('Tracing!')
  tf.print('Executing')

tf.function(f)()
tf.function(f)()
Tracing!
Executing
Tracing!
Executing

השגת פונקציות קונקרטיות

בכל פעם שמעקב אחר פונקציה נוצר פונקציה קונקרטית חדשה. באפשרותך להשיג ישירות פונקציית בטון באמצעות get_concrete_function .

print("Obtaining concrete trace")
double_strings = double.get_concrete_function(tf.constant("a"))
print("Executing traced function")
print(double_strings(tf.constant("a")))
print(double_strings(a=tf.constant("b")))
Obtaining concrete trace
Executing traced function
tf.Tensor(b'aa', shape=(), dtype=string)
tf.Tensor(b'bb', shape=(), dtype=string)
# You can also call get_concrete_function on an InputSpec
double_strings_from_inputspec = double.get_concrete_function(tf.TensorSpec(shape=[], dtype=tf.string))
print(double_strings_from_inputspec(tf.constant("c")))
Tracing with Tensor("a:0", shape=(), dtype=string)
tf.Tensor(b'cc', shape=(), dtype=string)

הדפסת ConcreteFunction מציגה סיכום של ארגומנטים הקלט שלו (עם סוגים) וסוג הפלט שלו.

print(double_strings)
ConcreteFunction double(a)
  Args:
    a: string Tensor, shape=()
  Returns:
    string Tensor, shape=()

ניתן גם לאחזר ישירות את חתימת הפונקציה הקונקרטית.

print(double_strings.structured_input_signature)
print(double_strings.structured_outputs)
((TensorSpec(shape=(), dtype=tf.string, name='a'),), {})
Tensor("Identity:0", shape=(), dtype=string)

שימוש במעקב בטון עם סוגים שאינם תואמים יזרוק שגיאה

with assert_raises(tf.errors.InvalidArgumentError):
  double_strings(tf.constant(1))
Caught expected exception 
  <class 'tensorflow.python.framework.errors_impl.InvalidArgumentError'>:
Traceback (most recent call last):
  File "<ipython-input-1-73d0ca52e838>", line 8, in assert_raises
    yield
  File "<ipython-input-1-e4e2860a4364>", line 2, in <module>
    double_strings(tf.constant(1))
tensorflow.python.framework.errors_impl.InvalidArgumentError: cannot compute __inference_double_162 as input #0(zero-based) was expected to be a string tensor but is a int32 tensor [Op:__inference_double_162]

יתכן שתבחין כי לטיעוני Python ניתן טיפול מיוחד בחתימת הקלט של פונקציה קונקרטית. לפני TensorFlow 2.3, טיעוני פיתון פשוט הוסרו מחתימת הפונקציה הקונקרטית. החל מ- TensorFlow 2.3, ארגומנטים של Python נותרים בחתימה, אך מוגבלים לקחת את הערך שנקבע במהלך המעקב.

@tf.function
def pow(a, b):
  return a ** b

square = pow.get_concrete_function(a=tf.TensorSpec(None, tf.float32), b=2)
print(square)
ConcreteFunction pow(a, b=2)
  Args:
    a: float32 Tensor, shape=<unknown>
  Returns:
    float32 Tensor, shape=<unknown>
assert square(tf.constant(10.0)) == 100

with assert_raises(TypeError):
  square(tf.constant(10.0), b=3)
Caught expected exception 
  <class 'TypeError'>:
Traceback (most recent call last):
  File "/tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/eager/function.py", line 1683, in _call_impl
    cancellation_manager)
  File "/tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/eager/function.py", line 1728, in _call_with_flat_signature
    self._flat_signature_summary(), ", ".join(sorted(kwargs))))
TypeError: pow(a) got unexpected keyword arguments: b.

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "<ipython-input-1-73d0ca52e838>", line 8, in assert_raises
    yield
  File "<ipython-input-1-d163f3d206cb>", line 4, in <module>
    square(tf.constant(10.0), b=3)
TypeError: ConcreteFunction pow(a, b) was constructed with int value 2 in b, but was called with int value 3

השגת גרפים

כל פונקציה קונקרטית היא עטיפה tf.Graph סביב tf.Graph . למרות tf.Graph האובייקט tf.Graph בפועל אינו דבר שבדרך כלל תצטרך לעשות, אתה יכול להשיג אותו בקלות מכל פונקציה קונקרטית.

graph = double_strings.graph
for node in graph.as_graph_def().node:
  print(f'{node.input} -> {node.name}')
[] -> a
['a', 'a'] -> add
['add'] -> Identity

ניפוי באגים

באופן כללי, קוד איתור באגים קל יותר במצב להוט מאשר tf.function . עליך לוודא שהקוד שלך מבוצע ללא שגיאות במצב להוט לפני שאתה מקשט עם tf.function . כדי לסייע בתהליך איתור באגים, אתה יכול להתקשר אל tf.config.run_functions_eagerly(True) כדי להשבית ולהפעיל מחדש את tf.function .

בעת מעקב אחר בעיות המופיעות רק ב- tf.function , הנה כמה טיפים:

  • שיחות print ישנות של פייתון print מתבצעות רק במהלך האיתור, ועוזרות לך להתחקות אחר המעקב אחר הפונקציה שלך.
  • שיחות tf.print יבוצעו בכל פעם ויכולות לעזור לך לאתר ערכי ביניים במהלך הביצוע.
  • tf.debugging.enable_check_numerics היא דרך קלה לאתר היכן נוצרים NaNs ו- Inf.
  • pdb יכול לעזור לך להבין מה קורה במהלך האיתור. (אזהרה: ה- PDB יפיל אותך לקוד מקור שהופך באמצעות AutoGraph.)

שינויים באוטוגרף

AutoGraph היא ספרייה tf.function כברירת מחדל tf.function , והופכת תת-קבוצה של קוד להוט של פייתון ל TensorFlow אופציות תואמות גרף. זה כולל זרימת בקרה if , for while .

אופציות TensorFlow כמו tf.cond ו- tf.cond . tf.while_loop ממשיכות לעבוד, אך לעתים קרובות קל יותר לכתוב ולהבין את זרימת השליטה tf.while_loop בפייתון.

# Simple loop

@tf.function
def f(x):
  while tf.reduce_sum(x) > 1:
    tf.print(x)
    x = tf.tanh(x)
  return x

f(tf.random.uniform([5]))
[0.928048491 0.537333608 0.319427252 0.414729953 0.138620138]
[0.729682684 0.490966946 0.308988899 0.392481416 0.137739]
[0.62287122 0.454983532 0.299516946 0.373497456 0.136874482]
[0.553123951 0.425986826 0.290870458 0.357047111 0.13602607]
[0.502857924 0.401961982 0.282935768 0.342610359 0.135193244]
[0.464361787 0.381626487 0.27562 0.329805791 0.134375557]
[0.433632493 0.364119112 0.268846452 0.318346262 0.133572534]
[0.408352554 0.348837078 0.262551099 0.308010817 0.132783771]
[0.387072921 0.335343778 0.256680071 0.298626363 0.132008836]
[0.368834078 0.32331419 0.251187652 0.290055037 0.131247327]
[0.352971435 0.312500536 0.246034727 0.282185435 0.130498841]
[0.339008093 0.302710205 0.241187632 0.274926543 0.129763052]
[0.326591551 0.293790847 0.236617178 0.26820302 0.129039586]
[0.315454811 0.285620153 0.232297987 0.261951953 0.128328085]
[0.305391371 0.278098613 0.228207797 0.256120354 0.127628237]
[0.296238661 0.27114439 0.224326983 0.250663161 0.126939729]
[0.287866682 0.264689356 0.220638305 0.245541915 0.126262262]
[0.280170113 0.25867638 0.217126325 0.240723446 0.12559554]
[0.273062497 0.253057063 0.213777393 0.236178935 0.124939285]
[0.266472191 0.247790173 0.210579231 0.231883332 0.124293216]
[0.260339141 0.242840245 0.207520843 0.227814704 0.12365707]
[0.254612684 0.238176659 0.204592302 0.223953649 0.123030603]
[0.249249727 0.23377277 0.201784685 0.220283121 0.122413576]
[0.244213238 0.229605287 0.199089885 0.216787875 0.12180575]
<tf.Tensor: shape=(5,), dtype=float32, numpy=
array([0.23947136, 0.22565375, 0.19650048, 0.21345437, 0.12120689],
      dtype=float32)>

אם אתה סקרן אתה יכול לבדוק את קוד החתימה שמייצר.

print(tf.autograph.to_code(f.python_function))
def tf__f(x):
    with ag__.FunctionScope('f', 'fscope', ag__.ConversionOptions(recursive=True, user_requested=True, optional_features=(), internal_convert_user_code=True)) as fscope:
        do_return = False
        retval_ = ag__.UndefinedReturnValue()

        def get_state():
            return (x,)

        def set_state(vars_):
            nonlocal x
            (x,) = vars_

        def loop_body():
            nonlocal x
            ag__.converted_call(ag__.ld(tf).print, (ag__.ld(x),), None, fscope)
            x = ag__.converted_call(ag__.ld(tf).tanh, (ag__.ld(x),), None, fscope)

        def loop_test():
            return (ag__.converted_call(ag__.ld(tf).reduce_sum, (ag__.ld(x),), None, fscope) > 1)
        ag__.while_stmt(loop_test, loop_body, get_state, set_state, ('x',), {})
        try:
            do_return = True
            retval_ = ag__.ld(x)
        except:
            do_return = False
            raise
        return fscope.ret(retval_, do_return)

תנאים

AutoGraph ימיר כמה הצהרות if <condition> לשיחות tf.cond המקבילות. החלפה זו נעשית אם <condition> הוא טנסור. אחרת, if בהצהרה מבוצעת בתור מותנית Python.

תנאי פייתון מבוצע במהלך המעקב, כך שענף אחד של התנאי יתווסף לגרף. ללא AutoGraph, גרף מעקב זה לא יוכל לקחת את הענף החלופי אם יש זרימת בקרה תלויה בנתונים.

tf.cond עוקב ומוסיף את שני ענפי התנאי לגרף, תוך בחירה דינמית של ענף בזמן הביצוע. למעקב יכולות להיות תופעות לוואי לא מכוונות; ראה אפקטים למעקב אחר AutoGraph לקבלת מידע נוסף.

@tf.function
def fizzbuzz(n):
  for i in tf.range(1, n + 1):
    print('Tracing for loop')
    if i % 15 == 0:
      print('Tracing fizzbuzz branch')
      tf.print('fizzbuzz')
    elif i % 3 == 0:
      print('Tracing fizz branch')
      tf.print('fizz')
    elif i % 5 == 0:
      print('Tracing buzz branch')
      tf.print('buzz')
    else:
      print('Tracing default branch')
      tf.print(i)

fizzbuzz(tf.constant(5))
fizzbuzz(tf.constant(20))
Tracing for loop
Tracing fizzbuzz branch
Tracing fizz branch
Tracing buzz branch
Tracing default branch
1
2
fizz
4
buzz
1
2
fizz
4
buzz
fizz
7
8
fizz
buzz
11
fizz
13
14
fizzbuzz
16
17
fizz
19
buzz

עיין בתיעוד ההפניה להגבלות נוספות על הצהרות המרות AutoGraph אם.

לולאות

חתימה תמיר חלק for ו while הצהרות לתוך TensorFlow המקבילה looping ops, כמו tf.while_loop . אם לא מומרים, הלולאה for או while מבוצעת כלולאת פייתון.

החלפה זו נעשית במצבים הבאים:

  • for x in y : אם y הוא טנסור, המירו ל- tf.while_loop . tf.while_loop . במקרה המיוחד שבו y הואtf.data.Dataset , נוצר שילוב שלtf.data.Dataset ops.
  • while <condition> : אם <condition> הוא טנסור, המירו ל- tf.while_loop . tf.while_loop .

לולאת פייתון tf.Graph במהלך המעקב, ומוסיפה אופס נוספים ל- tf.Graph עבור כל איטרציה של הלולאה.

לולאת TensorFlow עוקבת אחר גוף הלולאה, ובוחרת באופן דינמי כמה איטרציות לרוץ בזמן הביצוע. גוף הלולאה מופיע רק פעם אחת ב- tf.Graph שנוצר.

עיין בתיעוד ההפניה למגבלות נוספות בהמרת AutoGraph for הצהרות while .

מדלג על נתוני פייתון

מלכודת נפוצה היא לולאה על נתונים של פייתון / Numpy בתוך tf.function . לולאה זו תבוצע במהלך תהליך האיתור, tf.Graph עותק של המודל שלך ל- tf.Graph עבור כל איטרציה של הלולאה.

אם ברצונך לעטוף את כל לולאת האימונים tf.function , הדרך הבטוחה ביותר לעשות זאת היא לעטוף את הנתונים שלך כ-tf.data.Dataset כך שאוטוגרף יבטל את לולאת האימון באופן דינמי.

def measure_graph_size(f, *args):
  g = f.get_concrete_function(*args).graph
  print("{}({}) contains {} nodes in its graph".format(
      f.__name__, ', '.join(map(str, args)), len(g.as_graph_def().node)))

@tf.function
def train(dataset):
  loss = tf.constant(0)
  for x, y in dataset:
    loss += tf.abs(y - x) # Some dummy computation.
  return loss

small_data = [(1, 1)] * 3
big_data = [(1, 1)] * 10
measure_graph_size(train, small_data)
measure_graph_size(train, big_data)

measure_graph_size(train, tf.data.Dataset.from_generator(
    lambda: small_data, (tf.int32, tf.int32)))
measure_graph_size(train, tf.data.Dataset.from_generator(
    lambda: big_data, (tf.int32, tf.int32)))
train([(1, 1), (1, 1), (1, 1)]) contains 11 nodes in its graph
train([(1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1)]) contains 32 nodes in its graph
train(<FlatMapDataset shapes: (<unknown>, <unknown>), types: (tf.int32, tf.int32)>) contains 10 nodes in its graph
train(<FlatMapDataset shapes: (<unknown>, <unknown>), types: (tf.int32, tf.int32)>) contains 10 nodes in its graph

כאשר עוטפים נתונים של Python / Numpy במערך נתונים, יש לשים לב ל tf.data.Dataset.from_generator לעומת tf.data.Dataset.from_tensors . הראשון ישמור את הנתונים בפייתון tf.py_function אותם באמצעות tf.py_function שיכולה להיות השלכות ביצועים, ואילו האחרון tf.constant() עותק של הנתונים tf.constant() אחד גדול בגרף, שיכול להיות בעל השלכות זיכרון.

קריאת נתונים מקבצים באמצעות TFRecordDataset / CsvDataset / וכו '. היא הדרך היעילה ביותר לצרוך נתונים, שכן אז TensorFlow עצמו יכול לנהל את הטעינה והאסיפה האסינכרוניים של נתונים, מבלי שיהיה צורך לערב את Python. למידע נוסף, עיין במדריך tf.data .

צבירת ערכים בלולאה

דפוס נפוץ הוא צבירת ערכי ביניים מלולאה. בדרך כלל, הדבר נעשה על ידי הצטרפות לרשימת פייתון או הוספת ערכים למילון פייתון. עם זאת, מכיוון שמדובר בתופעות לוואי של פייתון, הן לא יעבדו כצפוי בלולאה הנגללת באופן דינמי. השתמש ב- tf.TensorArray כדי לצבור תוצאות tf.TensorArray באופן דינמי.

batch_size = 2
seq_len = 3
feature_size = 4

def rnn_step(inp, state):
  return inp + state

@tf.function
def dynamic_rnn(rnn_step, input_data, initial_state):
  # [batch, time, features] -> [time, batch, features]
  input_data = tf.transpose(input_data, [1, 0, 2])
  max_seq_len = input_data.shape[0]

  states = tf.TensorArray(tf.float32, size=max_seq_len)
  state = initial_state
  for i in tf.range(max_seq_len):
    state = rnn_step(input_data[i], state)
    states = states.write(i, state)
  return tf.transpose(states.stack(), [1, 0, 2])

dynamic_rnn(rnn_step,
            tf.random.uniform([batch_size, seq_len, feature_size]),
            tf.zeros([batch_size, feature_size]))
<tf.Tensor: shape=(2, 3, 4), dtype=float32, numpy=
array([[[0.8216245 , 0.29562855, 0.379112  , 0.49940717],
        [1.6473945 , 1.039927  , 1.3268942 , 0.5298227 ],
        [2.4393063 , 1.1283967 , 2.087479  , 1.2748951 ]],

       [[0.08016336, 0.73864746, 0.33738315, 0.4542967 ],
        [0.7459605 , 1.307698  , 1.1588445 , 0.9293362 ],
        [1.3752056 , 1.6133544 , 1.8199729 , 1.7356051 ]]], dtype=float32)>

מגבלות

TensorFlow Function יש מספר מגבלות על ידי עיצוב שאתה צריך להיות מודע בעת המרת פונקצית פיתון על Function .

ביצוע תופעות לוואי של פייתון

תופעות לוואי, כמו הדפסה, הוספה לרשימות ומוטציות גלובליות, יכולות להתנהג באופן בלתי צפוי בתוך Function , ולעיתים לבצע פעמיים או לא את כולן. הם מתרחשים רק בפעם הראשונה שאתה מתקשר Function עם סט כניסות. לאחר מכן, tf.Graph מבוצע מחדש, מבלי לבצע את קוד הפייתון.

כלל האצבע הכללי הוא להימנע מהסתמכות על תופעות לוואי של פייתון בהיגיון שלך ולהשתמש בהן רק לצורך ניפוי העקבות שלך. אחרת, ממשקי API של TensorFlow כמו tf.data , tf.print , tf.summary , tf.Variable.assign ו- tf.TensorArray הם הדרך הטובה ביותר להבטיח שהקוד שלך יבוצע על ידי זמן הריצה של TensorFlow בכל שיחה.

@tf.function
def f(x):
  print("Traced with", x)
  tf.print("Executed with", x)

f(1)
f(1)
f(2)
Traced with 1
Executed with 1
Executed with 1
Traced with 2
Executed with 2

אם ברצונך לבצע קוד פייתון במהלך כל הפעלת Function , tf.py_function הוא פתח יציאה. החיסרון של tf.py_function הוא שהוא לא נייד או ביצועי במיוחד, לא ניתן לשמור עם SavedModel ולא עובד טוב בהגדרות מבוזרות (מרובות GPU, TPU). כמו כן, מכיוון שיש לחבר את tf.py_function הוא tf.py_function את כל הקלטים / הפלטים לטנזורים.

שינוי משתנים גלובליים וחינמיים של Python

שינוי הגלובלי Python משתנה חינם ספירה כתופעת לוואי Python, אז זה קורה רק במהלך מעקב.

external_list = []

@tf.function
def side_effect(x):
  print('Python side effect')
  external_list.append(x)

side_effect(1)
side_effect(1)
side_effect(1)
# The list append only happened once!
assert len(external_list) == 1
Python side effect

עליכם להימנע ממתן מכולות כמו רשימות, דיקטים, אובייקטים אחרים שחיים מחוץ Function . במקום זאת, השתמש בארגומנטים ובאובייקטים של TF. לדוגמא, בסעיף "צבירת ערכים בלולאה" יש דוגמה אחת כיצד ניתן ליישם פעולות דומות לרשימה.

אתה יכול, במקרים מסוימים, ללכוד ולתפעל את המצב אם זה tf.Variable . tf.Variable . כך המשקלים של דגמי Keras מתעדכנים בשיחות חוזרות לאותה ConcreteFunction .

באמצעות איטרטורים ומחוללים של פייתון

תכונות רבות של פייתון, כגון גנרטורים ואיטרטורים, מסתמכות על זמן הריצה של פייתון כדי לעקוב אחר המצב. באופן כללי, בעוד שמבנים אלה פועלים כמצופה במצב להוט, הם דוגמאות לתופעות לוואי של פייתון ולכן הם מתרחשים רק במהלך העקיבה.

@tf.function
def buggy_consume_next(iterator):
  tf.print("Value:", next(iterator))

iterator = iter([1, 2, 3])
buggy_consume_next(iterator)
# This reuses the first value from the iterator, rather than consuming the next value.
buggy_consume_next(iterator)
buggy_consume_next(iterator)
Value: 1
Value: 1
Value: 1

בדיוק כמו איך TensorFlow יש לו התמחה tf.TensorArray עבור מבני רשימה, יש לו התמחה tf.data.Iterator עבור איטרציה מבנים. ראה סעיף על התמרות אוטומטיות של הגרף לקבלת סקירה כללית. כמו כן, ממשק ה- API של tf.data יכול לעזור ביישום דפוסי גנרטורים:

@tf.function
def good_consume_next(iterator):
  # This is ok, iterator is a tf.data.Iterator
  tf.print("Value:", next(iterator))

ds = tf.data.Dataset.from_tensor_slices([1, 2, 3])
iterator = iter(ds)
good_consume_next(iterator)
good_consume_next(iterator)
good_consume_next(iterator)
Value: 1
Value: 2
Value: 3

מחיקת משתנה בין שיחות Function

שגיאה נוספת שתיתקל בה היא משתנה שנאסף אשפה. ConcreteFunction שומרת רק על WeakRefs למשתנים שהם סוגרים עליהם, לכן עליך לשמור הפניה לכל משתנה.

external_var = tf.Variable(3)
@tf.function
def f(x):
  return x * external_var

traced_f = f.get_concrete_function(4)
print("Calling concrete function...")
print(traced_f(4))

# The original variable object gets garbage collected, since there are no more
# references to it.
external_var = tf.Variable(4)
print()
print("Calling concrete function after garbage collecting its closed Variable...")
with assert_raises(tf.errors.FailedPreconditionError):
  traced_f(4)
Calling concrete function...
tf.Tensor(12, shape=(), dtype=int32)

Calling concrete function after garbage collecting its closed Variable...
Caught expected exception 
  <class 'tensorflow.python.framework.errors_impl.FailedPreconditionError'>:
Traceback (most recent call last):
  File "<ipython-input-1-73d0ca52e838>", line 8, in assert_raises
    yield
  File "<ipython-input-1-9a93d2e07632>", line 16, in <module>
    traced_f(4)
tensorflow.python.framework.errors_impl.FailedPreconditionError: 2 root error(s) found.
  (0) Failed precondition:  Error while reading resource variable _AnonymousVar3 from Container: localhost. This could mean that the variable was uninitialized. Not found: Resource localhost/_AnonymousVar3/N10tensorflow3VarE does not exist.
     [[node ReadVariableOp (defined at <ipython-input-1-9a93d2e07632>:4) ]]
  (1) Failed precondition:  Error while reading resource variable _AnonymousVar3 from Container: localhost. This could mean that the variable was uninitialized. Not found: Resource localhost/_AnonymousVar3/N10tensorflow3VarE does not exist.
     [[node ReadVariableOp (defined at <ipython-input-1-9a93d2e07632>:4) ]]
     [[ReadVariableOp/_2]]
0 successful operations.
0 derived errors ignored. [Op:__inference_f_782]

Function call stack:
f -> f

בעיות ידועות

אם Function שלך לא מעריכה כראוי, יתכן והסבר על השגיאה נוצר על ידי בעיות ידועות אלה שתוכננו לתקן בעתיד.

תלוי במשתנים גלובליים וחינמיים של Python

Function יוצרת Function ConcreteFunction חדשה כאשר היא מתקשרת עם ערך חדש של ארגומנט Python. עם זאת, זה לא עושה זאת עבור סגירת הפיתון, הגלובלים או הלא-לוקלים של אותה Function . אם הערך שלהם משתנה בין שיחות Function , Function עדיין תשתמש בערכים שהיו להם בעת עקיבתה. זה שונה מאופן הפעולה של פונקציות פייתון רגילות.

מסיבה זו, אנו ממליצים על סגנון תכנות פונקציונלי המשתמש בארגומנטים במקום לסגור מעל שמות חיצוניים.

@tf.function
def buggy_add():
  return 1 + foo

@tf.function
def recommended_add(foo):
  return 1 + foo

foo = 1
print("Buggy:", buggy_add())
print("Correct:", recommended_add(foo))
Buggy: tf.Tensor(2, shape=(), dtype=int32)
Correct: tf.Tensor(2, shape=(), dtype=int32)
print("Updating the value of `foo` to 100!")
foo = 100
print("Buggy:", buggy_add())  # Did not change!
print("Correct:", recommended_add(foo))
Updating the value of `foo` to 100!
Buggy: tf.Tensor(2, shape=(), dtype=int32)
Correct: tf.Tensor(101, shape=(), dtype=int32)

אתה יכול לסגור מעל שמות חיצוניים, כל עוד לא תעדכן את הערכים שלהם.

תלוי באובייקטים של פייתון

ההמלצה להעביר אובייקטים של פייתון כטיעונים tf.function יש מספר בעיות ידועות שצפויות לתקן בעתיד. באופן כללי, אתה יכול להסתמך על מעקב עקבי אם אתה משתמש במבנה tf.nest Python או tf.nest או מעביר במקרה אחר של אובייקט Function . עם זאת, Function לא תיצור עקבות חדשים כאשר תעביר את אותו אובייקט ותשנה רק את תכונותיו .

class SimpleModel(tf.Module):
  def __init__(self):
    # These values are *not* tf.Variables.
    self.bias = 0.
    self.weight = 2.

@tf.function
def evaluate(model, x):
  return model.weight * x + model.bias

simple_model = SimpleModel()
x = tf.constant(10.)
print(evaluate(simple_model, x))
tf.Tensor(20.0, shape=(), dtype=float32)
print("Adding bias!")
simple_model.bias += 5.0
print(evaluate(simple_model, x))  # Didn't change :(
Adding bias!
tf.Tensor(20.0, shape=(), dtype=float32)

השימוש באותה Function כדי להעריך את המופע המעודכן של המודל יהיה טעון מכיוון שלמודל המעודכן יש אותו מפתח מטמון כמו המודל המקורי.

מסיבה זו, אנו ממליצים לך לכתוב את Function שלך כדי להימנע מהתלות בתכונות האובייקט הניתנות לשינוי או ליצור אובייקטים חדשים.

אם זה לא אפשרי, דרך לעקיפת הבעיה היא ליצור Function חדשות בכל פעם שאתה משנה את האובייקט שלך לכוח חוזר:

def evaluate(model, x):
  return model.weight * x + model.bias

new_model = SimpleModel()
evaluate_no_bias = tf.function(evaluate).get_concrete_function(new_model, x)
# Don't pass in `new_model`, `Function` already captured its state during tracing.
print(evaluate_no_bias(x))
tf.Tensor(20.0, shape=(), dtype=float32)
print("Adding bias!")
new_model.bias += 5.0
# Create new Function and ConcreteFunction since you modified new_model.
evaluate_with_bias = tf.function(evaluate).get_concrete_function(new_model, x)
print(evaluate_with_bias(x)) # Don't pass in `new_model`.
Adding bias!
tf.Tensor(25.0, shape=(), dtype=float32)

מכיוון שהחזרה יכולה להיות יקרה , אתה יכול להשתמש ב- tf.Variable s tf.Variable אובייקט, שאותן ניתן לשנות (אך לא לשנות, זהיר!) להשפעה דומה מבלי להזדקק למעקב אחר.

class BetterModel:

  def __init__(self):
    self.bias = tf.Variable(0.)
    self.weight = tf.Variable(2.)

@tf.function
def evaluate(model, x):
  return model.weight * x + model.bias

better_model = BetterModel()
print(evaluate(better_model, x))
tf.Tensor(20.0, shape=(), dtype=float32)
print("Adding bias!")
better_model.bias.assign_add(5.0)  # Note: instead of better_model.bias += 5
print(evaluate(better_model, x))  # This works!
Adding bias!
tf.Tensor(25.0, shape=(), dtype=float32)

יצירת tf. משתנים

Function תומכת ביצירת משתנים רק פעם אחת, כאשר נקרא לראשונה, ולאחר מכן שימוש חוזר בהם. אינך יכול ליצור tf.Variables . tf.Variables חדש. כרגע אסור ליצור משתנים חדשים בשיחות הבאות, אך יהיה בעתיד.

דוגמא:

@tf.function
def f(x):
  v = tf.Variable(1.0)
  return v

with assert_raises(ValueError):
  f(1.0)
Caught expected exception 
  <class 'ValueError'>:
Traceback (most recent call last):
  File "<ipython-input-1-73d0ca52e838>", line 8, in assert_raises
    yield
  File "<ipython-input-1-8a0913e250e0>", line 7, in <module>
    f(1.0)
ValueError: in user code:

    <ipython-input-1-8a0913e250e0>:3 f  *
        v = tf.Variable(1.0)
    /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/ops/variables.py:262 __call__  **
        return cls._variable_v2_call(*args, **kwargs)
    /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/ops/variables.py:256 _variable_v2_call
        shape=shape)
    /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/ops/variables.py:67 getter
        return captured_getter(captured_previous, **kwargs)
    /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/eager/def_function.py:731 invalid_creator_scope
        "tf.function-decorated function tried to create "

    ValueError: tf.function-decorated function tried to create variables on non-first call.

ניתן ליצור משתנים בתוך Function כל עוד משתנים אלה נוצרים רק בפעם הראשונה שהפונקציה מבוצעת.

class Count(tf.Module):
  def __init__(self):
    self.count = None

  @tf.function
  def __call__(self):
    if self.count is None:
      self.count = tf.Variable(0)
    return self.count.assign_add(1)

c = Count()
print(c())
print(c())
tf.Tensor(1, shape=(), dtype=int32)
tf.Tensor(2, shape=(), dtype=int32)

שימוש עם מספר אופטימיזציה של Keras

אתה עלול להיתקל ב- ValueError: tf.function-decorated function tried to create variables on non-first call. בעת שימוש ביותר ממייעל Keras אחד עם tf.function . שגיאה זו מתרחשת מכיוון tf.Variables יוצרת באופן פנימי tf.Variables כאשר הם מחילים שיפועים בפעם הראשונה.

opt1 = tf.keras.optimizers.Adam(learning_rate = 1e-2)
opt2 = tf.keras.optimizers.Adam(learning_rate = 1e-3)

@tf.function
def train_step(w, x, y, optimizer):
   with tf.GradientTape() as tape:
       L = tf.reduce_sum(tf.square(w*x - y))
   gradients = tape.gradient(L, [w])
   optimizer.apply_gradients(zip(gradients, [w]))

w = tf.Variable(2.)
x = tf.constant([-1.])
y = tf.constant([2.])

train_step(w, x, y, opt1)
print("Calling `train_step` with different optimizer...")
with assert_raises(ValueError):
  train_step(w, x, y, opt2)
Calling `train_step` with different optimizer...
Caught expected exception 
  <class 'ValueError'>:
Traceback (most recent call last):
  File "<ipython-input-1-73d0ca52e838>", line 8, in assert_raises
    yield
  File "<ipython-input-1-d3d3937dbf1a>", line 18, in <module>
    train_step(w, x, y, opt2)
ValueError: in user code:

    <ipython-input-1-d3d3937dbf1a>:9 train_step  *
        optimizer.apply_gradients(zip(gradients, [w]))
    /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/keras/optimizer_v2/optimizer_v2.py:604 apply_gradients  **
        self._create_all_weights(var_list)
    /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/keras/optimizer_v2/optimizer_v2.py:781 _create_all_weights
        _ = self.iterations
    /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/keras/optimizer_v2/optimizer_v2.py:788 __getattribute__
        return super(OptimizerV2, self).__getattribute__(name)
    /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/keras/optimizer_v2/optimizer_v2.py:926 iterations
        aggregation=tf_variables.VariableAggregation.ONLY_FIRST_REPLICA)
    /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/keras/optimizer_v2/optimizer_v2.py:1132 add_weight
        aggregation=aggregation)
    /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/training/tracking/base.py:810 _add_variable_with_custom_getter
        **kwargs_for_getter)
    /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/keras/engine/base_layer_utils.py:142 make_variable
        shape=variable_shape if variable_shape else None)
    /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/ops/variables.py:260 __call__
        return cls._variable_v1_call(*args, **kwargs)
    /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/ops/variables.py:221 _variable_v1_call
        shape=shape)
    /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/ops/variables.py:67 getter
        return captured_getter(captured_previous, **kwargs)
    /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/eager/def_function.py:731 invalid_creator_scope
        "tf.function-decorated function tried to create "

    ValueError: tf.function-decorated function tried to create variables on non-first call.

אם אתה צריך לשנות את האופטימיזציה במהלך האימון, דרך לעקיפת הבעיה היא ליצור Function חדשה עבור כל אופטימיזציה, להתקשר ישירות ל- ConcreteFunction .

opt1 = tf.keras.optimizers.Adam(learning_rate = 1e-2)
opt2 = tf.keras.optimizers.Adam(learning_rate = 1e-3)

# Not a tf.function.
def train_step(w, x, y, optimizer):
   with tf.GradientTape() as tape:
       L = tf.reduce_sum(tf.square(w*x - y))
   gradients = tape.gradient(L, [w])
   optimizer.apply_gradients(zip(gradients, [w]))

w = tf.Variable(2.)
x = tf.constant([-1.])
y = tf.constant([2.])

# Make a new Function and ConcreteFunction for each optimizer.
train_step_1 = tf.function(train_step).get_concrete_function(w, x, y, opt1)
train_step_2 = tf.function(train_step).get_concrete_function(w, x, y, opt2)
for i in range(10):
  if i % 2 == 0:
    train_step_1(w, x, y) # `opt1` is not used as a parameter. 
  else:
    train_step_2(w, x, y) # `opt2` is not used as a parameter.

שימוש עם מספר דגמי Keras

יתכן שתיתקל גם ב- ValueError: tf.function-decorated function tried to create variables on non-first call. כאשר מעבירים מופעי מודל שונים לאותה Function .

שגיאה זו מתרחשת מכיוון שמודלים של Keras (אשר לא מוגדרת צורת הקלט שלהם ) ושכבות tf.Variables יוצרים tf.Variables s כאשר הם נקראים לראשונה. ייתכן שאתה מנסה לאתחל את המשתנים האלה בתוך Function שכבר נקראה. כדי להימנע משגיאה זו, נסה להתקשר model.build(input_shape) כדי לאתחל את כל המשקולות לפני אימון המודל.

לקריאה נוספת

למידע על אופן הייצוא והטעינה של Function , עיין במדריך SavedModel . למידע נוסף על אופטימיזציות גרפים המבוצעות לאחר מעקב, עיין במדריך Grappler . כדי ללמוד כיצד לבצע אופטימיזציה של צינור הנתונים שלך ולפרופיל את המודל שלך, עיין במדריך Profiler .