עזרה להגן על שונית המחסום הגדולה עם TensorFlow על Kaggle הצטרפו אתגר

היכרות עם גרפים ותפקוד tf

הצג באתר TensorFlow.org הפעל בגוגל קולאב צפה במקור ב-GitHub הורד מחברת

סקירה כללית

מדריך זה עובר מתחת לפני השטח של TensorFlow ו-Keras כדי להדגים כיצד TensorFlow פועל. אם אתה רוצה במקום מייד להתחיל עם Keras, לבדוק את האוסף של מדריכי Keras .

במדריך זה, תלמד כיצד TensorFlow מאפשר לך לבצע שינויים פשוטים בקוד שלך כדי לקבל גרפים, כיצד גרפים מאוחסנים ומיוצגים, וכיצד אתה יכול להשתמש בהם כדי להאיץ את המודלים שלך.

זוהי סקירת תמונה גדולה שמכסה איך tf.function מאפשר לך לעבור ביצוע להוט ביצוע גרף. לפירוט מלא יותר של tf.function , ללכת tf.function מדריך .

מה זה גרפים?

בשלושת המדריכים הקודמים, אתה רץ TensorFlow בשקיקה. המשמעות היא שפעולות TensorFlow מבוצעות על ידי Python, פעולה אחר פעולה, והחזרת תוצאות חזרה לפייתון.

בעוד שלביצוע להוט יש כמה יתרונות ייחודיים, ביצוע גרפים מאפשר ניידות מחוץ ל-Python ונוטה להציע ביצועים טובים יותר. ביצוע גרף אמצעי כי חישובים מותחים מבוצעים כגרף TensorFlow, לפעמים המכונה tf.Graph או פשוט "גרף".

גרפים אינם מבני נתונים המכילים סט של tf.Operation חפץ, המייצגים יחידות חישוב; ו tf.Tensor אובייקטים, אשר מייצג את היחידות של נתונים שזורמות בין פעולות. הם מוגדרים בתוך tf.Graph בהקשר. מכיוון שהגרפים הללו הם מבני נתונים, ניתן לשמור, להפעיל ולשחזר אותם ללא קוד Python המקורי.

כך נראה גרף TensorFlow המייצג רשת עצבית דו-שכבתית כאשר הוא מוצג ב-TensorBoard.

גרף TensorFlow פשוט

היתרונות של גרפים

עם גרף, יש לך מידה רבה של גמישות. אתה יכול להשתמש בגרף TensorFlow שלך בסביבות שאין בהן מתורגמן Python, כמו יישומים ניידים, מכשירים משובצים ושרתי קצה עורפיים. TensorFlow משתמשת גרפים כפורמט עבור דגמים הצילו כאשר היא מייצאת אותם Python.

גם גרפים עוברים אופטימיזציה בקלות, ומאפשרים למהדר לבצע טרנספורמציות כמו:

  • סטטי להסיק את הערך של טנזורים ידי קיפול צמתים קבוע בחישוב שלך ( "קיפול קבוע").
  • הפרד חלקי משנה של חישוב שאינם תלויים ומפצל אותם בין שרשורים או התקנים.
  • פשט פעולות אריתמטיות על ידי ביטול ביטויי משנה נפוצים.

קיימת מערכת האופטימיזציה במלואה, grappler , לבצע הזה speedups אחרים.

בקיצור, גרפים הם מאוד שימושיים ולתת TensorFlow שלך לרוץ מהר, לרוץ במקביל, ולהפעיל ביעילות על מספר מכשירים.

עם זאת, אתה עדיין רוצה להגדיר את מודלים למידת המכונה שלך (או חישובים אחרים) ב- Python מטעמי נוחות, ולאחר מכן לבנות באופן אוטומטי גרפים כאשר אתה צריך אותם.

להכין

import tensorflow as tf
import timeit
from datetime import datetime

ניצול של גרפים

אתה ליצור ולהפעיל גרף TensorFlow באמצעות tf.function , או כקריאה ישירה או בתור מעצב. tf.function לוקח פונקציה קבוע כקלט ומחזירה Function . Function היא callable Python שבונה גרפי TensorFlow מפונקציית Python. אתה משתמש Function באותו אופן כמו המקבילה Python שלה.

