עזרה להגן על שונית המחסום הגדולה עם TensorFlow על Kaggle הצטרפו אתגר

היכרות עם גרפים ותפקוד tf

הצג באתר TensorFlow.org הפעל בגוגל קולאב צפה במקור ב-GitHub הורד מחברת

סקירה כללית

מדריך זה עובר מתחת לפני השטח של TensorFlow ו-Keras כדי להדגים כיצד TensorFlow פועל. אם במקום זאת אתה רוצה להתחיל מיד עם Keras, עיין באוסף המדריכים של Keras .

במדריך זה, תלמד כיצד TensorFlow מאפשר לך לבצע שינויים פשוטים בקוד שלך כדי לקבל גרפים, כיצד גרפים מאוחסנים ומיוצגים, וכיצד אתה יכול להשתמש בהם כדי להאיץ את המודלים שלך.

זוהי סקירה כללית הכוללת את האופן שבו tf.function מאפשר לך לעבור מביצוע נלהב לביצוע גרף. למפרט מלא יותר של tf.function , עבור למדריך tf.function .

מה זה גרפים?

בשלושת המדריכים הקודמים, הפעלת את TensorFlow בשקיקה . המשמעות היא שפעולות TensorFlow מבוצעות על ידי Python, פעולה אחר פעולה, והחזרת תוצאות חזרה לפייתון.

בעוד שלביצוע להוט יש כמה יתרונות ייחודיים, ביצוע גרפים מאפשר ניידות מחוץ ל-Python ונוטה להציע ביצועים טובים יותר. ביצוע גרף פירושו שחישובי טנסור מבוצעים כגרף TensorFlow , המכונה לפעמים tf.Graph או פשוט "גרף".

גרפים הם מבני נתונים המכילים קבוצה של אובייקטי tf.Operation , המייצגים יחידות חישוב; ו- tf.Tensor objects, המייצגים את יחידות הנתונים הזורמות בין פעולות. הם מוגדרים בהקשר tf.Graph . מכיוון שהגרפים הללו הם מבני נתונים, ניתן לשמור, להפעיל ולשחזר אותם ללא קוד Python המקורי.

כך נראה גרף TensorFlow המייצג רשת עצבית דו-שכבתית כאשר הוא מוצג ב-TensorBoard.

גרף TensorFlow פשוט

היתרונות של גרפים

עם גרף, יש לך מידה רבה של גמישות. אתה יכול להשתמש בגרף TensorFlow שלך בסביבות שאין בהן מתורגמן של Python, כמו יישומים ניידים, מכשירים משובצים ושרתי קצה. TensorFlow משתמש בגרפים כפורמט למודלים שמורים כאשר הוא מייצא אותם מ-Python.

גם גרפים עוברים אופטימיזציה בקלות, מה שמאפשר למהדר לבצע טרנספורמציות כמו:

  • הסיק באופן סטטי את הערך של הטנזורים על ידי קיפול צמתים קבועים בחישוב שלך ("קיפול מתמיד") .
  • הפרד חלקי משנה של חישוב שאינם תלויים ומפצל אותם בין שרשורים או התקנים.
  • פשט פעולות אריתמטיות על ידי ביטול ביטויי משנה נפוצים.

יש מערכת אופטימיזציה שלמה, Grappler , לביצוע מהירות זו ואחרות.

בקיצור, גרפים שימושיים ביותר ומאפשרים ל-TensorFlow שלך לרוץ מהר , לרוץ במקביל ולרוץ ביעילות על מספר מכשירים .

עם זאת, אתה עדיין רוצה להגדיר את מודלים למידת המכונה שלך (או חישובים אחרים) ב-Python מטעמי נוחות, ולאחר מכן לבנות באופן אוטומטי גרפים כאשר אתה צריך אותם.

להכין

import tensorflow as tf
import timeit
from datetime import datetime

ניצול של גרפים

