سوالی دارید؟ در انجمن بازدید از انجمن TensorFlow با انجمن ارتباط برقرار کنید

برآوردگرها

مشاهده در TensorFlow.org در Google Colab اجرا کنید مشاهده منبع در GitHub دانلود دفترچه یادداشت

این سند tf.estimator —یک سطح بالا tf.estimator API را معرفی می کند. برآورد کنندگان اقدامات زیر را کپسول می کنند:

  • آموزش
  • ارزیابی
  • پیش بینی
  • صادرات برای خدمت

TensorFlow چندین برآوردگر از پیش ساخته شده را پیاده سازی می کند. برآوردگرهای سفارشی هنوز پشتیبانی می شوند ، اما عمدتاً به عنوان یک معیار سازگاری به عقب. از برآوردگرهای سفارشی نباید برای کد جدید استفاده شود . همه برآوردگرها - از پیش ساخته یا سفارشی - کلاسهایی هستند که براساس کلاس tf.estimator.Estimator ساخته شده اند.

برای یک مثال سریع ، آموزش های Estimator را امتحان کنید. برای مروری بر طراحی API ، مقاله سفید را بررسی کنید.

برپایی

pip install -U tensorflow_datasets
import tempfile
import os

import tensorflow as tf
import tensorflow_datasets as tfds

مزایای

مشابه tf.keras.Model ، estimator یک انتزاع در سطح مدل است. tf.estimator برخی از قابلیت هایی را که در حال حاضر برای tf.keras هنوز در دست توسعه است فراهم می کند. اینها هستند:

  • پارامتر آموزش مبتنی بر سرور
  • یکپارچه سازی کامل TFX

قابلیت های برآوردگر

برآورد کنندگان مزایای زیر را ارائه می دهند:

  • بدون تغییر در مدل خود ، می توانید مدل های مبتنی برآورد را بر روی یک میزبان محلی یا در یک محیط توزیع شده چند سرور اجرا کنید. علاوه بر این ، می توانید مدل های مبتنی بر برآورد را بر روی CPU ، GPU یا TPU بدون رمزگذاری مجدد مدل خود اجرا کنید.
  • برآورد کنندگان یک حلقه آموزش توزیع شده ایمن ارائه می دهند که نحوه و زمان انجام موارد زیر را کنترل می کند:
    • داده را بارگیری کنید
    • استثنائات را اداره کنید
    • پرونده های ایست بازرسی ایجاد کنید و از خرابی ها بازیابی کنید
    • خلاصه ها را برای TensorBoard ذخیره کنید

هنگام نوشتن برنامه با Estimators ، باید خط لوله ورودی داده را از مدل جدا کنید. این جداسازی آزمایشات با مجموعه داده های مختلف را ساده می کند.

استفاده از برآوردگرهای از پیش ساخته شده

برآوردگرهای از پیش ساخته شده شما را قادر می سازند تا در سطح مفهومی بسیار بالاتری نسبت به API های پایه TensorFlow کار کنید. دیگر لازم نیست نگران ایجاد نمودار محاسباتی یا جلسات باشید زیرا ارزیابی کنندگان همه "لوله کشی" را برای شما مدیریت می کنند. علاوه بر این ، برآوردگرهای از پیش ساخته شده به شما امکان می دهند با ایجاد حداقل تغییرات کد ، با معماری مدل های مختلف آزمایش کنید. tf.estimator.DNNClassifier ، به عنوان مثال ، یک کلاس برآوردگر از پیش ساخته شده است که مدل های طبقه بندی را بر اساس شبکه های عصبی متراکم و فوروارد آموزش می دهد.

برنامه TensorFlow با تکیه بر یک برآوردگر از پیش ساخته شده معمولاً شامل چهار مرحله زیر است:

1. توابع ورودی را بنویسید

به عنوان مثال ، ممکن است یک تابع برای وارد کردن مجموعه آموزش و یک تابع دیگر برای وارد کردن مجموعه آزمون ایجاد کنید. برآوردگران انتظار دارند که ورودی های آنها به صورت یک جفت آبجکت قالب بندی شود:

  • فرهنگ لغتی که در آن کلیدها نام ویژگی هستند و مقادیر آن Tensors (یا SparseTensors) است که حاوی داده های مربوطه مربوطه است.
  • تنسور حاوی یک یا چند برچسب

input_fn باید یکtf.data.Dataset که در آن قالب جفت ایجاد کند.

به عنوان مثال ، کد زیر یکtf.data.Dataset از پرونده train.csv مجموعه داده تایتانیک ایجاد می کند:

def train_input_fn():
  titanic_file = tf.keras.utils.get_file("train.csv", "https://storage.googleapis.com/tf-datasets/titanic/train.csv")
  titanic = tf.data.experimental.make_csv_dataset(
      titanic_file, batch_size=32,
      label_name="survived")
  titanic_batches = (
      titanic.cache().repeat().shuffle(500)
      .prefetch(tf.data.AUTOTUNE))
  return titanic_batches