# Define a Python function.
def a_regular_function(x, y, b):
  x = tf.matmul(x, y)
  x = x + b
  return x

# `a_function_that_uses_a_graph` is a TensorFlow `Function`.
a_function_that_uses_a_graph = tf.function(a_regular_function)

# Make some tensors.
x1 = tf.constant([[1.0, 2.0]])
y1 = tf.constant([[2.0], [3.0]])
b1 = tf.constant(4.0)

orig_value = a_regular_function(x1, y1, b1).numpy()
# Call a `Function` like a Python function.
tf_function_value = a_function_that_uses_a_graph(x1, y1, b1).numpy()
assert(orig_value == tf_function_value)

על בחוץ, Function נראית כמו פונקציה רגילה שלכם לכתוב באמצעות פעולות TensorFlow. מתחת , עם זאת, זה מאוד שונה. Function מתמצתת מספר tf.Graph ים מאחורי API אחד . כך Function הוא מסוגל לתת לך את היתרונות של ביצוע הגרף , כמו מהירות פריסה.

tf.function חל פונקציה וכל פונקציות אחרות שהיא מכנה:

def inner_function(x, y, b):
  x = tf.matmul(x, y)
  x = x + b
  return x

# Use the decorator to make `outer_function` a `Function`.
@tf.function
def outer_function(x):
  y = tf.constant([[2.0], [3.0]])
  b = tf.constant(4.0)

  return inner_function(x, y, b)

# Note that the callable will create a graph that
# includes `inner_function` as well as `outer_function`.
outer_function(tf.constant([[1.0, 2.0]])).numpy()
array([[12.]], dtype=float32)

אם השתמשת 1.x TensorFlow, תבחין כי בשום שלב לא אתה צריך להגדיר Placeholder או tf.Session .

המרת פונקציות Python לגרפים

כל פונקציה שאתה כותב עם TensorFlow יכיל תערובת של מובנית פעולות TF ולוגיקה Python, כגון if-then סעיפים, לולאות, break , return , continue , ועוד. בעוד פעולות TensorFlow נלכדים בקלות על ידי tf.Graph הצרכים ההיגיון ספציפי Python, לעבור שלב נוסף כדי להיות חלק של הגרף. tf.function משתמשת בספריה בשם חתימה ( tf.autograph ) להמיר קוד פיתון לתוך קוד ליצירת הגרף.

def simple_relu(x):
  if tf.greater(x, 0):
    return x
  else:
    return 0

# `tf_simple_relu` is a TensorFlow `Function` that wraps `simple_relu`.
tf_simple_relu = tf.function(simple_relu)

print("First branch, with graph:", tf_simple_relu(tf.constant(1)).numpy())
print("Second branch, with graph:", tf_simple_relu(tf.constant(-1)).numpy())
First branch, with graph: 1
Second branch, with graph: 0

למרות שלא סביר שתצטרך להציג גרפים ישירות, אתה יכול לבדוק את הפלטים כדי לבדוק את התוצאות המדויקות. אלה לא קלים לקריאה, אז אין צורך להסתכל בזהירות רבה מדי!

# This is the graph-generating output of AutoGraph.
print(tf.autograph.to_code(simple_relu))
def tf__simple_relu(x):
    with ag__.FunctionScope('simple_relu', 'fscope', ag__.ConversionOptions(recursive=True, user_requested=True, optional_features=(), internal_convert_user_code=True)) as fscope:
        do_return = False
        retval_ = ag__.UndefinedReturnValue()

        def get_state():
            return (do_return, retval_)

        def set_state(vars_):
            nonlocal retval_, do_return
            (do_return, retval_) = vars_

        def if_body():
            nonlocal retval_, do_return
            try:
                do_return = True
                retval_ = ag__.ld(x)
            except:
                do_return = False
                raise

        def else_body():
            nonlocal retval_, do_return
            try:
                do_return = True
                retval_ = 0
            except:
                do_return = False
                raise
        ag__.if_stmt(ag__.converted_call(ag__.ld(tf).greater, (ag__.ld(x), 0), None, fscope), if_body, else_body, get_state, set_state, ('do_return', 'retval_'), 2)
        return fscope.ret(retval_, do_return)