אתה יוצר ומפעיל גרף ב-TensorFlow באמצעות tf.function , כקריאה ישירה או כדקורטור. tf.function לוקח פונקציה רגילה כקלט ומחזירה Function . A Function הוא פייתון הניתן להתקשרות שבונה גרפים של TensorFlow מפונקציית Python. אתה משתמש Function באותו אופן כמו המקבילה שלה לפייתון.

# Define a Python function.
def a_regular_function(x, y, b):
  x = tf.matmul(x, y)
  x = x + b
  return x

# `a_function_that_uses_a_graph` is a TensorFlow `Function`.
a_function_that_uses_a_graph = tf.function(a_regular_function)

# Make some tensors.
x1 = tf.constant([[1.0, 2.0]])
y1 = tf.constant([[2.0], [3.0]])
b1 = tf.constant(4.0)

orig_value = a_regular_function(x1, y1, b1).numpy()
# Call a `Function` like a Python function.
tf_function_value = a_function_that_uses_a_graph(x1, y1, b1).numpy()
assert(orig_value == tf_function_value)

מבחוץ, Function נראית כמו פונקציה רגילה שאתה כותב באמצעות פעולות TensorFlow. אולם מתחת , זה שונה מאוד . Function מקפלת כמה tf.Graph מאחורי API אחד . כך Function מסוגלת לתת לך את היתרונות של ביצוע גרפים , כמו מהירות ופריסה.

tf.function חל על פונקציה ועל כל הפונקציות האחרות שהיא קוראת לה :

def inner_function(x, y, b):
  x = tf.matmul(x, y)
  x = x + b
  return x

# Use the decorator to make `outer_function` a `Function`.
@tf.function
def outer_function(x):
  y = tf.constant([[2.0], [3.0]])
  b = tf.constant(4.0)

  return inner_function(x, y, b)

# Note that the callable will create a graph that
# includes `inner_function` as well as `outer_function`.
outer_function(tf.constant([[1.0, 2.0]])).numpy()
array([[12.]], dtype=float32)

אם השתמשת ב-TensorFlow 1.x, תבחין שבשום זמן לא היית צריך להגדיר Placeholder או tf.Session .

המרת פונקציות Python לגרפים

כל פונקציה שתכתוב עם TensorFlow תכיל תערובת של פעולות TF מובנות ולוגיקת Python, כגון פסקאות if-then , לולאות, break , return , continue ועוד. בעוד שפעולות TensorFlow נלכדות בקלות על ידי tf.Graph , הלוגיקה הספציפית לפייתון צריכה לעבור שלב נוסף כדי להפוך לחלק מהגרף. tf.function משתמש בספרייה בשם AutoGraph ( tf.autograph ) כדי להמיר קוד Python לקוד יצירת גרפים.

def simple_relu(x):
  if tf.greater(x, 0):
    return x
  else:
    return 0

# `tf_simple_relu` is a TensorFlow `Function` that wraps `simple_relu`.
tf_simple_relu = tf.function(simple_relu)

print("First branch, with graph:", tf_simple_relu(tf.constant(1)).numpy())
print("Second branch, with graph:", tf_simple_relu(tf.constant(-1)).numpy())
First branch, with graph: 1
Second branch, with graph: 0

למרות שלא סביר שתצטרך להציג גרפים ישירות, אתה יכול לבדוק את הפלטים כדי לבדוק את התוצאות המדויקות. אלה לא קלים לקריאה, אז אין צורך להסתכל בזהירות רבה מדי!

# This is the graph-generating output of AutoGraph.
print(tf.autograph.to_code(simple_relu))
def tf__simple_relu(x):
    with ag__.FunctionScope('simple_relu', 'fscope', ag__.ConversionOptions(recursive=True, user_requested=True, optional_features=(), internal_convert_user_code=True)) as fscope:
        do_return = False
        retval_ = ag__.UndefinedReturnValue()

        def get_state():
            return (do_return, retval_)

        def set_state(vars_):
            nonlocal retval_, do_return
            (do_return, retval_) = vars_

        def if_body():
            nonlocal retval_, do_return
            try:
                do_return = True
                retval_ = ag__.ld(x)
            except:
                do_return = False
                raise

        def else_body():
            nonlocal retval_, do_return
            try:
                do_return = True
                retval_ = 0
            except:
                do_return = False
                raise
        ag__.if_stmt(ag__.converted_call(ag__.ld(tf).greater, (ag__.ld(x), 0), None, fscope), if_body, else_body, get_state, set_state, ('do_return', 'retval_'), 2)
        return fscope.ret(retval_, do_return)