input_fn در یک اعدام tf.Graph و همچنین می تواند به طور مستقیم بازگشت (features_dics, labels) جفت حاوی تانسورها نمودار، اما این خطا خارج مستعد از موارد ساده مانند ثابت بازگشت است.

2. ستون های ویژگی را تعریف کنید.

هر tf.feature_column یک نام ویژگی ، نوع آن و هرگونه پردازش ورودی را مشخص می کند.

به عنوان مثال ، قطعه زیر سه ستون ویژگی ایجاد می کند.

  • اولین مورد از ویژگی age به طور مستقیم به عنوان ورودی با نقطه شناور استفاده می کند.
  • مورد دوم از ویژگی class به عنوان ورودی طبقه ای استفاده می کند.
  • مورد سوم از embark_town به عنوان ورودی طبقه ای استفاده می کند ، اما برای جلوگیری از نیاز به برشمردن گزینه ها و تنظیم تعداد گزینه ها ، از hashing trick استفاده می کند.

برای اطلاعات بیشتر ، آموزش ستون ویژگی را بررسی کنید.

age = tf.feature_column.numeric_column('age')
cls = tf.feature_column.categorical_column_with_vocabulary_list('class', ['First', 'Second', 'Third']) 
embark = tf.feature_column.categorical_column_with_hash_bucket('embark_town', 32)

3. برآوردگر از قبل ساخته شده مربوطه را نمونه بگیرید.

به عنوان مثال ، در اینجا نمونه ای از برآوردگر از قبل ساخته شده به نام LinearClassifier :

model_dir = tempfile.mkdtemp()
model = tf.estimator.LinearClassifier(
    model_dir=model_dir,
    feature_columns=[embark, cls, age],
    n_classes=2
)
INFO:tensorflow:Using default config.
INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmpeqzx9get', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}

برای اطلاعات بیشتر می توانید به آموزش طبقه بندی خطی بروید.

4- روش آموزش ، ارزیابی یا استنباط را فراخوانی کنید.

همه برآوردگرها روش های train ، evaluate و predict ارائه می دهند.