# This is the graph itself.
print(tf_simple_relu.get_concrete_function(tf.constant(1)).graph.as_graph_def())
node {
  name: "x"
  op: "Placeholder"
  attr {
    key: "_user_specified_name"
    value {
      s: "x"
    }
  }
  attr {
    key: "dtype"
    value {
      type: DT_INT32
    }
  }
  attr {
    key: "shape"
    value {
      shape {
      }
    }
  }
}
node {
  name: "Greater/y"
  op: "Const"
  attr {
    key: "dtype"
    value {
      type: DT_INT32
    }
  }
  attr {
    key: "value"
    value {
      tensor {
        dtype: DT_INT32
        tensor_shape {
        }
        int_val: 0
      }
    }
  }
}
node {
  name: "Greater"
  op: "Greater"
  input: "x"
  input: "Greater/y"
  attr {
    key: "T"
    value {
      type: DT_INT32
    }
  }
}
node {
  name: "cond"
  op: "StatelessIf"
  input: "Greater"
  input: "x"
  attr {
    key: "Tcond"
    value {
      type: DT_BOOL
    }
  }
  attr {
    key: "Tin"
    value {
      list {
        type: DT_INT32
      }
    }
  }
  attr {
    key: "Tout"
    value {
      list {
        type: DT_BOOL
        type: DT_INT32
      }
    }
  }
  attr {
    key: "_lower_using_switch_merge"
    value {
      b: true
    }
  }
  attr {
    key: "_read_only_resource_inputs"
    value {
      list {
      }
    }
  }
  attr {
    key: "else_branch"
    value {
      func {
        name: "cond_false_34"
      }
    }
  }
  attr {
    key: "output_shapes"
    value {
      list {
        shape {
        }
        shape {
        }
      }
    }
  }
  attr {
    key: "then_branch"
    value {
      func {
        name: "cond_true_33"
      }
    }
  }
}
node {
  name: "cond/Identity"
  op: "Identity"
  input: "cond"
  attr {
    key: "T"
    value {
      type: DT_BOOL
    }
  }
}
node {
  name: "cond/Identity_1"
  op: "Identity"
  input: "cond:1"
  attr {
    key: "T"
    value {
      type: DT_INT32
    }
  }
}
node {
  name: "Identity"
  op: "Identity"
  input: "cond/Identity_1"
  attr {
    key: "T"
    value {
      type: DT_INT32
    }
  }
}
library {
  function {
    signature {
      name: "cond_false_34"
      input_arg {
        name: "cond_placeholder"
        type: DT_INT32
      }
      output_arg {
        name: "cond_identity"
        type: DT_BOOL
      }
      output_arg {
        name: "cond_identity_1"
        type: DT_INT32
      }
    }
    node_def {
      name: "cond/Const"
      op: "Const"
      attr {
        key: "dtype"
        value {
          type: DT_BOOL
        }
      }
      attr {
        key: "value"
        value {
          tensor {
            dtype: DT_BOOL
            tensor_shape {
            }
            bool_val: true
          }
        }
      }
      experimental_debug_info {
        original_node_names: "cond/Const"
      }
    }
    node_def {
      name: "cond/Const_1"
      op: "Const"
      attr {
        key: "dtype"
        value {
          type: DT_BOOL
        }
      }
      attr {
        key: "value"
        value {
          tensor {
            dtype: DT_BOOL
            tensor_shape {
            }
            bool_val: true
          }
        }
      }
      experimental_debug_info {
        original_node_names: "cond/Const_1"
      }
    }
    node_def {
      name: "cond/Const_2"
      op: "Const"
      attr {
        key: "dtype"
        value {
          type: DT_INT32
        }
      }
      attr {
        key: "value"
        value {
          tensor {
            dtype: DT_INT32
            tensor_shape {
            }
            int_val: 0
          }
        }
      }
      experimental_debug_info {
        original_node_names: "cond/Const_2"
      }
    }
    node_def {
      name: "cond/Const_3"
      op: "Const"
      attr {
        key: "dtype"
        value {
          type: DT_BOOL
        }
      }
      attr {
        key: "value"
        value {
          tensor {
            dtype: DT_BOOL
            tensor_shape {
            }
            bool_val: true
          }
        }
      }
      experimental_debug_info {
        original_node_names: "cond/Const_3"
      }
    }
    node_def {
      name: "cond/Identity"
      op: "Identity"
      input: "cond/Const_3:output:0"
      attr {
        key: "T"
        value {
          type: DT_BOOL
        }
      }
      experimental_debug_info {
        original_node_names: "cond/Identity"
      }
    }
    node_def {
      name: "cond/Const_4"
      op: "Const"
      attr {
        key: "dtype"
        value {
          type: DT_INT32
        }
      }
      attr {
        key: "value"
        value {
          tensor {
            dtype: DT_INT32
            tensor_shape {
            }
            int_val: 0
          }
        }
      }
      experimental_debug_info {
        original_node_names: "cond/Const_4"
      }
    }
    node_def {
      name: "cond/Identity_1"
      op: "Identity"
      input: "cond/Const_4:output:0"
      attr {
        key: "T"
        value {
          type: DT_INT32
        }
      }
      experimental_debug_info {
        original_node_names: "cond/Identity_1"
      }
    }
    ret {
      key: "cond_identity"
      value: "cond/Identity:output:0"
    }
    ret {
      key: "cond_identity_1"
      value: "cond/Identity_1:output:0"
    }
    attr {
      key: "_construction_context"
      value {
        s: "kEagerRuntime"
      }
    }
    arg_attr {
      key: 0
      value {
        attr {
          key: "_output_shapes"
          value {
            list {
              shape {
              }
            }
          }
        }
      }
    }
  }
  function {
    signature {
      name: "cond_true_33"
      input_arg {
        name: "cond_identity_1_x"
        type: DT_INT32
      }
      output_arg {
        name: "cond_identity"
        type: DT_BOOL
      }
      output_arg {
        name: "cond_identity_1"
        type: DT_INT32
      }
    }
    node_def {
      name: "cond/Const"
      op: "Const"
      attr {
        key: "dtype"
        value {
          type: DT_BOOL
        }
      }
      attr {
        key: "value"
        value {
          tensor {
            dtype: DT_BOOL
            tensor_shape {
            }
            bool_val: true
          }
        }
      }
      experimental_debug_info {
        original_node_names: "cond/Const"
      }
    }
    node_def {
      name: "cond/Identity"
      op: "Identity"
      input: "cond/Const:output:0"
      attr {
        key: "T"
        value {
          type: DT_BOOL
        }
      }
      experimental_debug_info {
        original_node_names: "cond/Identity"
      }
    }
    node_def {
      name: "cond/Identity_1"
      op: "Identity"
      input: "cond_identity_1_x"
      attr {
        key: "T"
        value {
          type: DT_INT32
        }
      }
      experimental_debug_info {
        original_node_names: "cond/Identity_1"
      }
    }
    ret {
      key: "cond_identity"
      value: "cond/Identity:output:0"
    }
    ret {
      key: "cond_identity_1"
      value: "cond/Identity_1:output:0"
    }
    attr {
      key: "_construction_context"
      value {
        s: "kEagerRuntime"
      }
    }
    arg_attr {
      key: 0
      value {
        attr {
          key: "_output_shapes"
          value {
            list {
              shape {
              }
            }
          }
        }
      }
    }
  }
}
versions {
  producer: 808
  min_consumer: 12
}

