אומדנים

הצג באתר TensorFlow.org הפעל בגוגל קולאב צפה במקור ב-GitHub הורד מחברת

מסמך זה מציג את tf.estimator - API ברמה גבוהה של TensorFlow. אומדנים כוללים את הפעולות הבאות:

  • הַדְרָכָה
  • הַעֲרָכָה
  • נְבוּאָה
  • ייצוא להגשה

TensorFlow מיישמת מספר אומדנים מוכנים מראש. אומדנים מותאמים אישית עדיין נתמכים, אבל בעיקר כאמצעי תאימות לאחור. אין להשתמש באומדנים מותאמים אישית עבור קוד חדש . כל האומדנים - מוכנים מראש או מותאמים אישית - הם מחלקות המבוססות על המחלקה tf.estimator.Estimator .

לדוגמא מהירה, נסה את מדריכי האומדן . לסקירה כללית של עיצוב ה-API, עיין בנייר הלבן .

להכין

pip install -U tensorflow_datasets
import tempfile
import os

import tensorflow as tf
import tensorflow_datasets as tfds

יתרונות

בדומה ל- tf.keras.Model , estimator הוא הפשטה ברמת המודל. ה- tf.estimator מספק כמה יכולות כרגע עדיין בפיתוח עבור tf.keras . אלו הם:

  • אימון מבוסס שרת פרמטרים
  • אינטגרציה מלאה של TFX

יכולות אומדנים

אומדנים מספקים את היתרונות הבאים:

  • אתה יכול להפעיל מודלים מבוססי Estimator על מארח מקומי או על סביבת ריבוי שרתים מבוזרת מבלי לשנות את המודל שלך. יתרה מזאת, אתה יכול להריץ מודלים מבוססי Estimator על CPUs, GPUs או TPUs מבלי לקוד מחדש את המודל שלך.
  • האומדנים מספקים לולאת אימון מבוזרת בטוחה השולטת כיצד ומתי:
    • לטעון מידע
    • לטפל בחריגים
    • צור קבצי מחסום והתאושש מתקלות
    • שמור סיכומים עבור TensorBoard

בעת כתיבת אפליקציה עם Estimators, עליך להפריד את צינור קלט הנתונים מהמודל. הפרדה זו מפשטת ניסויים עם מערכי נתונים שונים.

שימוש באומדנים מוכנים מראש

אומדנים מוכנים מראש מאפשרים לך לעבוד ברמה קונספטואלית גבוהה בהרבה מזו של ממשקי ה-API של TensorFlow הבסיסיים. אתה כבר לא צריך לדאוג לגבי יצירת הגרף החישובי או ההפעלות מכיוון שאומדנים מטפלים בכל ה"צנרת" עבורך. יתר על כן, אומדנים מוכנים מראש מאפשרים לך להתנסות בארכיטקטורות מודלים שונות על ידי ביצוע שינויים קוד מינימליים בלבד. tf.estimator.DNNClassifier , למשל, היא מחלקת Estimator שהוכנה מראש, המאמנת מודלים של סיווג המבוססים על רשתות עצביות צפופות ומזינות קדימה.

תוכנית TensorFlow הנשענת על אומדן שהוכן מראש מורכבת בדרך כלל מארבעת השלבים הבאים:

1. כתוב פונקציות קלט

לדוגמה, תוכל ליצור פונקציה אחת לייבוא ​​ערכת ההדרכה ופונקציה אחרת לייבא את ערכת הבדיקות. מעריכים מצפים שהקלטים שלהם יהיו מעוצבים כזוג אובייקטים:

  • מילון שבו המפתחות הם שמות תכונה והערכים הם Tensors (או SparseTensors) המכילים את נתוני התכונה המתאימים
  • טנסור המכיל תווית אחת או יותר

ה- input_fn צריך להחזיר tf.data.Dataset שמניב זוגות בפורמט הזה.

לדוגמה, הקוד הבא בונה tf.data.Dataset מקובץ train.csv של מערך הנתונים של Titanic:

def train_input_fn():
  titanic_file = tf.keras.utils.get_file("train.csv", "https://storage.googleapis.com/tf-datasets/titanic/train.csv")
  titanic = tf.data.experimental.make_csv_dataset(
      titanic_file, batch_size=32,
      label_name="survived")
  titanic_batches = (
      titanic.cache().repeat().shuffle(500)
      .prefetch(tf.data.AUTOTUNE))
  return titanic_batches

ה- input_fn מבוצע ב- tf.Graph ויכול גם להחזיר ישירות זוג (features_dics, labels) המכיל טנסור גרפים, אבל זה נוטה לשגיאה מחוץ למקרים פשוטים כמו החזרת קבועים.

2. הגדר את עמודות התכונה.

כל tf.feature_column מזהה שם תכונה, סוגה וכל עיבוד מקדים של קלט.

לדוגמה, הקטע הבא יוצר שלוש עמודות תכונה.

  • הראשון משתמש בתכונת age ישירות כקלט של נקודה צפה.
  • השני משתמש בתכונת class כקלט קטגורי.
  • השלישי משתמש ב- embark_town כקלט קטגורי, אך משתמש hashing trick כדי למנוע את הצורך למנות את האפשרויות ולקבוע את מספר האפשרויות.