# This is the graph itself.
print(tf_simple_relu.get_concrete_function(tf.constant(1)).graph.as_graph_def())
node {
  name: "x"
  op: "Placeholder"
  attr {
    key: "_user_specified_name"
    value {
      s: "x"
    }
  }
  attr {
    key: "dtype"
    value {
      type: DT_INT32
    }
  }
  attr {
    key: "shape"
    value {
      shape {
      }
    }
  }
}
node {
  name: "Greater/y"
  op: "Const"
  attr {
    key: "dtype"
    value {
      type: DT_INT32
    }
  }
  attr {
    key: "value"
    value {
      tensor {
        dtype: DT_INT32
        tensor_shape {
        }
        int_val: 0
      }
    }
  }
}
node {
  name: "Greater"
  op: "Greater"
  input: "x"
  input: "Greater/y"
  attr {
    key: "T"
    value {
      type: DT_INT32
    }
  }
}
node {
  name: "cond"
  op: "StatelessIf"
  input: "Greater"
  input: "x"
  attr {
    key: "Tcond"
    value {
      type: DT_BOOL
    }
  }
  attr {
    key: "Tin"
    value {
      list {
        type: DT_INT32
      }
    }
  }
  attr {
    key: "Tout"
    value {
      list {
        type: DT_BOOL
        type: DT_INT32
      }
    }
  }
  attr {
    key: "_lower_using_switch_merge"
    value {
      b: true
    }
  }
  attr {
    key: "_read_only_resource_inputs"
    value {
      list {
      }
    }
  }
  attr {
    key: "else_branch"
    value {
      func {
        name: "cond_false_34"
      }
    }
  }
  attr {
    key: "output_shapes"
    value {
      list {
        shape {
        }
        shape {
        }
      }
    }
  }
  attr {
    key: "then_branch"
    value {
      func {
        name: "cond_true_33"
      }
    }
  }
}
node {
  name: "cond/Identity"
  op: "Identity"
  input: "cond"
  attr {
    key: "T"
    value {
      type: DT_BOOL
    }
  }
}
node {
  name: "cond/Identity_1"
  op: "Identity"
  input: "cond:1"
  attr {
    key: "T"
    value {
      type: DT_INT32
    }
  }
}
node {
  name: "Identity"
  op: "Identity"
  input: "cond/Identity_1"
  attr {
    key: "T"
    value {
      type: DT_INT32
    }
  }
}
library {
  function {
    signature {
      name: "cond_false_34"
      input_arg {
        name: "cond_placeholder"
        type: DT_INT32
      }
      output_arg {
        name: "cond_identity"
        type: DT_BOOL
      }
      output_arg {
        name: "cond_identity_1"
        type: DT_INT32
      }
    }
    node_def {
      name: "cond/Const"
      op: "Const"
      attr {
        key: "dtype"
        value {
          type: DT_BOOL
        }
      }
      attr {
        key: "value"
        value {
          tensor {
            dtype: DT_BOOL
            tensor_shape {
            }
            bool_val: true
          }
        }
      }
    }
    node_def {
      name: "cond/Const_1"
      op: "Const"
      attr {
        key: "dtype"
        value {
          type: DT_BOOL
        }
      }
      attr {
        key: "value"
        value {
          tensor {
            dtype: DT_BOOL
            tensor_shape {
            }
            bool_val: true
          }
        }
      }
    }
    node_def {
      name: "cond/Const_2"
      op: "Const"
      attr {
        key: "dtype"
        value {
          type: DT_INT32
        }
      }
      attr {
        key: "value"
        value {
          tensor {
            dtype: DT_INT32
            tensor_shape {
            }
            int_val: 0
          }
        }
      }
    }
    node_def {
      name: "cond/Const_3"
      op: "Const"
      attr {
        key: "dtype"
        value {
          type: DT_BOOL
        }
      }
      attr {
        key: "value"
        value {
          tensor {
            dtype: DT_BOOL
            tensor_shape {
            }
            bool_val: true
          }
        }
      }
    }
    node_def {
      name: "cond/Identity"
      op: "Identity"
      input: "cond/Const_3:output:0"
      attr {
        key: "T"
        value {
          type: DT_BOOL
        }
      }
    }
    node_def {
      name: "cond/Const_4"
      op: "Const"
      attr {
        key: "dtype"
        value {
          type: DT_INT32
        }
      }
      attr {
        key: "value"
        value {
          tensor {
            dtype: DT_INT32
            tensor_shape {
            }
            int_val: 0
          }
        }
      }
    }
    node_def {
      name: "cond/Identity_1"
      op: "Identity"
      input: "cond/Const_4:output:0"
      attr {
        key: "T"
        value {
          type: DT_INT32
        }
      }
    }
    ret {
      key: "cond_identity"
      value: "cond/Identity:output:0"
    }
    ret {
      key: "cond_identity_1"
      value: "cond/Identity_1:output:0"
    }
    attr {
      key: "_construction_context"
      value {
        s: "kEagerRuntime"
      }
    }
    arg_attr {
      key: 0
      value {
        attr {
          key: "_output_shapes"
          value {
            list {
              shape {
              }
            }
          }
        }
      }
    }
  }
  function {
    signature {
      name: "cond_true_33"
      input_arg {
        name: "cond_identity_1_x"
        type: DT_INT32
      }
      output_arg {
        name: "cond_identity"
        type: DT_BOOL
      }
      output_arg {
        name: "cond_identity_1"
        type: DT_INT32
      }
    }
    node_def {
      name: "cond/Const"
      op: "Const"
      attr {
        key: "dtype"
        value {
          type: DT_BOOL
        }
      }
      attr {
        key: "value"
        value {
          tensor {
            dtype: DT_BOOL
            tensor_shape {
            }
            bool_val: true
          }
        }
      }
    }
    node_def {
      name: "cond/Identity"
      op: "Identity"
      input: "cond/Const:output:0"
      attr {
        key: "T"
        value {
          type: DT_BOOL
        }
      }
    }
    node_def {
      name: "cond/Identity_1"
      op: "Identity"
      input: "cond_identity_1_x"
      attr {
        key: "T"
        value {
          type: DT_INT32
        }
      }
    }
    ret {
      key: "cond_identity"
      value: "cond/Identity:output:0"
    }
    ret {
      key: "cond_identity_1"
      value: "cond/Identity_1:output:0"
    }
    attr {
      key: "_construction_context"
      value {
        s: "kEagerRuntime"
      }
    }
    arg_attr {
      key: 0
      value {
        attr {
          key: "_output_shapes"
          value {
            list {
              shape {
              }
            }
          }
        }
      }
    }
  }
}
versions {
  producer: 898
  min_consumer: 12
}