model = model.train(input_fn=train_input_fn, steps=100)
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/training/training_util.py:236: Variable.initialized_value (from tensorflow.python.ops.variables) is deprecated and will be removed in a future version.
Instructions for updating:
Use Variable.read_value. Variables in 2.X are initialized automatically both in eager and graph (inside tf.defun) contexts.
Downloading data from https://storage.googleapis.com/tf-datasets/titanic/train.csv
32768/30874 [===============================] - 0s 0us/step
INFO:tensorflow:Calling model_fn.
/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/keras/engine/base_layer_v1.py:1700: UserWarning: `layer.add_variable` is deprecated and will be removed in a future version. Please use `layer.add_weight` method instead.
  warnings.warn('`layer.add_variable` is deprecated and '
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/keras/optimizer_v2/ftrl.py:149: calling Constant.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.
Instructions for updating:
Call initializer instance with the dtype argument instead of passing it to the constructor
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmpeqzx9get/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
INFO:tensorflow:loss = 0.6931472, step = 0
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 100...
INFO:tensorflow:Saving checkpoints for 100 into /tmp/tmpeqzx9get/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 100...
INFO:tensorflow:Loss for final step: 0.54946315.
result = model.evaluate(train_input_fn, steps=10)

for key, value in result.items():
  print(key, ":", value)
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Starting evaluation at 2021-06-19T01:21:21
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /tmp/tmpeqzx9get/model.ckpt-100
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Evaluation [1/10]
INFO:tensorflow:Evaluation [2/10]
INFO:tensorflow:Evaluation [3/10]
INFO:tensorflow:Evaluation [4/10]
INFO:tensorflow:Evaluation [5/10]
INFO:tensorflow:Evaluation [6/10]
INFO:tensorflow:Evaluation [7/10]
INFO:tensorflow:Evaluation [8/10]
INFO:tensorflow:Evaluation [9/10]
INFO:tensorflow:Evaluation [10/10]
INFO:tensorflow:Inference Time : 0.63381s
INFO:tensorflow:Finished evaluation at 2021-06-19-01:21:21
INFO:tensorflow:Saving dict for global step 100: accuracy = 0.646875, accuracy_baseline = 0.6, auc = 0.69405115, auc_precision_recall = 0.6043487, average_loss = 0.64180285, global_step = 100, label/mean = 0.4, loss = 0.64180285, precision = 0.72727275, prediction/mean = 0.3058043, recall = 0.1875
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 100: /tmp/tmpeqzx9get/model.ckpt-100
accuracy : 0.646875
accuracy_baseline : 0.6
auc : 0.69405115
auc_precision_recall : 0.6043487
average_loss : 0.64180285
label/mean : 0.4
loss : 0.64180285
precision : 0.72727275
prediction/mean : 0.3058043
recall : 0.1875
global_step : 100
for pred in model.predict(train_input_fn):
  for key, value in pred.items():
    print(key, ":", value)
  break
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /tmp/tmpeqzx9get/model.ckpt-100
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
logits : [-1.5908554]
logistic : [0.16926359]
probabilities : [0.83073646 0.16926359]
class_ids : [0]
classes : [b'0']
all_class_ids : [0 1]
all_classes : [b'0' b'1']

مزایای برآوردگرهای از پیش ساخته شده

برآوردگرهای از پیش ساخته شده بهترین روش ها را رمزگذاری می کنند و مزایای زیر را فراهم می کنند:

  • بهترین روش ها برای تعیین محل های مختلف نمودار محاسباتی ، اجرای استراتژی ها روی یک ماشین یا خوشه.
  • بهترین روش ها برای نوشتن رویداد (خلاصه) و خلاصه های مفید جهانی.

اگر از برآوردگرهای پیش ساخته استفاده نمی کنید ، باید ویژگی های قبلی را خودتان پیاده سازی کنید.

برآوردگرهای سفارشی

قلب هر برآوردگر - چه از قبل ساخته شده و چه سفارشی - عملکرد مدل آن ، model_fn است که روشی است که نمودارها را برای آموزش ، ارزیابی و پیش بینی می سازد. وقتی از برآوردگر از پیش ساخته شده استفاده می کنید ، شخص دیگری قبلاً عملکرد مدل را اجرا کرده است. هنگام اعتماد به یک برآوردگر سفارشی ، باید خودتان تابع مدل را بنویسید.

یک برآوردگر از مدل Keras ایجاد کنید

با tf.keras.estimator.model_to_estimator می توانید مدل های موجود Keras را به Estimators tf.keras.estimator.model_to_estimator . اگر می خواهید کد مدل خود را مدرن کنید ، این کار مفید است ، اما خطوط آموزش شما هنوز هم به برآوردگر نیاز دارد.

یک مدل Keras MobileNet V2 تهیه کنید و مدل را با بهینه ساز ، افت و معیارها برای آموزش با آنها تدوین کنید:

keras_mobilenet_v2 = tf.keras.applications.MobileNetV2(
    input_shape=(160, 160, 3), include_top=False)
keras_mobilenet_v2.trainable = False

estimator_model = tf.keras.Sequential([
    keras_mobilenet_v2,
    tf.keras.layers.GlobalAveragePooling2D(),
    tf.keras.layers.Dense(1)
])

# Compile the model
estimator_model.compile(
    optimizer='adam',
    loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
    metrics=['accuracy'])
Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/mobilenet_v2/mobilenet_v2_weights_tf_dim_ordering_tf_kernels_1.0_160_no_top.h5
9412608/9406464 [==============================] - 0s 0us/step

یک Estimator از مدل Keras وارد شده ایجاد کنید. حالت اولیه مدل Keras در Estimator ایجاد شده حفظ می شود:

est_mobilenet_v2 = tf.keras.estimator.model_to_estimator(keras_model=estimator_model)
INFO:tensorflow:Using default config.
WARNING:tensorflow:Using temporary folder as model directory: /tmp/tmpamo11374
INFO:tensorflow:Using the Keras model provided.
/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/keras/backend.py:435: UserWarning: `tf.keras.backend.set_learning_phase` is deprecated and will be removed after 2020-10-11. To update it, simply pass a True/False value to the `training` argument of the `__call__` method of your layer or model.
  warnings.warn('`tf.keras.backend.set_learning_phase` is deprecated and '
/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/keras/utils/generic_utils.py:497: CustomMaskWarning: Custom mask layers require a config and must override get_config. When loading, the custom mask layer must be passed to the custom_objects argument.
  category=CustomMaskWarning)
INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmpamo11374', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}

با Estimator مشتق مانند هر Estimator دیگری رفتار کنید.

IMG_SIZE = 160  # All images will be resized to 160x160

def preprocess(image, label):
  image = tf.cast(image, tf.float32)
  image = (image/127.5) - 1
  image = tf.image.resize(image, (IMG_SIZE, IMG_SIZE))
  return image, label
def train_input_fn(batch_size):
  data = tfds.load('cats_vs_dogs', as_supervised=True)
  train_data = data['train']
  train_data = train_data.map(preprocess).shuffle(500).batch(batch_size)
  return train_data

برای آموزش ، با عملکرد قطار برآوردگر تماس بگیرید:

est_mobilenet_v2.train(input_fn=lambda: train_input_fn(32), steps=50)
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Warm-starting with WarmStartSettings: WarmStartSettings(ckpt_to_initialize_from='/tmp/tmpamo11374/keras/keras_model.ckpt', vars_to_warm_start='.*', var_name_to_vocab_info={}, var_name_to_prev_var_name={})
INFO:tensorflow:Warm-starting with WarmStartSettings: WarmStartSettings(ckpt_to_initialize_from='/tmp/tmpamo11374/keras/keras_model.ckpt', vars_to_warm_start='.*', var_name_to_vocab_info={}, var_name_to_prev_var_name={})
INFO:tensorflow:Warm-starting from: /tmp/tmpamo11374/keras/keras_model.ckpt
INFO:tensorflow:Warm-starting from: /tmp/tmpamo11374/keras/keras_model.ckpt
INFO:tensorflow:Warm-starting variables only in TRAINABLE_VARIABLES.
INFO:tensorflow:Warm-starting variables only in TRAINABLE_VARIABLES.
INFO:tensorflow:Warm-started 158 variables.
INFO:tensorflow:Warm-started 158 variables.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmpamo11374/model.ckpt.
INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmpamo11374/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
INFO:tensorflow:loss = 0.70249426, step = 0
INFO:tensorflow:loss = 0.70249426, step = 0
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 50...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 50...
INFO:tensorflow:Saving checkpoints for 50 into /tmp/tmpamo11374/model.ckpt.
INFO:tensorflow:Saving checkpoints for 50 into /tmp/tmpamo11374/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 50...
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 50...
INFO:tensorflow:Loss for final step: 0.6834691.
INFO:tensorflow:Loss for final step: 0.6834691.
<tensorflow_estimator.python.estimator.estimator.EstimatorV2 at 0x7fcaf1c34cd0>

به همین ترتیب ، برای ارزیابی ، عملکرد ارزیابی Estimator را فراخوانی کنید:

est_mobilenet_v2.evaluate(input_fn=lambda: train_input_fn(32), steps=10)
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Calling model_fn.
/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/keras/engine/training.py:2426: UserWarning: `Model.state_updates` will be removed in a future version. This property should not be used in TensorFlow 2.0, as `updates` are applied automatically.
  warnings.warn('`Model.state_updates` will be removed in a future version. '
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Starting evaluation at 2021-06-19T01:21:56
INFO:tensorflow:Starting evaluation at 2021-06-19T01:21:56
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /tmp/tmpamo11374/model.ckpt-50
INFO:tensorflow:Restoring parameters from /tmp/tmpamo11374/model.ckpt-50
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Evaluation [1/10]
INFO:tensorflow:Evaluation [1/10]
INFO:tensorflow:Evaluation [2/10]
INFO:tensorflow:Evaluation [2/10]
INFO:tensorflow:Evaluation [3/10]
INFO:tensorflow:Evaluation [3/10]
INFO:tensorflow:Evaluation [4/10]
INFO:tensorflow:Evaluation [4/10]
INFO:tensorflow:Evaluation [5/10]
INFO:tensorflow:Evaluation [5/10]
INFO:tensorflow:Evaluation [6/10]
INFO:tensorflow:Evaluation [6/10]
INFO:tensorflow:Evaluation [7/10]
INFO:tensorflow:Evaluation [7/10]
INFO:tensorflow:Evaluation [8/10]
INFO:tensorflow:Evaluation [8/10]
INFO:tensorflow:Evaluation [9/10]
INFO:tensorflow:Evaluation [9/10]
INFO:tensorflow:Evaluation [10/10]
INFO:tensorflow:Evaluation [10/10]
INFO:tensorflow:Inference Time : 5.68835s
INFO:tensorflow:Inference Time : 5.68835s
INFO:tensorflow:Finished evaluation at 2021-06-19-01:22:02
INFO:tensorflow:Finished evaluation at 2021-06-19-01:22:02
INFO:tensorflow:Saving dict for global step 50: accuracy = 0.478125, global_step = 50, loss = 0.671334
INFO:tensorflow:Saving dict for global step 50: accuracy = 0.478125, global_step = 50, loss = 0.671334
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 50: /tmp/tmpamo11374/model.ckpt-50
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 50: /tmp/tmpamo11374/model.ckpt-50
{'accuracy': 0.478125, 'loss': 0.671334, 'global_step': 50}

برای جزئیات بیشتر ، لطفاً به اسناد tf.keras.estimator.model_to_estimator .

ذخیره ایست های بازرسی مبتنی بر شی با Estimator

برآوردگران به طور پیش فرض ، ایست های بازرسی را با نام های متغیر ذخیره می کنند تا نمودار شی object توصیف شده در راهنمای Checkpoint . tf.train.Checkpoint ایست های بازرسی مبتنی بر نام را می خواند ، اما ممکن است هنگام جابجایی قطعات یک مدل به خارج از model_fn برآوردگر ، نام متغیرها تغییر کند. صرفه جویی در سازگاری به جلو ، ایست های بازرسی مبتنی بر شی object آموزش مدل در داخل برآوردگر و سپس استفاده از آن را در خارج از مدل آسان می کند.

import tensorflow.compat.v1 as tf_compat
def toy_dataset():
  inputs = tf.range(10.)[:, None]
  labels = inputs * 5. + tf.range(5.)[None, :]
  return tf.data.Dataset.from_tensor_slices(
    dict(x=inputs, y=labels)).repeat().batch(2)
class Net(tf.keras.Model):
  """A simple linear model."""

  def __init__(self):
    super(Net, self).__init__()
    self.l1 = tf.keras.layers.Dense(5)

  def call(self, x):
    return self.l1(x)
def model_fn(features, labels, mode):
  net = Net()
  opt = tf.keras.optimizers.Adam(0.1)
  ckpt = tf.train.Checkpoint(step=tf_compat.train.get_global_step(),
                             optimizer=opt, net=net)
  with tf.GradientTape() as tape:
    output = net(features['x'])
    loss = tf.reduce_mean(tf.abs(output - features['y']))
  variables = net.trainable_variables
  gradients = tape.gradient(loss, variables)
  return tf.estimator.EstimatorSpec(
    mode,
    loss=loss,
    train_op=tf.group(opt.apply_gradients(zip(gradients, variables)),
                      ckpt.step.assign_add(1)),
    # Tell the Estimator to save "ckpt" in an object-based format.
    scaffold=tf_compat.train.Scaffold(saver=ckpt))

tf.keras.backend.clear_session()
est = tf.estimator.Estimator(model_fn, './tf_estimator_example/')
est.train(toy_dataset, steps=10)
INFO:tensorflow:Using default config.
INFO:tensorflow:Using default config.
INFO:tensorflow:Using config: {'_model_dir': './tf_estimator_example/', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
INFO:tensorflow:Using config: {'_model_dir': './tf_estimator_example/', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Saving checkpoints for 0 into ./tf_estimator_example/model.ckpt.
INFO:tensorflow:Saving checkpoints for 0 into ./tf_estimator_example/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
INFO:tensorflow:loss = 4.531598, step = 0
INFO:tensorflow:loss = 4.531598, step = 0
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 10...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 10...
INFO:tensorflow:Saving checkpoints for 10 into ./tf_estimator_example/model.ckpt.
INFO:tensorflow:Saving checkpoints for 10 into ./tf_estimator_example/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 10...
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 10...
INFO:tensorflow:Loss for final step: 37.41622.
INFO:tensorflow:Loss for final step: 37.41622.
<tensorflow_estimator.python.estimator.estimator.EstimatorV2 at 0x7fcaf1b75210>

tf.train.Checkpoint می تواند ایست های بازرسی برآوردگر را از model_dir خود model_dir .

opt = tf.keras.optimizers.Adam(0.1)
net = Net()
ckpt = tf.train.Checkpoint(
  step=tf.Variable(1, dtype=tf.int64), optimizer=opt, net=net)
ckpt.restore(tf.train.latest_checkpoint('./tf_estimator_example/'))
ckpt.step.numpy()  # From est.train(..., steps=10)
10

مدلهای ذخیره شده از برآوردگرها

برآوردگران از طریق tf.Estimator.export_saved_model .

input_column = tf.feature_column.numeric_column("x")

estimator = tf.estimator.LinearClassifier(feature_columns=[input_column])

def input_fn():
  return tf.data.Dataset.from_tensor_slices(
    ({"x": [1., 2., 3., 4.]}, [1, 1, 0, 0])).repeat(200).shuffle(64).batch(16)
estimator.train(input_fn)
INFO:tensorflow:Using default config.
INFO:tensorflow:Using default config.
WARNING:tensorflow:Using temporary folder as model directory: /tmp/tmpnh9mbjji
WARNING:tensorflow:Using temporary folder as model directory: /tmp/tmpnh9mbjji
INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmpnh9mbjji', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmpnh9mbjji', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmpnh9mbjji/model.ckpt.
INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmpnh9mbjji/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
INFO:tensorflow:loss = 0.6931472, step = 0
INFO:tensorflow:loss = 0.6931472, step = 0
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 50...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 50...
INFO:tensorflow:Saving checkpoints for 50 into /tmp/tmpnh9mbjji/model.ckpt.
INFO:tensorflow:Saving checkpoints for 50 into /tmp/tmpnh9mbjji/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 50...
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 50...
INFO:tensorflow:Loss for final step: 0.41857475.
INFO:tensorflow:Loss for final step: 0.41857475.
<tensorflow_estimator.python.estimator.canned.linear.LinearClassifierV2 at 0x7fca8c0e7f90>

برای ذخیره Estimator باید یک serving_input_receiver ایجاد کنید. این تابع بخشی از tf.Graph می tf.Graph که داده های خام دریافت شده توسط SavedModel را تجزیه می کند.

ماژول tf.estimator.export شامل توابعی برای کمک به ساخت این receivers .

کد زیر یک گیرنده ایجاد می کند ، بر اساس feature_columns ، سریال tf.Example را قبول می کند. tf.Example بافر پروتکل ، که اغلب با tf- tf.Example استفاده می شود.

tmpdir = tempfile.mkdtemp()

serving_input_fn = tf.estimator.export.build_parsing_serving_input_receiver_fn(
  tf.feature_column.make_parse_example_spec([input_column]))

estimator_base_path = os.path.join(tmpdir, 'from_estimator')
estimator_path = estimator.export_saved_model(estimator_base_path, serving_input_fn)
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Done calling model_fn.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/saved_model/signature_def_utils_impl.py:145: build_tensor_info (from tensorflow.python.saved_model.utils_impl) is deprecated and will be removed in a future version.
Instructions for updating:
This function will only be available through the v1 compatibility library as tf.compat.v1.saved_model.utils.build_tensor_info or tf.compat.v1.saved_model.build_tensor_info.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/saved_model/signature_def_utils_impl.py:145: build_tensor_info (from tensorflow.python.saved_model.utils_impl) is deprecated and will be removed in a future version.
Instructions for updating:
This function will only be available through the v1 compatibility library as tf.compat.v1.saved_model.utils.build_tensor_info or tf.compat.v1.saved_model.build_tensor_info.
INFO:tensorflow:Signatures INCLUDED in export for Classify: ['serving_default', 'classification']
INFO:tensorflow:Signatures INCLUDED in export for Classify: ['serving_default', 'classification']
INFO:tensorflow:Signatures INCLUDED in export for Regress: ['regression']
INFO:tensorflow:Signatures INCLUDED in export for Regress: ['regression']
INFO:tensorflow:Signatures INCLUDED in export for Predict: ['predict']
INFO:tensorflow:Signatures INCLUDED in export for Predict: ['predict']
INFO:tensorflow:Signatures INCLUDED in export for Train: None
INFO:tensorflow:Signatures INCLUDED in export for Train: None
INFO:tensorflow:Signatures INCLUDED in export for Eval: None
INFO:tensorflow:Signatures INCLUDED in export for Eval: None
INFO:tensorflow:Restoring parameters from /tmp/tmpnh9mbjji/model.ckpt-50
INFO:tensorflow:Restoring parameters from /tmp/tmpnh9mbjji/model.ckpt-50
INFO:tensorflow:Assets added to graph.
INFO:tensorflow:Assets added to graph.
INFO:tensorflow:No assets to write.
INFO:tensorflow:No assets to write.
INFO:tensorflow:SavedModel written to: /tmp/tmpueh7p422/from_estimator/temp-1624065724/saved_model.pb
INFO:tensorflow:SavedModel written to: /tmp/tmpueh7p422/from_estimator/temp-1624065724/saved_model.pb

همچنین می توانید آن مدل را از پایتون بارگیری و اجرا کنید:

imported = tf.saved_model.load(estimator_path)

def predict(x):
  example = tf.train.Example()
  example.features.feature["x"].float_list.value.extend([x])
  return imported.signatures["predict"](
    examples=tf.constant([example.SerializeToString()]))
print(predict(1.5))
print(predict(3.5))
{'logits': <tf.Tensor: shape=(1, 1), dtype=float32, numpy=array([[0.3068818]], dtype=float32)>, 'all_classes': <tf.Tensor: shape=(1, 2), dtype=string, numpy=array([[b'0', b'1']], dtype=object)>, 'logistic': <tf.Tensor: shape=(1, 1), dtype=float32, numpy=array([[0.57612395]], dtype=float32)>, 'probabilities': <tf.Tensor: shape=(1, 2), dtype=float32, numpy=array([[0.42387608, 0.57612395]], dtype=float32)>, 'classes': <tf.Tensor: shape=(1, 1), dtype=string, numpy=array([[b'1']], dtype=object)>, 'all_class_ids': <tf.Tensor: shape=(1, 2), dtype=int32, numpy=array([[0, 1]], dtype=int32)>, 'class_ids': <tf.Tensor: shape=(1, 1), dtype=int64, numpy=array([[1]])>}
{'logits': <tf.Tensor: shape=(1, 1), dtype=float32, numpy=array([[-1.1466763]], dtype=float32)>, 'all_classes': <tf.Tensor: shape=(1, 2), dtype=string, numpy=array([[b'0', b'1']], dtype=object)>, 'logistic': <tf.Tensor: shape=(1, 1), dtype=float32, numpy=array([[0.24109668]], dtype=float32)>, 'probabilities': <tf.Tensor: shape=(1, 2), dtype=float32, numpy=array([[0.7589033 , 0.24109669]], dtype=float32)>, 'classes': <tf.Tensor: shape=(1, 1), dtype=string, numpy=array([[b'0']], dtype=object)>, 'all_class_ids': <tf.Tensor: shape=(1, 2), dtype=int32, numpy=array([[0, 1]], dtype=int32)>, 'class_ids': <tf.Tensor: shape=(1, 1), dtype=int64, numpy=array([[0]])>}

tf.estimator.export.build_raw_serving_input_receiver_fn به شما امکان می دهد توابع ورودی ایجاد کنید که به جای tf.train.Example خام را می tf.train.Example . tf.train.Example s.

استفاده از tf.distribute.Strategy با Estimator (پشتیبانی محدود)

tf.estimator یک توزیع توزیع شده tf.estimator API است که در ابتدا از رویکرد سرور پارامتر async پشتیبانی می کند. tf.estimator اکنون از tf.distribute.Strategy پشتیبانی می tf.distribute.Strategy . اگر از tf.estimator استفاده می کنید ، می توانید با تغییرات بسیار کمی در کد خود به آموزش توزیع شده تغییر دهید. با استفاده از این ، کاربران برآوردگر هم اکنون می توانند آموزش توزیع همزمان بر روی چندین GPU و چندین کارگر را انجام دهند و همچنین از TPU استفاده کنند. این پشتیبانی در Estimator محدود است. برای جزئیات بیشتر به بخش "چه اکنون پشتیبانی می شود" در زیر مراجعه کنید.

استفاده از tf.distribute.Strategy با Estimator کمی متفاوت از مورد Keras است. به جای استفاده از strategy.scope ، اکنون شی استراتژی را به RunConfig برای برآوردگر منتقل می کنید.

برای اطلاعات بیشتر می توانید به راهنمای توزیع شده آموزش مراجعه کنید.

در اینجا یک قطعه از کد است که نشان می دهد این کار را با از پیش ساخته شده برآورد LinearRegressor و MirroredStrategy :

mirrored_strategy = tf.distribute.MirroredStrategy()
config = tf.estimator.RunConfig(
    train_distribute=mirrored_strategy, eval_distribute=mirrored_strategy)
regressor = tf.estimator.LinearRegressor(
    feature_columns=[tf.feature_column.numeric_column('feats')],
    optimizer='SGD',
    config=config)
WARNING:tensorflow:Collective ops is not configured at program startup. Some performance features may not be enabled.
WARNING:tensorflow:Collective ops is not configured at program startup. Some performance features may not be enabled.
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)
INFO:tensorflow:Initializing RunConfig with distribution strategies.
INFO:tensorflow:Initializing RunConfig with distribution strategies.
INFO:tensorflow:Not using Distribute Coordinator.
INFO:tensorflow:Not using Distribute Coordinator.
WARNING:tensorflow:Using temporary folder as model directory: /tmp/tmphjmg1q2m
WARNING:tensorflow:Using temporary folder as model directory: /tmp/tmphjmg1q2m
INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmphjmg1q2m', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': <tensorflow.python.distribute.mirrored_strategy.MirroredStrategy object at 0x7fca2c1de350>, '_device_fn': None, '_protocol': None, '_eval_distribute': <tensorflow.python.distribute.mirrored_strategy.MirroredStrategy object at 0x7fca2c1de350>, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1, '_distribute_coordinator_mode': None}
INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmphjmg1q2m', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': <tensorflow.python.distribute.mirrored_strategy.MirroredStrategy object at 0x7fca2c1de350>, '_device_fn': None, '_protocol': None, '_eval_distribute': <tensorflow.python.distribute.mirrored_strategy.MirroredStrategy object at 0x7fca2c1de350>, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1, '_distribute_coordinator_mode': None}

در اینجا ، شما از یک برآوردگر پیش ساخته استفاده می کنید ، اما همان کد با یک برآوردگر سفارشی نیز کار می کند. train_distribute نحوه توزیع آموزش را تعیین می کند و eval_distribute نحوه توزیع ارزیابی را تعیین می کند. این تفاوت دیگری با کراس است که شما از یک استراتژی هم برای آموزش و هم برای نتیجه گیری استفاده می کنید.

اکنون می توانید این برآوردگر را با یک تابع ورودی آموزش داده و ارزیابی کنید:

def input_fn():
  dataset = tf.data.Dataset.from_tensors(({"feats":[1.]}, [1.]))
  return dataset.repeat(1000).batch(10)
regressor.train(input_fn=input_fn, steps=10)
regressor.evaluate(input_fn=input_fn, steps=10)
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Create CheckpointSaverHook.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow_estimator/python/estimator/util.py:96: DistributedIteratorV1.initialize (from tensorflow.python.distribute.input_lib) is deprecated and will be removed in a future version.
Instructions for updating:
Use the iterator's `initializer` property instead.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow_estimator/python/estimator/util.py:96: DistributedIteratorV1.initialize (from tensorflow.python.distribute.input_lib) is deprecated and will be removed in a future version.
Instructions for updating:
Use the iterator's `initializer` property instead.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmphjmg1q2m/model.ckpt.
INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmphjmg1q2m/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
INFO:tensorflow:loss = 1.0, step = 0
INFO:tensorflow:loss = 1.0, step = 0
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 10...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 10...
INFO:tensorflow:Saving checkpoints for 10 into /tmp/tmphjmg1q2m/model.ckpt.
INFO:tensorflow:Saving checkpoints for 10 into /tmp/tmphjmg1q2m/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 10...
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 10...
INFO:tensorflow:Loss for final step: 2.877698e-13.
INFO:tensorflow:Loss for final step: 2.877698e-13.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Starting evaluation at 2021-06-19T01:22:08
INFO:tensorflow:Starting evaluation at 2021-06-19T01:22:08
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /tmp/tmphjmg1q2m/model.ckpt-10
INFO:tensorflow:Restoring parameters from /tmp/tmphjmg1q2m/model.ckpt-10
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Evaluation [1/10]
INFO:tensorflow:Evaluation [1/10]
INFO:tensorflow:Evaluation [2/10]
INFO:tensorflow:Evaluation [2/10]
INFO:tensorflow:Evaluation [3/10]
INFO:tensorflow:Evaluation [3/10]
INFO:tensorflow:Evaluation [4/10]
INFO:tensorflow:Evaluation [4/10]
INFO:tensorflow:Evaluation [5/10]
INFO:tensorflow:Evaluation [5/10]
INFO:tensorflow:Evaluation [6/10]
INFO:tensorflow:Evaluation [6/10]
INFO:tensorflow:Evaluation [7/10]
INFO:tensorflow:Evaluation [7/10]
INFO:tensorflow:Evaluation [8/10]
INFO:tensorflow:Evaluation [8/10]
INFO:tensorflow:Evaluation [9/10]
INFO:tensorflow:Evaluation [9/10]
INFO:tensorflow:Evaluation [10/10]
INFO:tensorflow:Evaluation [10/10]
INFO:tensorflow:Inference Time : 0.22064s
INFO:tensorflow:Inference Time : 0.22064s
INFO:tensorflow:Finished evaluation at 2021-06-19-01:22:08
INFO:tensorflow:Finished evaluation at 2021-06-19-01:22:08
INFO:tensorflow:Saving dict for global step 10: average_loss = 1.4210855e-14, global_step = 10, label/mean = 1.0, loss = 1.4210855e-14, prediction/mean = 0.99999994
INFO:tensorflow:Saving dict for global step 10: average_loss = 1.4210855e-14, global_step = 10, label/mean = 1.0, loss = 1.4210855e-14, prediction/mean = 0.99999994
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 10: /tmp/tmphjmg1q2m/model.ckpt-10
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 10: /tmp/tmphjmg1q2m/model.ckpt-10
{'average_loss': 1.4210855e-14,
 'label/mean': 1.0,
 'loss': 1.4210855e-14,
 'prediction/mean': 0.99999994,
 'global_step': 10}

تفاوت دیگری که در اینجا بین Estimator و Keras برجسته می شود ، مدیریت ورودی است. در Keras ، هر دسته از مجموعه داده ها به طور خودکار در چندین کپی تقسیم می شوند. با این حال ، در برآوردگر ، شما تقسیم دسته ای خودکار را انجام نمی دهید و یا داده ها را بین کارگران مختلف به طور خودکار خرد نمی کنید. شما کنترل کاملی بر نحوه توزیع داده های خود در بین کارگران و دستگاه ها دارید و باید یک input_fn ارائه input_fn تا نحوه توزیع داده های شما مشخص شود.

input_fn شما برای هر کارگر یک بار فراخوانی می شود ، بنابراین به ازای هر کارگر یک مجموعه داده داده می شود. سپس یک دسته از آن مجموعه داده به یک ماکت در آن کارگر تغذیه می شود و بدین ترتیب تعداد N گروه برای 1 کارگر مصرف می شود. به عبارت دیگر ، مجموعه داده ای که توسط input_fn باید دسته هایی از اندازه PER_REPLICA_BATCH_SIZE . و اندازه دسته جهانی یک مرحله را می توان به عنوان PER_REPLICA_BATCH_SIZE * strategy.num_replicas_in_sync بدست آورد.

هنگام انجام آموزش چند کارگری ، باید داده های خود را بین کارگران تقسیم کنید ، یا با یک دانه تصادفی بر روی هر کدام از آنها مرتب شوید. می توانید نمونه ای از نحوه انجام این کار را در آموزش چند کاره با آموزش برآوردگر بررسی کنید .

و به همین ترتیب ، می توانید از استراتژی های چند کارگر و سرور پارامتر نیز استفاده کنید. کد ثابت باقی مانده است ، اما شما باید از tf.estimator.train_and_evaluate استفاده کنید و متغیرهای محیط TF_CONFIG برای هر باینری که در خوشه شما اجرا می شود تنظیم کنید.

اکنون چه چیزی پشتیبانی می شود؟

پشتیبانی محدودی از آموزش با Estimator با استفاده از همه استراتژی ها به جز TPUStrategy . آموزش و ارزیابی v1.train.Scaffold باید مفید باشد ، اما تعدادی از ویژگی های پیشرفته مانند v1.train.Scaffold . v1.train.Scaffold این کار را نمی کند. همچنین ممکن است تعدادی اشکال در این ادغام وجود داشته باشد و هیچ برنامه ای برای بهبود فعال این پشتیبانی وجود ندارد (تمرکز بر پشتیبانی از حلقه Keras و آموزش سفارشی است). در صورت امکان ، شما ترجیح می دهید از tf.distribute با آن API ها استفاده کنید.

آموزش API MirroredStrategy TPUStrategy MultiWorkerMirroredStrategy استراتژی CentralStorage ParameterServerStrategy
API برآوردگر پشتیبانی محدود پشتیبانی نشده پشتیبانی محدود پشتیبانی محدود پشتیبانی محدود

مثالها و آموزشها

در اینجا چند نمونه از انتها به انتها آورده شده است که نحوه استفاده از استراتژی های مختلف با Estimator را نشان می دهد:

  1. آموزش آموزش چند کارگر با برآوردگر نشان می دهد که چگونه می توانید با استفاده از MultiWorkerMirroredStrategy در مجموعه داده MNIST با چندین کارگر آموزش MultiWorkerMirroredStrategy .
  2. یک مثال پایان به پایان اجرای آموزش چند کارگری با استراتژی های توزیع در tensorflow/ecosystem با استفاده از الگوهای Kubernetes. این با یک مدل Keras شروع می شود و با استفاده از tf.keras.estimator.model_to_estimator API آن را به Estimator تبدیل می کند.
  3. این مقام ResNet50 مدل، که می تواند با استفاده از آموزش دیده MirroredStrategy یا MultiWorkerMirroredStrategy .