למידע נוסף, עיין במדריך עמודות תכונה .

age = tf.feature_column.numeric_column('age')
cls = tf.feature_column.categorical_column_with_vocabulary_list('class', ['First', 'Second', 'Third']) 
embark = tf.feature_column.categorical_column_with_hash_bucket('embark_town', 32)

3. הצג את האומד הרלוונטי המוכן מראש.

לדוגמה, הנה דוגמה לדוגמה של הערכה מוכנה מראש בשם LinearClassifier :

model_dir = tempfile.mkdtemp()
model = tf.estimator.LinearClassifier(
    model_dir=model_dir,
    feature_columns=[embark, cls, age],
    n_classes=2
)
INFO:tensorflow:Using default config.
INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmpl24pp3cp', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}

למידע נוסף, תוכל לעבור למדריך המסווג הליניארי .

4. קרא לשיטת אימון, הערכה או מסקנות.

כל האומדנים מספקים שיטות train , evaluate predict .

model = model.train(input_fn=train_input_fn, steps=100)
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/training/training_util.py:236: Variable.initialized_value (from tensorflow.python.ops.variables) is deprecated and will be removed in a future version.
Instructions for updating:
Use Variable.read_value. Variables in 2.X are initialized automatically both in eager and graph (inside tf.defun) contexts.
INFO:tensorflow:Calling model_fn.
/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/keras/engine/base_layer_v1.py:1684: UserWarning: `layer.add_variable` is deprecated and will be removed in a future version. Please use `layer.add_weight` method instead.
  warnings.warn('`layer.add_variable` is deprecated and '
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/keras/optimizer_v2/ftrl.py:147: calling Constant.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.
Instructions for updating:
Call initializer instance with the dtype argument instead of passing it to the constructor
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmpl24pp3cp/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
INFO:tensorflow:loss = 0.6931472, step = 0
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 100...
INFO:tensorflow:Saving checkpoints for 100 into /tmp/tmpl24pp3cp/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 100...
INFO:tensorflow:Loss for final step: 0.6319582.
2021-09-22 20:49:10.453286: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
result = model.evaluate(train_input_fn, steps=10)

for key, value in result.items():
  print(key, ":", value)
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Starting evaluation at 2021-09-22T20:49:11
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /tmp/tmpl24pp3cp/model.ckpt-100
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Evaluation [1/10]
INFO:tensorflow:Evaluation [2/10]
INFO:tensorflow:Evaluation [3/10]
INFO:tensorflow:Evaluation [4/10]
INFO:tensorflow:Evaluation [5/10]
INFO:tensorflow:Evaluation [6/10]
INFO:tensorflow:Evaluation [7/10]
INFO:tensorflow:Evaluation [8/10]
INFO:tensorflow:Evaluation [9/10]
INFO:tensorflow:Evaluation [10/10]
INFO:tensorflow:Inference Time : 0.74609s
INFO:tensorflow:Finished evaluation at 2021-09-22-20:49:12
INFO:tensorflow:Saving dict for global step 100: accuracy = 0.734375, accuracy_baseline = 0.640625, auc = 0.7373913, auc_precision_recall = 0.64306235, average_loss = 0.563341, global_step = 100, label/mean = 0.359375, loss = 0.563341, precision = 0.734375, prediction/mean = 0.3463129, recall = 0.40869564
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 100: /tmp/tmpl24pp3cp/model.ckpt-100
accuracy : 0.734375
accuracy_baseline : 0.640625
auc : 0.7373913
auc_precision_recall : 0.64306235
average_loss : 0.563341
label/mean : 0.359375
loss : 0.563341
precision : 0.734375
prediction/mean : 0.3463129
recall : 0.40869564
global_step : 100
2021-09-22 20:49:12.168629: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
for pred in model.predict(train_input_fn):
  for key, value in pred.items():
    print(key, ":", value)
  break
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /tmp/tmpl24pp3cp/model.ckpt-100
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
logits : [-1.5173098]
logistic : [0.17985801]
probabilities : [0.820142   0.17985801]
class_ids : [0]
classes : [b'0']
all_class_ids : [0 1]
all_classes : [b'0' b'1']
2021-09-22 20:49:13.076528: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.

היתרונות של אומדנים מוכנים מראש

מעריכים מוכנים מראש מקודדים שיטות עבודה מומלצות, ומספקים את היתרונות הבאים:

  • שיטות עבודה מומלצות לקביעה היכן חלקים שונים של הגרף החישובי צריכים לפעול, יישום אסטרטגיות במכונה בודדת או באשכול.
  • שיטות עבודה מומלצות לכתיבת אירועים (סיכום) וסיכומים שימושיים אוניברסליים.

אם אינך משתמש באומדנים מוכנים מראש, עליך ליישם את התכונות הקודמות בעצמך.

אומדנים מותאמים אישית

הלב של כל הערכה - בין אם הוא מוכן מראש או בהתאמה אישית - הוא פונקציית המודל שלו, model_fn , שהיא שיטה שבונה גרפים לאימון, הערכה וחיזוי. כאשר אתה משתמש באומדן שהוכן מראש, מישהו אחר כבר יישם את פונקציית המודל. כאשר מסתמכים על אומדן מותאם אישית, עליך לכתוב את פונקציית המודל בעצמך.

צור מעריך מדגם Keras

אתה יכול להמיר מודלים קיימים של Keras ל-Estimators עם tf.keras.estimator.model_to_estimator . זה מועיל אם ברצונך לחדש את קוד הדגם שלך, אך צינור ההדרכה שלך עדיין דורש מעריכים.

הצג דגם של Keras MobileNet V2 והרכב את המודל עם כלי האופטימיזציה, האובדן והמדדים להתאמן איתם:

keras_mobilenet_v2 = tf.keras.applications.MobileNetV2(
    input_shape=(160, 160, 3), include_top=False)
keras_mobilenet_v2.trainable = False

estimator_model = tf.keras.Sequential([
    keras_mobilenet_v2,
    tf.keras.layers.GlobalAveragePooling2D(),
    tf.keras.layers.Dense(1)
])

# Compile the model
estimator_model.compile(
    optimizer='adam',
    loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
    metrics=['accuracy'])
Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/mobilenet_v2/mobilenet_v2_weights_tf_dim_ordering_tf_kernels_1.0_160_no_top.h5
9412608/9406464 [==============================] - 0s 0us/step
9420800/9406464 [==============================] - 0s 0us/step

צור Estimator ממודל Keras המלוקט. מצב המודל הראשוני של מודל Estimator נשמר באומדן שנוצר:

est_mobilenet_v2 = tf.keras.estimator.model_to_estimator(keras_model=estimator_model)
INFO:tensorflow:Using default config.
WARNING:tensorflow:Using temporary folder as model directory: /tmp/tmpmosnmied
INFO:tensorflow:Using the Keras model provided.
/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/keras/backend.py:401: UserWarning: `tf.keras.backend.set_learning_phase` is deprecated and will be removed after 2020-10-11. To update it, simply pass a True/False value to the `training` argument of the `__call__` method of your layer or model.
  warnings.warn('`tf.keras.backend.set_learning_phase` is deprecated and '
/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/keras/utils/generic_utils.py:497: CustomMaskWarning: Custom mask layers require a config and must override get_config. When loading, the custom mask layer must be passed to the custom_objects argument.
  category=CustomMaskWarning)
INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmpmosnmied', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}

התייחס Estimator הנגזר כפי שהיית מתייחס לכל Estimator אחר.

IMG_SIZE = 160  # All images will be resized to 160x160

def preprocess(image, label):
  image = tf.cast(image, tf.float32)
  image = (image/127.5) - 1
  image = tf.image.resize(image, (IMG_SIZE, IMG_SIZE))
  return image, label
def train_input_fn(batch_size):
  data = tfds.load('cats_vs_dogs', as_supervised=True)
  train_data = data['train']
  train_data = train_data.map(preprocess).shuffle(500).batch(batch_size)
  return train_data

כדי להתאמן, התקשר לפונקציית הרכבת של Estimator:

est_mobilenet_v2.train(input_fn=lambda: train_input_fn(32), steps=50)
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Warm-starting with WarmStartSettings: WarmStartSettings(ckpt_to_initialize_from='/tmp/tmpmosnmied/keras/keras_model.ckpt', vars_to_warm_start='.*', var_name_to_vocab_info={}, var_name_to_prev_var_name={})
INFO:tensorflow:Warm-starting with WarmStartSettings: WarmStartSettings(ckpt_to_initialize_from='/tmp/tmpmosnmied/keras/keras_model.ckpt', vars_to_warm_start='.*', var_name_to_vocab_info={}, var_name_to_prev_var_name={})
INFO:tensorflow:Warm-starting from: /tmp/tmpmosnmied/keras/keras_model.ckpt
INFO:tensorflow:Warm-starting from: /tmp/tmpmosnmied/keras/keras_model.ckpt
INFO:tensorflow:Warm-starting variables only in TRAINABLE_VARIABLES.
INFO:tensorflow:Warm-starting variables only in TRAINABLE_VARIABLES.
INFO:tensorflow:Warm-started 158 variables.
INFO:tensorflow:Warm-started 158 variables.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmpmosnmied/model.ckpt.
INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmpmosnmied/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
INFO:tensorflow:loss = 0.6994096, step = 0
INFO:tensorflow:loss = 0.6994096, step = 0
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 50...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 50...
INFO:tensorflow:Saving checkpoints for 50 into /tmp/tmpmosnmied/model.ckpt.
INFO:tensorflow:Saving checkpoints for 50 into /tmp/tmpmosnmied/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 50...
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 50...
INFO:tensorflow:Loss for final step: 0.68789804.
INFO:tensorflow:Loss for final step: 0.68789804.
<tensorflow_estimator.python.estimator.estimator.EstimatorV2 at 0x7f4b1c1e9890>

באופן דומה, כדי להעריך, קרא לפונקציית ההערכה של האומד:

est_mobilenet_v2.evaluate(input_fn=lambda: train_input_fn(32), steps=10)
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Calling model_fn.
/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/keras/engine/training.py:2470: UserWarning: `Model.state_updates` will be removed in a future version. This property should not be used in TensorFlow 2.0, as `updates` are applied automatically.
  warnings.warn('`Model.state_updates` will be removed in a future version. '
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Starting evaluation at 2021-09-22T20:49:36
INFO:tensorflow:Starting evaluation at 2021-09-22T20:49:36
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /tmp/tmpmosnmied/model.ckpt-50
INFO:tensorflow:Restoring parameters from /tmp/tmpmosnmied/model.ckpt-50
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Evaluation [1/10]
INFO:tensorflow:Evaluation [1/10]
INFO:tensorflow:Evaluation [2/10]
INFO:tensorflow:Evaluation [2/10]
INFO:tensorflow:Evaluation [3/10]
INFO:tensorflow:Evaluation [3/10]
INFO:tensorflow:Evaluation [4/10]
INFO:tensorflow:Evaluation [4/10]
INFO:tensorflow:Evaluation [5/10]
INFO:tensorflow:Evaluation [5/10]
INFO:tensorflow:Evaluation [6/10]
INFO:tensorflow:Evaluation [6/10]
INFO:tensorflow:Evaluation [7/10]
INFO:tensorflow:Evaluation [7/10]
INFO:tensorflow:Evaluation [8/10]
INFO:tensorflow:Evaluation [8/10]
INFO:tensorflow:Evaluation [9/10]
INFO:tensorflow:Evaluation [9/10]
INFO:tensorflow:Evaluation [10/10]
INFO:tensorflow:Evaluation [10/10]
INFO:tensorflow:Inference Time : 3.89658s
INFO:tensorflow:Inference Time : 3.89658s
INFO:tensorflow:Finished evaluation at 2021-09-22-20:49:39
INFO:tensorflow:Finished evaluation at 2021-09-22-20:49:39
INFO:tensorflow:Saving dict for global step 50: accuracy = 0.525, global_step = 50, loss = 0.6723582
INFO:tensorflow:Saving dict for global step 50: accuracy = 0.525, global_step = 50, loss = 0.6723582
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 50: /tmp/tmpmosnmied/model.ckpt-50
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 50: /tmp/tmpmosnmied/model.ckpt-50
{'accuracy': 0.525, 'loss': 0.6723582, 'global_step': 50}

לפרטים נוספים, עיין בתיעוד עבור tf.keras.estimator.model_to_estimator .

שמירת מחסומים מבוססי אובייקטים עם Estimator

מעריכים כברירת מחדל שומרים נקודות ביקורת עם שמות משתנים במקום עם גרף האובייקטים המתואר במדריך המחסום . tf.train.Checkpoint יקרא נקודות ביקורת מבוססות שמות, אך שמות משתנים עשויים להשתנות בעת הזזת חלקים של מודל מחוץ ל- model_fn של האומד. עבור תאימות קדימה, שמירת נקודות ביקורת מבוססות אובייקטים מקלה על אימון מודל בתוך Estimator ולאחר מכן להשתמש בו מחוץ לאחד.

import tensorflow.compat.v1 as tf_compat
def toy_dataset():
  inputs = tf.range(10.)[:, None]
  labels = inputs * 5. + tf.range(5.)[None, :]
  return tf.data.Dataset.from_tensor_slices(
    dict(x=inputs, y=labels)).repeat().batch(2)
class Net(tf.keras.Model):
  """A simple linear model."""

  def __init__(self):
    super(Net, self).__init__()
    self.l1 = tf.keras.layers.Dense(5)

  def call(self, x):
    return self.l1(x)
def model_fn(features, labels, mode):
  net = Net()
  opt = tf.keras.optimizers.Adam(0.1)
  ckpt = tf.train.Checkpoint(step=tf_compat.train.get_global_step(),
                             optimizer=opt, net=net)
  with tf.GradientTape() as tape:
    output = net(features['x'])
    loss = tf.reduce_mean(tf.abs(output - features['y']))
  variables = net.trainable_variables
  gradients = tape.gradient(loss, variables)
  return tf.estimator.EstimatorSpec(
    mode,
    loss=loss,
    train_op=tf.group(opt.apply_gradients(zip(gradients, variables)),
                      ckpt.step.assign_add(1)),
    # Tell the Estimator to save "ckpt" in an object-based format.
    scaffold=tf_compat.train.Scaffold(saver=ckpt))

tf.keras.backend.clear_session()
est = tf.estimator.Estimator(model_fn, './tf_estimator_example/')
est.train(toy_dataset, steps=10)
INFO:tensorflow:Using default config.
INFO:tensorflow:Using default config.
INFO:tensorflow:Using config: {'_model_dir': './tf_estimator_example/', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
INFO:tensorflow:Using config: {'_model_dir': './tf_estimator_example/', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Saving checkpoints for 0 into ./tf_estimator_example/model.ckpt.
INFO:tensorflow:Saving checkpoints for 0 into ./tf_estimator_example/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
INFO:tensorflow:loss = 4.659403, step = 0
INFO:tensorflow:loss = 4.659403, step = 0
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 10...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 10...
INFO:tensorflow:Saving checkpoints for 10 into ./tf_estimator_example/model.ckpt.
INFO:tensorflow:Saving checkpoints for 10 into ./tf_estimator_example/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 10...
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 10...
INFO:tensorflow:Loss for final step: 39.58891.
INFO:tensorflow:Loss for final step: 39.58891.
<tensorflow_estimator.python.estimator.estimator.EstimatorV2 at 0x7f4b7c451fd0>

לאחר מכן, tf.train.Checkpoint יכול לטעון את נקודות הבידוק של האומד מה- model_dir שלו.

opt = tf.keras.optimizers.Adam(0.1)
net = Net()
ckpt = tf.train.Checkpoint(
  step=tf.Variable(1, dtype=tf.int64), optimizer=opt, net=net)
ckpt.restore(tf.train.latest_checkpoint('./tf_estimator_example/'))
ckpt.step.numpy()  # From est.train(..., steps=10)
10

SavedModels מאת מעריכים

מעריכים מייצאים SavedModels דרך tf.Estimator.export_saved_model .

input_column = tf.feature_column.numeric_column("x")

estimator = tf.estimator.LinearClassifier(feature_columns=[input_column])

def input_fn():
  return tf.data.Dataset.from_tensor_slices(
    ({"x": [1., 2., 3., 4.]}, [1, 1, 0, 0])).repeat(200).shuffle(64).batch(16)
estimator.train(input_fn)
INFO:tensorflow:Using default config.
INFO:tensorflow:Using default config.
WARNING:tensorflow:Using temporary folder as model directory: /tmp/tmp30_d7xz6
WARNING:tensorflow:Using temporary folder as model directory: /tmp/tmp30_d7xz6
INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmp30_d7xz6', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmp30_d7xz6', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmp30_d7xz6/model.ckpt.
INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmp30_d7xz6/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
INFO:tensorflow:loss = 0.6931472, step = 0
INFO:tensorflow:loss = 0.6931472, step = 0
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 50...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 50...
INFO:tensorflow:Saving checkpoints for 50 into /tmp/tmp30_d7xz6/model.ckpt.
INFO:tensorflow:Saving checkpoints for 50 into /tmp/tmp30_d7xz6/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 50...
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 50...
INFO:tensorflow:Loss for final step: 0.4022895.
INFO:tensorflow:Loss for final step: 0.4022895.
<tensorflow_estimator.python.estimator.canned.linear.LinearClassifierV2 at 0x7f4b1c10fd10>

כדי לשמור Estimator אתה צריך ליצור serving_input_receiver . פונקציה זו בונה חלק מ- tf.Graph את הנתונים הגולמיים שהתקבלו על ידי SavedModel.

מודול tf.estimator.export מכיל פונקציות שיעזרו לבנות receivers אלה.

הקוד הבא בונה מקלט, המבוסס על feature_columns , שמקבל מאגרי פרוטוקול tf.Example בסידרה, המשמשים לעתים קרובות עם tf-serving .

tmpdir = tempfile.mkdtemp()

serving_input_fn = tf.estimator.export.build_parsing_serving_input_receiver_fn(
  tf.feature_column.make_parse_example_spec([input_column]))

estimator_base_path = os.path.join(tmpdir, 'from_estimator')
estimator_path = estimator.export_saved_model(estimator_base_path, serving_input_fn)
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Done calling model_fn.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/saved_model/signature_def_utils_impl.py:145: build_tensor_info (from tensorflow.python.saved_model.utils_impl) is deprecated and will be removed in a future version.
Instructions for updating:
This function will only be available through the v1 compatibility library as tf.compat.v1.saved_model.utils.build_tensor_info or tf.compat.v1.saved_model.build_tensor_info.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/saved_model/signature_def_utils_impl.py:145: build_tensor_info (from tensorflow.python.saved_model.utils_impl) is deprecated and will be removed in a future version.
Instructions for updating:
This function will only be available through the v1 compatibility library as tf.compat.v1.saved_model.utils.build_tensor_info or tf.compat.v1.saved_model.build_tensor_info.
INFO:tensorflow:Signatures INCLUDED in export for Classify: ['serving_default', 'classification']
INFO:tensorflow:Signatures INCLUDED in export for Classify: ['serving_default', 'classification']
INFO:tensorflow:Signatures INCLUDED in export for Regress: ['regression']
INFO:tensorflow:Signatures INCLUDED in export for Regress: ['regression']
INFO:tensorflow:Signatures INCLUDED in export for Predict: ['predict']
INFO:tensorflow:Signatures INCLUDED in export for Predict: ['predict']
INFO:tensorflow:Signatures INCLUDED in export for Train: None
INFO:tensorflow:Signatures INCLUDED in export for Train: None
INFO:tensorflow:Signatures INCLUDED in export for Eval: None
INFO:tensorflow:Signatures INCLUDED in export for Eval: None
INFO:tensorflow:Restoring parameters from /tmp/tmp30_d7xz6/model.ckpt-50
INFO:tensorflow:Restoring parameters from /tmp/tmp30_d7xz6/model.ckpt-50
INFO:tensorflow:Assets added to graph.
INFO:tensorflow:Assets added to graph.
INFO:tensorflow:No assets to write.
INFO:tensorflow:No assets to write.
INFO:tensorflow:SavedModel written to: /tmp/tmpi_szzuj1/from_estimator/temp-1632343781/saved_model.pb
INFO:tensorflow:SavedModel written to: /tmp/tmpi_szzuj1/from_estimator/temp-1632343781/saved_model.pb

אתה יכול גם לטעון ולהפעיל את המודל הזה, מ-python:

imported = tf.saved_model.load(estimator_path)

def predict(x):
  example = tf.train.Example()
  example.features.feature["x"].float_list.value.extend([x])
  return imported.signatures["predict"](
    examples=tf.constant([example.SerializeToString()]))
print(predict(1.5))
print(predict(3.5))
{'class_ids': <tf.Tensor: shape=(1, 1), dtype=int64, numpy=array([[1]])>, 'classes': <tf.Tensor: shape=(1, 1), dtype=string, numpy=array([[b'1']], dtype=object)>, 'all_classes': <tf.Tensor: shape=(1, 2), dtype=string, numpy=array([[b'0', b'1']], dtype=object)>, 'all_class_ids': <tf.Tensor: shape=(1, 2), dtype=int32, numpy=array([[0, 1]], dtype=int32)>, 'logits': <tf.Tensor: shape=(1, 1), dtype=float32, numpy=array([[0.2974025]], dtype=float32)>, 'logistic': <tf.Tensor: shape=(1, 1), dtype=float32, numpy=array([[0.5738074]], dtype=float32)>, 'probabilities': <tf.Tensor: shape=(1, 2), dtype=float32, numpy=array([[0.42619258, 0.5738074 ]], dtype=float32)>}
{'class_ids': <tf.Tensor: shape=(1, 1), dtype=int64, numpy=array([[0]])>, 'classes': <tf.Tensor: shape=(1, 1), dtype=string, numpy=array([[b'0']], dtype=object)>, 'all_classes': <tf.Tensor: shape=(1, 2), dtype=string, numpy=array([[b'0', b'1']], dtype=object)>, 'all_class_ids': <tf.Tensor: shape=(1, 2), dtype=int32, numpy=array([[0, 1]], dtype=int32)>, 'logits': <tf.Tensor: shape=(1, 1), dtype=float32, numpy=array([[-1.1919093]], dtype=float32)>, 'logistic': <tf.Tensor: shape=(1, 1), dtype=float32, numpy=array([[0.23291764]], dtype=float32)>, 'probabilities': <tf.Tensor: shape=(1, 2), dtype=float32, numpy=array([[0.7670824 , 0.23291762]], dtype=float32)>}

tf.estimator.export.build_raw_serving_input_receiver_fn מאפשר לך ליצור פונקציות קלט שלוקחות טנזורים גולמיים ולא tf.train.Example s.

שימוש ב- tf.distribute.Strategy עם Estimator (תמיכה מוגבלת)

tf.estimator הוא API של הדרכה מבוזר של TensorFlow שתמך במקור בגישת שרת פרמטרים אסינכרוניים. tf.estimator תומך כעת ב- tf.distribute.Strategy . אם אתה משתמש ב- tf.estimator , אתה יכול לעבור לאימון מבוזר עם מעט מאוד שינויים בקוד שלך. עם זה, משתמשי Estimator יכולים כעת לבצע אימון מבוזר סינכרוני על מספר GPUs ומספר עובדים, כמו גם להשתמש ב-TPUs. עם זאת, תמיכה זו באומדן מוגבלת. עיין בסעיף מה נתמך עכשיו למטה לפרטים נוספים.

השימוש ב- tf.distribute.Strategy עם Estimator שונה במקצת מאשר במקרה של Keras. במקום להשתמש ב-strategi.scope, כעת אתה מעביר את אובייקט strategy.scope ל- RunConfig עבור האומד.

אתה יכול לעיין במדריך ההדרכה המבוזר למידע נוסף.

להלן קטע קוד שמראה זאת עם Estimator LinearRegressor ו- MirroredStrategy שהוכן מראש:

mirrored_strategy = tf.distribute.MirroredStrategy()
config = tf.estimator.RunConfig(
    train_distribute=mirrored_strategy, eval_distribute=mirrored_strategy)
regressor = tf.estimator.LinearRegressor(
    feature_columns=[tf.feature_column.numeric_column('feats')],
    optimizer='SGD',
    config=config)
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)
INFO:tensorflow:Initializing RunConfig with distribution strategies.
INFO:tensorflow:Initializing RunConfig with distribution strategies.
INFO:tensorflow:Not using Distribute Coordinator.
INFO:tensorflow:Not using Distribute Coordinator.
WARNING:tensorflow:Using temporary folder as model directory: /tmp/tmpftw63jyd
WARNING:tensorflow:Using temporary folder as model directory: /tmp/tmpftw63jyd
INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmpftw63jyd', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': <tensorflow.python.distribute.mirrored_strategy.MirroredStrategy object at 0x7f4b0c04c050>, '_device_fn': None, '_protocol': None, '_eval_distribute': <tensorflow.python.distribute.mirrored_strategy.MirroredStrategy object at 0x7f4b0c04c050>, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1, '_distribute_coordinator_mode': None}
INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmpftw63jyd', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': <tensorflow.python.distribute.mirrored_strategy.MirroredStrategy object at 0x7f4b0c04c050>, '_device_fn': None, '_protocol': None, '_eval_distribute': <tensorflow.python.distribute.mirrored_strategy.MirroredStrategy object at 0x7f4b0c04c050>, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1, '_distribute_coordinator_mode': None}

כאן, אתה משתמש באומד מראש, אבל אותו קוד עובד גם עם אומד מותאם אישית. train_distribute קובע כיצד ההדרכה תחולק, ו- eval_distribute קובעת כיצד תחולק הערכה. זהו הבדל נוסף מ-Keras שבו אתה משתמש באותה אסטרטגיה גם לאימון וגם לאיוואל.

כעת אתה יכול לאמן ולהעריך את האומד הזה עם פונקציית קלט:

def input_fn():
  dataset = tf.data.Dataset.from_tensors(({"feats":[1.]}, [1.]))
  return dataset.repeat(1000).batch(10)
regressor.train(input_fn=input_fn, steps=10)
regressor.evaluate(input_fn=input_fn, steps=10)
INFO:tensorflow:Calling model_fn.
/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/data/ops/dataset_ops.py:374: UserWarning: To make it possible to preserve tf.data options across serialization boundaries, their implementation has moved to be part of the TensorFlow graph. As a consequence, the options value is in general no longer known at graph construction time. Invoking this method in graph mode retains the legacy behavior of the original implementation, but note that the returned value might not reflect the actual value of the options.
  warnings.warn("To make it possible to preserve tf.data options across "
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Create CheckpointSaverHook.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow_estimator/python/estimator/util.py:95: DistributedIteratorV1.initialize (from tensorflow.python.distribute.input_lib) is deprecated and will be removed in a future version.
Instructions for updating:
Use the iterator's `initializer` property instead.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow_estimator/python/estimator/util.py:95: DistributedIteratorV1.initialize (from tensorflow.python.distribute.input_lib) is deprecated and will be removed in a future version.
Instructions for updating:
Use the iterator's `initializer` property instead.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmpftw63jyd/model.ckpt.
INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmpftw63jyd/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
2021-09-22 20:49:45.706166: W tensorflow/core/grappler/utils/graph_view.cc:836] No registered 'MultiDeviceIteratorFromStringHandle' OpKernel for GPU devices compatible with node { {node MultiDeviceIteratorFromStringHandle} }
    .  Registered:  device='CPU'

2021-09-22 20:49:45.707521: W tensorflow/core/grappler/utils/graph_view.cc:836] No registered 'MultiDeviceIteratorGetNextFromShard' OpKernel for GPU devices compatible with node { {node MultiDeviceIteratorGetNextFromShard} }
    .  Registered:  device='CPU'