לרוב, tf.function יעבוד ללא שיקולים מיוחדים. עם זאת, יש כמה אזהרות, והמדריך tf.function יכול לעזור כאן, כמו גם ההפניה המלאה ל-AutoGraph

פולימורפיזם: Function אחת, גרפים רבים

tf.Graph מתמחה בסוג מסוים של קלט (לדוגמה, טנסורים עם dtype ספציפי או אובייקטים עם אותו id() ).

בכל פעם שאתה מפעיל Function עם dtypes וצורות חדשות בארגומנטים שלה, Function יוצר tf.Graph חדש עבור הארגומנטים החדשים. סוגי ה- dtypes והצורות של הקלט של tf.Graph קלט או רק כחתימה .

Function מאחסנת את ה- tf.Graph המתאים לאותה חתימה ב- ConcreteFunction . ConcreteFunction היא עטיפה סביב tf.Graph .

@tf.function
def my_relu(x):
  return tf.maximum(0., x)

# `my_relu` creates new graphs as it observes more signatures.
print(my_relu(tf.constant(5.5)))
print(my_relu([1, -1]))
print(my_relu(tf.constant([3., -3.])))
tf.Tensor(5.5, shape=(), dtype=float32)
tf.Tensor([1. 0.], shape=(2,), dtype=float32)
tf.Tensor([3. 0.], shape=(2,), dtype=float32)