רוב הזמן, tf.function יעבוד בלי שיקולים מיוחדים. עם זאת, ישנם כמה אזהרות, ואת מדריך tf.function יכול לעזור כאן, כמו גם התייחסות חתימה המלאה

פולימורפיזם: אחד Function , גרפים רבים

tf.Graph מתמחה לסוג מסוים של תשומות (למשל, tensors עם ספציפיים dtype או חפצים עם אותו id() ).

בכל פעם שאתה לעורר Function עם חדש dtypes וצורות בטיעוניה, Function יוצרת חדשה tf.Graph עבור טיעונים חדשים. dtypes וצורות של tf.Graph תשומות של" ידועים כמו חתימה קלט או סתם חתימה.

Function וחנויות tf.Graph מתאימה חתימה כי בתוך ConcreteFunction . ConcreteFunction הוא מעטפת סביב tf.Graph .

@tf.function
def my_relu(x):
  return tf.maximum(0., x)

# `my_relu` creates new graphs as it observes more signatures.
print(my_relu(tf.constant(5.5)))
print(my_relu([1, -1]))
print(my_relu(tf.constant([3., -3.])))
tf.Tensor(5.5, shape=(), dtype=float32)
tf.Tensor([1. 0.], shape=(2,), dtype=float32)
tf.Tensor([3. 0.], shape=(2,), dtype=float32)