INFO:tensorflow:loss = 1.0, step = 0
INFO:tensorflow:loss = 1.0, step = 0
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 10...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 10...
INFO:tensorflow:Saving checkpoints for 10 into /tmp/tmpftw63jyd/model.ckpt.
INFO:tensorflow:Saving checkpoints for 10 into /tmp/tmpftw63jyd/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 10...
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 10...
INFO:tensorflow:Loss for final step: 2.877698e-13.
INFO:tensorflow:Loss for final step: 2.877698e-13.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Starting evaluation at 2021-09-22T20:49:46
INFO:tensorflow:Starting evaluation at 2021-09-22T20:49:46
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /tmp/tmpftw63jyd/model.ckpt-10
INFO:tensorflow:Restoring parameters from /tmp/tmpftw63jyd/model.ckpt-10
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Done running local_init_op.
2021-09-22 20:49:46.680821: W tensorflow/core/grappler/utils/graph_view.cc:836] No registered 'MultiDeviceIteratorFromStringHandle' OpKernel for GPU devices compatible with node { {node MultiDeviceIteratorFromStringHandle} }
    .  Registered:  device='CPU'

2021-09-22 20:49:46.682161: W tensorflow/core/grappler/utils/graph_view.cc:836] No registered 'MultiDeviceIteratorGetNextFromShard' OpKernel for GPU devices compatible with node { {node MultiDeviceIteratorGetNextFromShard} }
    .  Registered:  device='CPU'