אם Function כבר נקראה עם החתימה הזו, Function לא יוצרת tf.Graph חדש.

# These two calls do *not* create new graphs.
print(my_relu(tf.constant(-2.5))) # Signature matches `tf.constant(5.5)`.
print(my_relu(tf.constant([-1., 1.]))) # Signature matches `tf.constant([3., -3.])`.
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor([0. 1.], shape=(2,), dtype=float32)

מכיוון שהיא מגובה במספר גרפים, Function היא פולימורפית . זה מאפשר לו לתמוך ביותר סוגי קלט ממה tf.Graph בודד יכול לייצג, כמו גם לבצע אופטימיזציה של כל tf.Graph לביצועים טובים יותר.

# There are three `ConcreteFunction`s (one for each graph) in `my_relu`.
# The `ConcreteFunction` also knows the return type and shape!
print(my_relu.pretty_printed_concrete_signatures())
my_relu(x)
  Args:
    x: float32 Tensor, shape=()
  Returns:
    float32 Tensor, shape=()

my_relu(x=[1, -1])
  Returns:
    float32 Tensor, shape=(2,)

my_relu(x)
  Args:
    x: float32 Tensor, shape=(2,)
  Returns:
    float32 Tensor, shape=(2,)

שימוש ב- tf.function

עד כה, למדת כיצד להמיר פונקציית Python לגרף פשוט על ידי שימוש ב- tf.function בתור דקורטור או עטיפה. אבל בפועל, לגרום ל- tf.function לעבוד כהלכה יכול להיות מסובך! בסעיפים הבאים, תלמד כיצד תוכל לגרום לקוד שלך לעבוד כמצופה עם tf.function .

ביצוע גרף לעומת ביצוע להוט

ניתן לבצע את הקוד Function הן בשקיקה והן כגרף. כברירת מחדל, Function מבצעת את הקוד שלה בתור גרף:

@tf.function
def get_MSE(y_true, y_pred):
  sq_diff = tf.pow(y_true - y_pred, 2)
  return tf.reduce_mean(sq_diff)
y_true = tf.random.uniform([5], maxval=10, dtype=tf.int32)
y_pred = tf.random.uniform([5], maxval=10, dtype=tf.int32)
print(y_true)
print(y_pred)
tf.Tensor([1 0 4 4 7], shape=(5,), dtype=int32)
tf.Tensor([3 6 3 0 6], shape=(5,), dtype=int32)
get_MSE(y_true, y_pred)
<tf.Tensor: shape=(), dtype=int32, numpy=11>

כדי לוודא שהגרף של Function שלך עושה את אותו חישוב כמו פונקציית Python המקבילה שלו, אתה יכול לגרום לו לפעול בשקיקה עם tf.config.run_functions_eagerly(True) . זהו מתג שמבטל את היכולת של Function ליצור ולהריץ גרפים , במקום זאת לבצע את הקוד כרגיל.

tf.config.run_functions_eagerly(True)
get_MSE(y_true, y_pred)
<tf.Tensor: shape=(), dtype=int32, numpy=11>
# Don't forget to set it back when you are done.
tf.config.run_functions_eagerly(False)

עם זאת, Function יכולה להתנהג אחרת תחת גרף וביצוע להוט. פונקציית print של Python היא דוגמה אחת להבדלים בין שני המצבים הללו. בוא נבדוק מה קורה כאשר אתה מכניס משפט print לפונקציה שלך וקורא לה שוב ושוב.

@tf.function
def get_MSE(y_true, y_pred):
  print("Calculating MSE!")
  sq_diff = tf.pow(y_true - y_pred, 2)
  return tf.reduce_mean(sq_diff)

שימו לב למה שמודפס:

error = get_MSE(y_true, y_pred)
error = get_MSE(y_true, y_pred)
error = get_MSE(y_true, y_pred)
Calculating MSE!

האם התפוקה מפתיעה? get_MSE הודפס פעם אחת בלבד למרות שהוא נקרא שלוש פעמים.

לשם הסבר, משפט ה- print מבוצע כאשר Function מריץ את הקוד המקורי על מנת ליצור את הגרף בתהליך המכונה "מעקב" . המעקב לוכד את פעולות TensorFlow לתוך גרף, print אינה נלכדת בגרף. לאחר מכן, הגרף הזה מבוצע עבור כל שלוש השיחות מבלי להפעיל שוב את קוד Python .

כבדיקת שפיות, בואו נשבית את ביצוע הגרפים כדי להשוות:

# Now, globally set everything to run eagerly to force eager execution.
tf.config.run_functions_eagerly(True)
# Observe what is printed below.
error = get_MSE(y_true, y_pred)
error = get_MSE(y_true, y_pred)
error = get_MSE(y_true, y_pred)
Calculating MSE!
Calculating MSE!
Calculating MSE!
tf.config.run_functions_eagerly(False)

print היא תופעת לוואי של Python , ויש הבדלים נוספים שכדאי להיות מודעים אליהם בעת המרת Function . למידע נוסף בסעיף מגבלות במדריך ביצועים טובים יותר עם tf.function .

ביצוע לא קפדני

ביצוע גרף מבצע רק את הפעולות הדרושות להפקת ההשפעות הניתנות לצפייה, הכוללות:

  • ערך ההחזרה של הפונקציה
  • תועדו תופעות לוואי ידועות כמו:
    • פעולות קלט/פלט, כמו tf.print
    • פעולות איתור באגים, כגון פונקציות ה-assert ב- tf.debugging
    • מוטציות של tf.Variable

התנהגות זו מכונה בדרך כלל "ביצוע לא קפדני", ושונה מביצוע להוט, שעובר את כל פעולות התוכנית, נחוץ או לא.

בפרט, בדיקת שגיאות בזמן ריצה אינה נחשבת כאפקט שניתן לראות. אם מדלגים על פעולה מכיוון שהיא מיותרת, היא לא יכולה להעלות שגיאות זמן ריצה.

בדוגמה הבאה, פעולת "המיותרת" tf.gather מדלגת במהלך ביצוע הגרף, כך ששגיאת זמן הריצה InvalidArgumentError לא מועלית כפי שתהיה בביצוע להוט. אל תסתמך על העלאת שגיאה בזמן ביצוע גרף.

def unused_return_eager(x):
  # Get index 1 will fail when `len(x) == 1`
  tf.gather(x, [1]) # unused 
  return x

try:
  print(unused_return_eager(tf.constant([0.0])))
except tf.errors.InvalidArgumentError as e:
  # All operations are run during eager execution so an error is raised.
  print(f'{type(e).__name__}: {e}')
tf.Tensor([0.], shape=(1,), dtype=float32)
@tf.function
def unused_return_graph(x):
  tf.gather(x, [1]) # unused
  return x

# Only needed operations are run during graph exection. The error is not raised.
print(unused_return_graph(tf.constant([0.0])))
tf.Tensor([0.], shape=(1,), dtype=float32)

שיטות עבודה מומלצות tf.function

עשוי לקחת זמן להתרגל להתנהגות של Function . כדי להתחיל במהירות, משתמשים הראשונים צריכים לשחק עם פונקציות צעצוע לקשט עם @tf.function כדי לקבל ניסיון במעבר מלהוט לביצוע גרף.