אם Function כבר נקרא עם חתימה, Function אינו יוצר חדש tf.Graph .

# These two calls do *not* create new graphs.
print(my_relu(tf.constant(-2.5))) # Signature matches `tf.constant(5.5)`.
print(my_relu(tf.constant([-1., 1.]))) # Signature matches `tf.constant([3., -3.])`.
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor([0. 1.], shape=(2,), dtype=float32)

בגלל זה מגובה על ידי גרפים מרובים, Function היא פולימורפיים. זה מאפשר לה לתמוך בסוגים קלט יותר בודד tf.Graph יכול לייצג, כמו גם כדי לייעל כל tf.Graph עבור ביצועים טובים יותר.

# There are three `ConcreteFunction`s (one for each graph) in `my_relu`.
# The `ConcreteFunction` also knows the return type and shape!
print(my_relu.pretty_printed_concrete_signatures())
my_relu(x)
  Args:
    x: float32 Tensor, shape=()
  Returns:
    float32 Tensor, shape=()

my_relu(x=[1, -1])
  Returns:
    float32 Tensor, shape=(2,)

my_relu(x)
  Args:
    x: float32 Tensor, shape=(2,)
  Returns:
    float32 Tensor, shape=(2,)

שימוש tf.function

עד כה, למד כיצד להמיר פונקצית Python לתוך גרף פשוט באמצעות tf.function בתור מעצב או מעטפת. אבל בפועל, מקבל tf.function לעבודה נכונה יכול להיות מסובך! בפרקים הבאים, תלמד איך אתה יכול לעשות את העבודה שלך קוד כצפוי עם tf.function .

ביצוע גרף לעומת ביצוע להוט

הקוד בתוך Function ניתן לבצע בשתי בשקיקה ו כגרף. כברירת מחדל, Function מבצעת הקוד שלה כגרף:

@tf.function
def get_MSE(y_true, y_pred):
  sq_diff = tf.pow(y_true - y_pred, 2)
  return tf.reduce_mean(sq_diff)
y_true = tf.random.uniform([5], maxval=10, dtype=tf.int32)
y_pred = tf.random.uniform([5], maxval=10, dtype=tf.int32)
print(y_true)
print(y_pred)
tf.Tensor([6 1 7 8 0], shape=(5,), dtype=int32)
tf.Tensor([6 0 1 8 6], shape=(5,), dtype=int32)
get_MSE(y_true, y_pred)
<tf.Tensor: shape=(), dtype=int32, numpy=14>

כדי לוודא שלך Function הגרף של" עושה את אותו חישוב כמו פונקצית Python המקבילה שלה, אתה יכול לעשות את זה לפועל בשקיקה עם tf.config.run_functions_eagerly(True) . זהו מתג שמכבה Function היכולת של ליצור גרפים לרוץ, במקום ביצוע הקוד נורמאלי.

tf.config.run_functions_eagerly(True)
get_MSE(y_true, y_pred)
<tf.Tensor: shape=(), dtype=int32, numpy=14>
# Don't forget to set it back when you are done.
tf.config.run_functions_eagerly(False)

עם זאת, Function יכולה להתנהג אחרת תחת גרף וביצוע להוט. פייתון print הפונקציה היא דוגמה אחת לאופן שונה בשני המצבים הללו. המחאה של פלטה מה קורה כשמכניסים print הצהרה הפונקציה שלך ולקרוא אותו שוב ושוב.