INFO:tensorflow:Evaluation [1/10]
INFO:tensorflow:Evaluation [1/10]
INFO:tensorflow:Evaluation [2/10]
INFO:tensorflow:Evaluation [2/10]
INFO:tensorflow:Evaluation [3/10]
INFO:tensorflow:Evaluation [3/10]
INFO:tensorflow:Evaluation [4/10]
INFO:tensorflow:Evaluation [4/10]
INFO:tensorflow:Evaluation [5/10]
INFO:tensorflow:Evaluation [5/10]
INFO:tensorflow:Evaluation [6/10]
INFO:tensorflow:Evaluation [6/10]
INFO:tensorflow:Evaluation [7/10]
INFO:tensorflow:Evaluation [7/10]
INFO:tensorflow:Evaluation [8/10]
INFO:tensorflow:Evaluation [8/10]
INFO:tensorflow:Evaluation [9/10]
INFO:tensorflow:Evaluation [9/10]
INFO:tensorflow:Evaluation [10/10]
INFO:tensorflow:Evaluation [10/10]
INFO:tensorflow:Inference Time : 0.26514s
INFO:tensorflow:Inference Time : 0.26514s
INFO:tensorflow:Finished evaluation at 2021-09-22-20:49:46
INFO:tensorflow:Finished evaluation at 2021-09-22-20:49:46
INFO:tensorflow:Saving dict for global step 10: average_loss = 1.4210855e-14, global_step = 10, label/mean = 1.0, loss = 1.4210855e-14, prediction/mean = 0.99999994
INFO:tensorflow:Saving dict for global step 10: average_loss = 1.4210855e-14, global_step = 10, label/mean = 1.0, loss = 1.4210855e-14, prediction/mean = 0.99999994
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 10: /tmp/tmpftw63jyd/model.ckpt-10
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 10: /tmp/tmpftw63jyd/model.ckpt-10
{'average_loss': 1.4210855e-14,
 'label/mean': 1.0,
 'loss': 1.4210855e-14,
 'prediction/mean': 0.99999994,
 'global_step': 10}