עיצוב עבור tf.function עשוי להיות ההימור הטוב ביותר עבור כתיבת תוכניות TensorFlow תואמות גרפים. הנה כמה עצות:

  • החלף בין ביצוע להוט לביצוע גרף מוקדם ולעתים קרובות עם tf.config.run_functions_eagerly כדי לאתר אם/כאשר שני המצבים מתפצלים.
  • צור tf.Variable s מחוץ לפונקציית Python ושנה אותם מבפנים. אותו דבר לגבי אובייקטים המשתמשים ב- tf.Variable , כמו keras.layers , keras.Model s ו- tf.optimizers .
  • הימנע מכתיבת פונקציות התלויות במשתני Python חיצוניים , למעט אובייקטי tf.Variable s ו-Keras.
  • העדיפו לכתוב פונקציות שלוקחות טנזורים וסוגי TensorFlow אחרים כקלט. אתה יכול להעביר סוגים אחרים של אובייקטים אבל היזהר !
  • כלול כמה שיותר חישוב תחת פונקציה tf.function . כדי למקסם את רווח הביצועים. לדוגמה, לקשט שלב אימון שלם או את כל לולאת האימון.

רואים את המהירות

tf.function בדרך כלל משפר את הביצועים של הקוד שלך, אבל כמות המהירות תלויה בסוג החישוב שאתה מפעיל. חישובים קטנים יכולים להשתלט על ידי התקורה של קריאת גרף. אתה יכול למדוד את ההבדל בביצועים כך:

x = tf.random.uniform(shape=[10, 10], minval=-1, maxval=2, dtype=tf.dtypes.int32)

def power(x, y):
  result = tf.eye(10, dtype=tf.dtypes.int32)
  for _ in range(y):
    result = tf.matmul(x, result)
  return result
print("Eager execution:", timeit.timeit(lambda: power(x, 100), number=1000))
Eager execution: 2.5637862179974036
power_as_graph = tf.function(power)
print("Graph execution:", timeit.timeit(lambda: power_as_graph(x, 100), number=1000))
Graph execution: 0.6832536700021592

tf.function משמש בדרך כלל להאצת לולאות אימון, ותוכלו ללמוד עליה עוד בכתיבת לולאת אימון מאפס עם Keras.

ביצועים ופשרות

גרפים יכולים להאיץ את הקוד שלך, אבל לתהליך יצירתם יש תקורה מסוימת. עבור פונקציות מסוימות, יצירת הגרף לוקחת יותר זמן מאשר ביצוע הגרף. השקעה זו מוחזרת בדרך כלל במהירות עם שיפור הביצועים של ביצועים הבאים, אך חשוב להיות מודע לכך שהשלבים הראשונים של כל אימון מודלים גדולים יכולים להיות איטיים יותר עקב מעקב.

לא משנה כמה גדול הדגם שלך, אתה רוצה להימנע ממעקב לעתים קרובות. המדריך tf.function דן כיצד להגדיר מפרטי קלט ולהשתמש בארגומנטים של טנסור כדי להימנע מחזרה. אם אתה מגלה שאתה מקבל ביצועים גרועים בצורה יוצאת דופן, מומלץ לבדוק אם אתה חוזר בטעות.

מתי מתבצע מעקב אחר Function ?

כדי להבין מתי Function שלך מתחקה, הוסף משפט print לקוד שלה. ככלל אצבע, Function יבצע את הצהרת print בכל פעם שהיא מתחקה.

@tf.function
def a_function_with_python_side_effect(x):
  print("Tracing!") # An eager-only side effect.
  return x * x + tf.constant(2)

# This is traced the first time.
print(a_function_with_python_side_effect(tf.constant(2)))
# The second time through, you won't see the side effect.
print(a_function_with_python_side_effect(tf.constant(3)))
Tracing!
tf.Tensor(6, shape=(), dtype=int32)
tf.Tensor(11, shape=(), dtype=int32)
# This retraces each time the Python argument changes,
# as a Python argument could be an epoch count or other
# hyperparameter.
print(a_function_with_python_side_effect(2))
print(a_function_with_python_side_effect(3))
Tracing!
tf.Tensor(6, shape=(), dtype=int32)
Tracing!
tf.Tensor(11, shape=(), dtype=int32)

ארגומנטים חדשים של Python תמיד מפעילים יצירת גרף חדש, ומכאן המעקב הנוסף.

הצעדים הבאים

תוכל ללמוד עוד על tf.function בדף העזרה של API ועל ידי ביצוע המדריך לביצועים טובים יותר עם tf.function .