@tf.function
def get_MSE(y_true, y_pred):
  print("Calculating MSE!")
  sq_diff = tf.pow(y_true - y_pred, 2)
  return tf.reduce_mean(sq_diff)

שימו לב למה מודפס:

error = get_MSE(y_true, y_pred)
error = get_MSE(y_true, y_pred)
error = get_MSE(y_true, y_pred)
Calculating MSE!

האם התפוקה מפתיעה? get_MSE מודפס רק פעם אחת למרות שזה נקרא שלוש פעמים.

כדי להסביר את print האמירה מבוצעת כאשר Function המפעילה את הקוד המקורי כדי ליצור את הגרף בתהליך המכונה "התחקות" . איתור לוכד את פעולות TensorFlow לתוך גרף, ואת print לא כבשו בגרף. הגרף אשר יבוצע עבור כל שלוש שיחות מבלי הרצת קוד פיתון שוב.

כבדיקת שפיות, בואו נשבית את ביצוע הגרפים כדי להשוות:

# Now, globally set everything to run eagerly to force eager execution.
tf.config.run_functions_eagerly(True)
# Observe what is printed below.
error = get_MSE(y_true, y_pred)
error = get_MSE(y_true, y_pred)
error = get_MSE(y_true, y_pred)
Calculating MSE!
Calculating MSE!
Calculating MSE!
tf.config.run_functions_eagerly(False)

print היא תופעת לוואי Python, ויש הבדלים נוספים שאתה צריך להיות מודע בעת המרת פונקציה לתוך Function .

ביצוע לא קפדני

ביצוע גרף מבצע רק את הפעולות הדרושות להפקת ההשפעות הניתנות לצפייה, הכוללות:

  • ערך ההחזרה של הפונקציה
  • תועדו תופעות לוואי ידועות כמו:

התנהגות זו מכונה בדרך כלל "ביצוע לא קפדני", ושונה מביצוע להוט, שעובר את כל פעולות התוכנית, נחוץ או לא.

בפרט, בדיקת שגיאות בזמן ריצה אינה נחשבת כאפקט שניתן לראות. אם מדלגים על פעולה מכיוון שהיא מיותרת, היא לא יכולה להעלות שגיאות בזמן ריצה.

בדוגמה הבאה, המבצע "מיותר" tf.gather הוא דילג במהלך הביצוע הגרף, כך שגיאת זמן ריצה InvalidArgumentError לא העלתה ככל שזה יהיה בביצוע להוט. אל תסתמך על העלאת שגיאה בזמן ביצוע גרף.

def unused_return_eager(x):
  # Get index 1 will fail when `len(x) == 1`
  tf.gather(x, [1]) # unused 
  return x

try:
  print(unused_return_eager(tf.constant([0.0])))
except tf.errors.InvalidArgumentError as e:
  # All operations are run during eager execution so an error is raised.
  print(f'{type(e).__name__}: {e}')
tf.Tensor([0.], shape=(1,), dtype=float32)
@tf.function
def unused_return_graph(x):
  tf.gather(x, [1]) # unused
  return x

# Only needed operations are run during graph exection. The error is not raised.
print(unused_return_graph(tf.constant([0.0])))
tf.Tensor([0.], shape=(1,), dtype=float32)

tf.function מומלצת

זה עלול לקחת קצת זמן להתרגל ההתנהגות של Function . כדי להתחיל לעבוד במהירות, אם זו פעם הראשונה צריכה לשחק עם קישוט פונקציות צעצוע עם @tf.function כדי לקבל ניסיון עם יוצאים מן להוט ביצוע גרף.