הבדל נוסף שיש להדגיש כאן בין Estimator ל-Keras הוא הטיפול בקלט. ב-Keras, כל אצווה של מערך הנתונים מפוצלת באופן אוטומטי על פני ההעתקים המרובים. עם זאת, ב-Estimator, אינך מבצע פיצול אצווה אוטומטי, ואינו מפיץ את הנתונים באופן אוטומטי בין עובדים שונים. יש לך שליטה מלאה על האופן שבו אתה רוצה שהנתונים שלך יופצו על פני עובדים ומכשירים, ועליך לספק input_fn כדי לציין כיצד להפיץ את הנתונים שלך.

input_fn שלך נקרא פעם אחת לכל עובד, ובכך נותן מערך נתונים אחד לכל עובד. אז אצווה אחת מאותו מערך נתונים מוזנת לעתק אחד באותו עובד, ובכך צורכת N אצווה עבור N העתקים בעובד אחד. במילים אחרות, מערך הנתונים המוחזר על ידי ה- input_fn צריך לספק אצוות בגודל PER_REPLICA_BATCH_SIZE . וניתן לקבל את גודל האצווה הגלובלי עבור שלב כ- PER_REPLICA_BATCH_SIZE * strategy.num_replicas_in_sync .

בעת ביצוע הכשרה מרובה עובדים, עליך לפצל את הנתונים שלך בין העובדים, או לערבב עם סיד אקראי על כל אחד מהם. אתה יכול לבדוק דוגמה כיצד לעשות זאת במדריך הדרכה ריבוי עובדים עם הערכה .