תכנון עבור tf.function עשוי להיות הפתרון הטוב ביותר עבור כתיבת תוכניות TensorFlow תואמת גרף. הנה כמה עצות:

  • מעבר בין ביצוע הלהוט גרף מוקדם ולעתים קרובות עם tf.config.run_functions_eagerly להצביע אם / כאשר שני המצבים לסטות.
  • צור tf.Variable מחוץ ים הפונקציה Python ולשנות אותם מבפנים. כנ"ל לגבי אובייקטים שמשתמשים tf.Variable , כמו keras.layers , keras.Model ים ו tf.optimizers .
  • הימנע כתיבת פונקציות תלוי במשתני Python חיצוניים , למעט tf.Variable ים ו Keras אובייקטים.
  • העדיפו לכתוב פונקציות שלוקחות טנזורים וסוגי TensorFlow אחרים כקלט. אתה יכול לעבור בסוגים חפצים אחרים אבל להיזהר !
  • לכלול בחישוב רב ככל האפשר תחת tf.function כדי למקסם את הרווח הביצועים. לדוגמה, לקשט שלב אימון שלם או את כל לולאת האימון.

רואים את המהירות

tf.function בדרך כלל משפר את הביצועים של הקוד שלך, אבל בסך-תאוצה תלוי בסוג של חישוב לך לרוץ. חישובים קטנים יכולים להיות נשלטים על ידי התקורה של קריאת גרף. אתה יכול למדוד את ההבדל בביצועים כך:

x = tf.random.uniform(shape=[10, 10], minval=-1, maxval=2, dtype=tf.dtypes.int32)

def power(x, y):
  result = tf.eye(10, dtype=tf.dtypes.int32)
  for _ in range(y):
    result = tf.matmul(x, result)
  return result
print("Eager execution:", timeit.timeit(lambda: power(x, 100), number=1000))
Eager execution: 2.0122516460000384
power_as_graph = tf.function(power)
print("Graph execution:", timeit.timeit(lambda: power_as_graph(x, 100), number=1000))
Graph execution: 0.6084441319999883

tf.function נפוץ להאיץ לולאות אימונים, ואתה יכול ללמוד עוד על זה ב כתיבת לולאת הכשרה מאפסת עם Keras.

ביצועים ופשרות

גרפים יכולים להאיץ את הקוד שלך, אבל לתהליך יצירתם יש תקורה מסוימת. עבור פונקציות מסוימות, יצירת הגרף לוקחת יותר זמן מאשר ביצוע הגרף. השקעה זו מוחזרת בדרך כלל במהירות עם שיפור הביצועים של הוצאות להורג עוקבות, אך חשוב להיות מודע לכך שהשלבים הראשונים של כל אימון מודלים גדולים יכולים להיות איטיים יותר עקב מעקב.

לא משנה כמה גדול הדגם שלך, אתה רוצה להימנע ממעקב לעתים קרובות. tf.function הנזכרים המדריך כיצד מפרט קלט סט ויכוחים מותחים שימוש כדי למנוע retracing. אם אתה מגלה שאתה מקבל ביצועים גרועים בצורה יוצאת דופן, מומלץ לבדוק אם אתה חוזר בטעות.

כאשר הוא Function התחקות?

כדי להבין מתי שלך Function היא התחקות, להוסיף print הצהרת הקוד שלה. ככלל אצבע, Function תריץ את print האמירה בכול פעם שהוא עוקב.

@tf.function
def a_function_with_python_side_effect(x):
  print("Tracing!") # An eager-only side effect.
  return x * x + tf.constant(2)

# This is traced the first time.
print(a_function_with_python_side_effect(tf.constant(2)))
# The second time through, you won't see the side effect.
print(a_function_with_python_side_effect(tf.constant(3)))
Tracing!
tf.Tensor(6, shape=(), dtype=int32)
tf.Tensor(11, shape=(), dtype=int32)
# This retraces each time the Python argument changes,
# as a Python argument could be an epoch count or other
# hyperparameter.
print(a_function_with_python_side_effect(2))
print(a_function_with_python_side_effect(3))
Tracing!
tf.Tensor(6, shape=(), dtype=int32)
Tracing!
tf.Tensor(11, shape=(), dtype=int32)

ארגומנטים חדשים של Python תמיד מפעילים יצירת גרף חדש, ומכאן המעקב הנוסף.

הצעדים הבאים

אתה יכול ללמוד עוד על tf.function בדף הפניה API ו-ידי ביצוע ביצועים טובים יותר עם tf.function מדריך.