ובאופן דומה, אתה יכול להשתמש גם באסטרטגיות ריבוי עובדים ושרת פרמטרים. הקוד נשאר זהה, אבל עליך להשתמש ב- tf.estimator.train_and_evaluate ולהגדיר משתני סביבה TF_CONFIG עבור כל בינארי הפועל באשכול שלך.

מה נתמך עכשיו?

קיימת תמיכה מוגבלת לאימון עם Estimator תוך שימוש בכל האסטרטגיות למעט TPUStrategy . הכשרה והערכה בסיסית אמורות לעבוד, אך מספר תכונות מתקדמות כגון v1.train.Scaffold לא. ייתכנו גם מספר באגים באינטגרציה הזו ואין תוכניות לשפר את התמיכה הזו באופן אקטיבי (ההתמקדות היא ב-Keras ותמיכה בלולאת אימון מותאמת אישית). אם זה אפשרי, עליך להעדיף להשתמש ב- tf.distribute עם ממשקי API אלה במקום זאת.

הדרכה API אסטרטגיית מראה אסטרטגיה של TPUS אסטרטגיית MultiWorkerMirrored אסטרטגיית אחסון מרכזית ParameterServerStrategy
API של Estimator תמיכה מוגבלת אינו נתמך תמיכה מוגבלת תמיכה מוגבלת תמיכה מוגבלת

דוגמאות והדרכות

הנה כמה דוגמאות מקצה לקצה שמראות כיצד להשתמש באסטרטגיות שונות עם אומדן:

  1. ערכת הדרכה מרובה עובדים עם אומדן מראה כיצד ניתן להתאמן עם מספר עובדים באמצעות MultiWorkerMirroredStrategy במערך הנתונים של MNIST.
  2. דוגמה מקצה לקצה להפעלת הכשרה מרובת עובדים עם אסטרטגיות הפצה ב- tensorflow/ecosystem באמצעות תבניות Kubernetes. זה מתחיל במודל Keras וממיר אותו ל-Estimator באמצעות tf.keras.estimator.model_to_estimator API.
  3. הדגם הרשמי של ResNet50 , שניתן לאמן אותו באמצעות MirroredStrategy או MultiWorkerMirroredStrategy .