![]() | ![]() | ![]() | ![]() |
این سند tf.estimator
سطح بالا tf.estimator
API را معرفی می كند. برآورد کنندگان اقدامات زیر را کپسول می کنند:
- آموزش
- ارزیابی
- پیش بینی
- صادرات برای خدمت
TensorFlow چندین برآوردگر از پیش ساخته شده را پیاده سازی می کند. برآوردگرهای سفارشی هنوز پشتیبانی می شوند ، اما عمدتا به عنوان یک معیار سازگاری با عقب از برآوردگرهای سفارشی نباید برای کد جدید استفاده شود . همه برآوردگرها - از پیش ساخته شده یا سفارشی - کلاسهایی هستند که براساس کلاس tf.estimator.Estimator
ساخته شده اند.
برای یک مثال سریع ، آموزش های Estimator را امتحان کنید. برای مروری بر طراحی API ، مقاله سفید را بررسی کنید.
برپایی
pip install -q -U tensorflow_datasets
import tempfile
import os
import tensorflow as tf
import tensorflow_datasets as tfds
مزایای
مشابه tf.keras.Model
، estimator
یک انتزاع در سطح مدل است. tf.estimator
برخی از قابلیت هایی را که در حال حاضر هنوز در دست توسعه است برای tf.keras
فراهم می کند. اینها هستند:
- پارامتر آموزش مبتنی بر سرور
- یکپارچه سازی کامل TFX
قابلیت های برآوردگر
برآورد کنندگان مزایای زیر را ارائه می دهند:
- شما می توانید مدل های مبتنی برآورد را بر روی یک میزبان محلی یا در یک محیط توزیع شده چند سرور بدون تغییر مدل خود اجرا کنید. علاوه بر این ، می توانید مدل های مبتنی بر برآورد را بر روی CPU ، GPU یا TPU بدون رمزگذاری مجدد مدل خود اجرا کنید.
- برآورد کنندگان یک حلقه آموزش توزیع شده ایمن ارائه می دهند که نحوه و زمان انجام کار را کنترل می کند:
- داده را بارگیری کنید
- استثنائات را اداره کنید
- پرونده های ایست بازرسی ایجاد کنید و از خرابی ها بازیابی کنید
- خلاصه ها را برای TensorBoard ذخیره کنید
هنگام نوشتن برنامه با Estimators ، باید خط لوله ورودی داده را از مدل جدا کنید. این جداسازی آزمایشات با مجموعه داده های مختلف را ساده می کند.
استفاده از برآوردگرهای از پیش ساخته شده
برآوردگرهای از پیش ساخته شده شما را قادر می سازند تا در سطح مفهومی بسیار بالاتری نسبت به API های پایه TensorFlow کار کنید. دیگر لازم نیست نگران ایجاد نمودار محاسباتی یا جلسات باشید زیرا ارزیابی کنندگان همه "لوله کشی" را برای شما مدیریت می کنند. علاوه بر این ، برآوردگرهای از پیش ساخته شده به شما امکان می دهند با ایجاد حداقل تغییرات کد ، با معماری مدل های مختلف آزمایش کنید. tf.estimator.DNNClassifier
، به عنوان مثال ، یک کلاس برآوردگر از پیش ساخته شده است که مدل های طبقه بندی را بر اساس شبکه های عصبی متراکم و فوروارد آموزش می دهد.
برنامه TensorFlow با تکیه بر یک برآوردگر از پیش ساخته شده معمولاً شامل چهار مرحله زیر است:
1. توابع ورودی را بنویسید
به عنوان مثال ، ممکن است یک تابع برای وارد کردن مجموعه آموزش و یک تابع دیگر برای وارد کردن مجموعه آزمون ایجاد کنید. برآورد کنندگان انتظار دارند که ورودی های آنها به صورت یک جفت شی قالب بندی شود:
- فرهنگ لغتی که در آن کلیدها نام ویژگی هستند و مقادیر آن Tensors (یا SparseTensors) است که حاوی داده های ویژگی مربوطه است.
- تنسور حاوی یک یا چند برچسب
input_fn
باید یکtf.data.Dataset
که در آن قالب جفت ایجاد کند.
به عنوان مثال، کد زیر را ایجادtf.data.Dataset
را از مجموعه داده تایتانیک train.csv
فایل:
def train_input_fn():
titanic_file = tf.keras.utils.get_file("train.csv", "https://storage.googleapis.com/tf-datasets/titanic/train.csv")
titanic = tf.data.experimental.make_csv_dataset(
titanic_file, batch_size=32,
label_name="survived")
titanic_batches = (
titanic.cache().repeat().shuffle(500)
.prefetch(tf.data.AUTOTUNE))
return titanic_batches
input_fn
در یک اعدام tf.Graph
و همچنین می تواند به طور مستقیم بازگشت (features_dics, labels)
جفت حاوی تانسورها نمودار، اما این خطا خارج مستعد از موارد ساده مانند ثابت بازگشت است.
2. ستون های ویژگی را تعریف کنید.
هر tf.feature_column
یک نام ویژگی ، نوع آن و هرگونه پردازش ورودی را مشخص می کند.
به عنوان مثال ، قطعه زیر سه ستون ویژگی ایجاد می کند.
- اولین مورد از ویژگی
age
به طور مستقیم به عنوان ورودی با نقطه شناور استفاده می کند. - مورد دوم از ویژگی
class
به عنوان ورودی طبقه ای استفاده می کند. - مورد سوم از
embark_town
به عنوان یک ورودی طبقه ای استفاده می کند ، اما ازhashing trick
برای جلوگیری از نیاز به برشمردن گزینه ها و تنظیم تعداد گزینه ها استفاده می کند.
برای اطلاعات بیشتر ، آموزش ستون ویژگی را بررسی کنید.
age = tf.feature_column.numeric_column('age')
cls = tf.feature_column.categorical_column_with_vocabulary_list('class', ['First', 'Second', 'Third'])
embark = tf.feature_column.categorical_column_with_hash_bucket('embark_town', 32)
3. برآوردگر از قبل ساخته شده مربوطه را نمونه بگیرید.
به عنوان مثال ، در اینجا نمونه ای از برآوردگر از قبل ساخته شده به نام LinearClassifier
:
model_dir = tempfile.mkdtemp()
model = tf.estimator.LinearClassifier(
model_dir=model_dir,
feature_columns=[embark, cls, age],
n_classes=2
)
INFO:tensorflow:Using default config. INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmpu27sw9ie', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true graph_options { rewrite_options { meta_optimizer_iterations: ONE } } , '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
برای اطلاعات بیشتر می توانید به آموزش طبقه بندی خطی بروید.
4- یک روش آموزش ، ارزیابی یا استنباط را فراخوانی کنید.
همه برآوردگرها روش های train
، evaluate
و predict
ارائه می دهند.
model = model.train(input_fn=train_input_fn, steps=100)
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/training/training_util.py:236: Variable.initialized_value (from tensorflow.python.ops.variables) is deprecated and will be removed in a future version. Instructions for updating: Use Variable.read_value. Variables in 2.X are initialized automatically both in eager and graph (inside tf.defun) contexts. Downloading data from https://storage.googleapis.com/tf-datasets/titanic/train.csv 32768/30874 [===============================] - 0s 0us/step INFO:tensorflow:Calling model_fn. /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/keras/engine/base_layer_v1.py:1727: UserWarning: `layer.add_variable` is deprecated and will be removed in a future version. Please use `layer.add_weight` method instead. warnings.warn('`layer.add_variable` is deprecated and ' WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/keras/optimizer_v2/ftrl.py:134: calling Constant.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version. Instructions for updating: Call initializer instance with the dtype argument instead of passing it to the constructor INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Create CheckpointSaverHook. INFO:tensorflow:Graph was finalized. INFO:tensorflow:Running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0... INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmpu27sw9ie/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0... INFO:tensorflow:loss = 0.6931472, step = 0 INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 100... INFO:tensorflow:Saving checkpoints for 100 into /tmp/tmpu27sw9ie/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 100... INFO:tensorflow:Loss for final step: 0.62258995.
result = model.evaluate(train_input_fn, steps=10)
for key, value in result.items():
print(key, ":", value)
INFO:tensorflow:Calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Starting evaluation at 2021-01-08T02:56:30Z INFO:tensorflow:Graph was finalized. INFO:tensorflow:Restoring parameters from /tmp/tmpu27sw9ie/model.ckpt-100 INFO:tensorflow:Running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Evaluation [1/10] INFO:tensorflow:Evaluation [2/10] INFO:tensorflow:Evaluation [3/10] INFO:tensorflow:Evaluation [4/10] INFO:tensorflow:Evaluation [5/10] INFO:tensorflow:Evaluation [6/10] INFO:tensorflow:Evaluation [7/10] INFO:tensorflow:Evaluation [8/10] INFO:tensorflow:Evaluation [9/10] INFO:tensorflow:Evaluation [10/10] INFO:tensorflow:Inference Time : 0.67613s INFO:tensorflow:Finished evaluation at 2021-01-08-02:56:31 INFO:tensorflow:Saving dict for global step 100: accuracy = 0.715625, accuracy_baseline = 0.60625, auc = 0.7403657, auc_precision_recall = 0.6804854, average_loss = 0.5836128, global_step = 100, label/mean = 0.39375, loss = 0.5836128, precision = 0.739726, prediction/mean = 0.34897345, recall = 0.42857143 INFO:tensorflow:Saving 'checkpoint_path' summary for global step 100: /tmp/tmpu27sw9ie/model.ckpt-100 accuracy : 0.715625 accuracy_baseline : 0.60625 auc : 0.7403657 auc_precision_recall : 0.6804854 average_loss : 0.5836128 label/mean : 0.39375 loss : 0.5836128 precision : 0.739726 prediction/mean : 0.34897345 recall : 0.42857143 global_step : 100
for pred in model.predict(train_input_fn):
for key, value in pred.items():
print(key, ":", value)
break
INFO:tensorflow:Calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Graph was finalized. INFO:tensorflow:Restoring parameters from /tmp/tmpu27sw9ie/model.ckpt-100 INFO:tensorflow:Running local_init_op. INFO:tensorflow:Done running local_init_op. logits : [-0.73942876] logistic : [0.32312906] probabilities : [0.6768709 0.3231291] class_ids : [0] classes : [b'0'] all_class_ids : [0 1] all_classes : [b'0' b'1']
مزایای برآوردگرهای از پیش ساخته شده
برآوردگرهای از پیش ساخته شده بهترین روش ها را رمزگذاری می کنند و مزایای زیر را فراهم می کنند:
- بهترین روش ها برای تعیین محل های مختلف نمودار محاسباتی ، اجرای استراتژی ها روی یک ماشین یا خوشه.
- بهترین روش ها برای نوشتن رویداد (خلاصه) و خلاصه های مفید جهانی.
اگر از برآوردگرهای پیش ساخته استفاده نمی کنید ، باید ویژگی های قبلی را خودتان پیاده سازی کنید.
برآوردگرهای سفارشی
قلب هر برآوردگر - چه از قبل ساخته شده و چه سفارشی - عملکرد مدل آن ، model_fn
است که روشی است که نمودارها را برای آموزش ، ارزیابی و پیش بینی می سازد. وقتی از برآوردگر از پیش ساخته شده استفاده می کنید ، شخص دیگری قبلاً عملکرد مدل را اجرا کرده است. هنگام اعتماد به یک برآوردگر سفارشی ، باید خودتان تابع مدل را بنویسید.
یک برآوردگر از مدل Keras ایجاد کنید
با tf.keras.estimator.model_to_estimator
می توانید مدل های موجود Keras را به Estimators tf.keras.estimator.model_to_estimator
. اگر می خواهید کد مدل خود را مدرن کنید ، این مفید است ، اما خطوط آموزش شما هنوز به برآوردگرها نیاز دارد.
یک مدل Keras MobileNet V2 تهیه کنید و مدل را با بهینه ساز ، افت و معیارها برای آموزش با آنها تدوین کنید:
keras_mobilenet_v2 = tf.keras.applications.MobileNetV2(
input_shape=(160, 160, 3), include_top=False)
keras_mobilenet_v2.trainable = False
estimator_model = tf.keras.Sequential([
keras_mobilenet_v2,
tf.keras.layers.GlobalAveragePooling2D(),
tf.keras.layers.Dense(1)
])
# Compile the model
estimator_model.compile(
optimizer='adam',
loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
metrics=['accuracy'])
Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/mobilenet_v2/mobilenet_v2_weights_tf_dim_ordering_tf_kernels_1.0_160_no_top.h5 9412608/9406464 [==============================] - 0s 0us/step
یک Estimator
از مدل Keras وارد شده ایجاد کنید. حالت اولیه مدل Keras در Estimator
ایجاد شده حفظ می شود:
est_mobilenet_v2 = tf.keras.estimator.model_to_estimator(keras_model=estimator_model)
INFO:tensorflow:Using default config. WARNING:tensorflow:Using temporary folder as model directory: /tmp/tmpeaonpwe8 INFO:tensorflow:Using the Keras model provided. /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/keras/backend.py:434: UserWarning: `tf.keras.backend.set_learning_phase` is deprecated and will be removed after 2020-10-11. To update it, simply pass a True/False value to the `training` argument of the `__call__` method of your layer or model. warnings.warn('`tf.keras.backend.set_learning_phase` is deprecated and ' INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmpeaonpwe8', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true graph_options { rewrite_options { meta_optimizer_iterations: ONE } } , '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
با Estimator
مشتق مانند هر Estimator
دیگری رفتار کنید.
IMG_SIZE = 160 # All images will be resized to 160x160
def preprocess(image, label):
image = tf.cast(image, tf.float32)
image = (image/127.5) - 1
image = tf.image.resize(image, (IMG_SIZE, IMG_SIZE))
return image, label
def train_input_fn(batch_size):
data = tfds.load('cats_vs_dogs', as_supervised=True)
train_data = data['train']
train_data = train_data.map(preprocess).shuffle(500).batch(batch_size)
return train_data
برای آموزش ، با عملکرد قطار برآوردگر تماس بگیرید:
est_mobilenet_v2.train(input_fn=lambda: train_input_fn(32), steps=50)
Downloading and preparing dataset 786.68 MiB (download: 786.68 MiB, generated: Unknown size, total: 786.68 MiB) to /home/kbuilder/tensorflow_datasets/cats_vs_dogs/4.0.0... WARNING:absl:1738 images were corrupted and were skipped Dataset cats_vs_dogs downloaded and prepared to /home/kbuilder/tensorflow_datasets/cats_vs_dogs/4.0.0. Subsequent calls will reuse this data. INFO:tensorflow:Calling model_fn. INFO:tensorflow:Calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Warm-starting with WarmStartSettings: WarmStartSettings(ckpt_to_initialize_from='/tmp/tmpeaonpwe8/keras/keras_model.ckpt', vars_to_warm_start='.*', var_name_to_vocab_info={}, var_name_to_prev_var_name={}) INFO:tensorflow:Warm-starting with WarmStartSettings: WarmStartSettings(ckpt_to_initialize_from='/tmp/tmpeaonpwe8/keras/keras_model.ckpt', vars_to_warm_start='.*', var_name_to_vocab_info={}, var_name_to_prev_var_name={}) INFO:tensorflow:Warm-starting from: /tmp/tmpeaonpwe8/keras/keras_model.ckpt INFO:tensorflow:Warm-starting from: /tmp/tmpeaonpwe8/keras/keras_model.ckpt INFO:tensorflow:Warm-starting variables only in TRAINABLE_VARIABLES. INFO:tensorflow:Warm-starting variables only in TRAINABLE_VARIABLES. INFO:tensorflow:Warm-started 158 variables. INFO:tensorflow:Warm-started 158 variables. INFO:tensorflow:Create CheckpointSaverHook. INFO:tensorflow:Create CheckpointSaverHook. INFO:tensorflow:Graph was finalized. INFO:tensorflow:Graph was finalized. INFO:tensorflow:Running local_init_op. INFO:tensorflow:Running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0... INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0... INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmpeaonpwe8/model.ckpt. INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmpeaonpwe8/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0... INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0... INFO:tensorflow:loss = 0.6884984, step = 0 INFO:tensorflow:loss = 0.6884984, step = 0 INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 50... INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 50... INFO:tensorflow:Saving checkpoints for 50 into /tmp/tmpeaonpwe8/model.ckpt. INFO:tensorflow:Saving checkpoints for 50 into /tmp/tmpeaonpwe8/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 50... INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 50... INFO:tensorflow:Loss for final step: 0.67705643. INFO:tensorflow:Loss for final step: 0.67705643. <tensorflow_estimator.python.estimator.estimator.EstimatorV2 at 0x7f3d7c3822b0>
به همین ترتیب ، برای ارزیابی ، عملکرد ارزیابی Estimator را فراخوانی کنید:
est_mobilenet_v2.evaluate(input_fn=lambda: train_input_fn(32), steps=10)
INFO:tensorflow:Calling model_fn. INFO:tensorflow:Calling model_fn. /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/keras/engine/training.py:2325: UserWarning: `Model.state_updates` will be removed in a future version. This property should not be used in TensorFlow 2.0, as `updates` are applied automatically. warnings.warn('`Model.state_updates` will be removed in a future version. ' INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Starting evaluation at 2021-01-08T02:57:32Z INFO:tensorflow:Starting evaluation at 2021-01-08T02:57:32Z INFO:tensorflow:Graph was finalized. INFO:tensorflow:Graph was finalized. INFO:tensorflow:Restoring parameters from /tmp/tmpeaonpwe8/model.ckpt-50 INFO:tensorflow:Restoring parameters from /tmp/tmpeaonpwe8/model.ckpt-50 INFO:tensorflow:Running local_init_op. INFO:tensorflow:Running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Evaluation [1/10] INFO:tensorflow:Evaluation [1/10] INFO:tensorflow:Evaluation [2/10] INFO:tensorflow:Evaluation [2/10] INFO:tensorflow:Evaluation [3/10] INFO:tensorflow:Evaluation [3/10] INFO:tensorflow:Evaluation [4/10] INFO:tensorflow:Evaluation [4/10] INFO:tensorflow:Evaluation [5/10] INFO:tensorflow:Evaluation [5/10] INFO:tensorflow:Evaluation [6/10] INFO:tensorflow:Evaluation [6/10] INFO:tensorflow:Evaluation [7/10] INFO:tensorflow:Evaluation [7/10] INFO:tensorflow:Evaluation [8/10] INFO:tensorflow:Evaluation [8/10] INFO:tensorflow:Evaluation [9/10] INFO:tensorflow:Evaluation [9/10] INFO:tensorflow:Evaluation [10/10] INFO:tensorflow:Evaluation [10/10] INFO:tensorflow:Inference Time : 2.42050s INFO:tensorflow:Inference Time : 2.42050s INFO:tensorflow:Finished evaluation at 2021-01-08-02:57:35 INFO:tensorflow:Finished evaluation at 2021-01-08-02:57:35 INFO:tensorflow:Saving dict for global step 50: accuracy = 0.515625, global_step = 50, loss = 0.6688157 INFO:tensorflow:Saving dict for global step 50: accuracy = 0.515625, global_step = 50, loss = 0.6688157 INFO:tensorflow:Saving 'checkpoint_path' summary for global step 50: /tmp/tmpeaonpwe8/model.ckpt-50 INFO:tensorflow:Saving 'checkpoint_path' summary for global step 50: /tmp/tmpeaonpwe8/model.ckpt-50 {'accuracy': 0.515625, 'loss': 0.6688157, 'global_step': 50}
برای جزئیات بیشتر ، لطفاً به اسناد tf.keras.estimator.model_to_estimator
.
ذخیره ایست های بازرسی مبتنی بر شی با Estimator
برآوردگران به طور پیش فرض ، ایست های بازرسی را با نام های متغیر ذخیره می کنند تا نمودار شی object توصیف شده در راهنمای Checkpoint . tf.train.Checkpoint
ایست های بازرسی مبتنی بر نام را می خواند ، اما ممکن است هنگام جابجایی قطعات یک مدل به خارج از model_fn
برآوردگر ، نام متغیرها تغییر کند. صرفه جویی در سازگاری به جلو ، ایست های بازرسی مبتنی بر شی object آموزش یک مدل در داخل برآوردگر و سپس استفاده از آن را در خارج از مدل آسان می کند.
import tensorflow.compat.v1 as tf_compat
def toy_dataset():
inputs = tf.range(10.)[:, None]
labels = inputs * 5. + tf.range(5.)[None, :]
return tf.data.Dataset.from_tensor_slices(
dict(x=inputs, y=labels)).repeat().batch(2)
class Net(tf.keras.Model):
"""A simple linear model."""
def __init__(self):
super(Net, self).__init__()
self.l1 = tf.keras.layers.Dense(5)
def call(self, x):
return self.l1(x)
def model_fn(features, labels, mode):
net = Net()
opt = tf.keras.optimizers.Adam(0.1)
ckpt = tf.train.Checkpoint(step=tf_compat.train.get_global_step(),
optimizer=opt, net=net)
with tf.GradientTape() as tape:
output = net(features['x'])
loss = tf.reduce_mean(tf.abs(output - features['y']))
variables = net.trainable_variables
gradients = tape.gradient(loss, variables)
return tf.estimator.EstimatorSpec(
mode,
loss=loss,
train_op=tf.group(opt.apply_gradients(zip(gradients, variables)),
ckpt.step.assign_add(1)),
# Tell the Estimator to save "ckpt" in an object-based format.
scaffold=tf_compat.train.Scaffold(saver=ckpt))
tf.keras.backend.clear_session()
est = tf.estimator.Estimator(model_fn, './tf_estimator_example/')
est.train(toy_dataset, steps=10)
INFO:tensorflow:Using default config. INFO:tensorflow:Using default config. INFO:tensorflow:Using config: {'_model_dir': './tf_estimator_example/', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true graph_options { rewrite_options { meta_optimizer_iterations: ONE } } , '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1} INFO:tensorflow:Using config: {'_model_dir': './tf_estimator_example/', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true graph_options { rewrite_options { meta_optimizer_iterations: ONE } } , '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1} INFO:tensorflow:Calling model_fn. INFO:tensorflow:Calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Create CheckpointSaverHook. INFO:tensorflow:Create CheckpointSaverHook. INFO:tensorflow:Graph was finalized. INFO:tensorflow:Graph was finalized. INFO:tensorflow:Running local_init_op. INFO:tensorflow:Running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0... INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0... INFO:tensorflow:Saving checkpoints for 0 into ./tf_estimator_example/model.ckpt. INFO:tensorflow:Saving checkpoints for 0 into ./tf_estimator_example/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0... INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0... INFO:tensorflow:loss = 4.4040537, step = 0 INFO:tensorflow:loss = 4.4040537, step = 0 INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 10... INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 10... INFO:tensorflow:Saving checkpoints for 10 into ./tf_estimator_example/model.ckpt. INFO:tensorflow:Saving checkpoints for 10 into ./tf_estimator_example/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 10... INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 10... INFO:tensorflow:Loss for final step: 35.247967. INFO:tensorflow:Loss for final step: 35.247967. <tensorflow_estimator.python.estimator.estimator.EstimatorV2 at 0x7f3d64534518>
tf.train.Checkpoint
می تواند ایست های بازرسی برآوردگر را از model_dir
خود model_dir
.
opt = tf.keras.optimizers.Adam(0.1)
net = Net()
ckpt = tf.train.Checkpoint(
step=tf.Variable(1, dtype=tf.int64), optimizer=opt, net=net)
ckpt.restore(tf.train.latest_checkpoint('./tf_estimator_example/'))
ckpt.step.numpy() # From est.train(..., steps=10)
10
مدلهای ذخیره شده از برآوردگرها
برآورد کنندگان از طریق tf.Estimator.export_saved_model
.
input_column = tf.feature_column.numeric_column("x")
estimator = tf.estimator.LinearClassifier(feature_columns=[input_column])
def input_fn():
return tf.data.Dataset.from_tensor_slices(
({"x": [1., 2., 3., 4.]}, [1, 1, 0, 0])).repeat(200).shuffle(64).batch(16)
estimator.train(input_fn)
INFO:tensorflow:Using default config. INFO:tensorflow:Using default config. WARNING:tensorflow:Using temporary folder as model directory: /tmp/tmpczwhe6jk WARNING:tensorflow:Using temporary folder as model directory: /tmp/tmpczwhe6jk INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmpczwhe6jk', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true graph_options { rewrite_options { meta_optimizer_iterations: ONE } } , '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1} INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmpczwhe6jk', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true graph_options { rewrite_options { meta_optimizer_iterations: ONE } } , '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1} INFO:tensorflow:Calling model_fn. INFO:tensorflow:Calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Create CheckpointSaverHook. INFO:tensorflow:Create CheckpointSaverHook. INFO:tensorflow:Graph was finalized. INFO:tensorflow:Graph was finalized. INFO:tensorflow:Running local_init_op. INFO:tensorflow:Running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0... INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0... INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmpczwhe6jk/model.ckpt. INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmpczwhe6jk/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0... INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0... INFO:tensorflow:loss = 0.6931472, step = 0 INFO:tensorflow:loss = 0.6931472, step = 0 INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 50... INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 50... INFO:tensorflow:Saving checkpoints for 50 into /tmp/tmpczwhe6jk/model.ckpt. INFO:tensorflow:Saving checkpoints for 50 into /tmp/tmpczwhe6jk/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 50... INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 50... INFO:tensorflow:Loss for final step: 0.48830828. INFO:tensorflow:Loss for final step: 0.48830828. <tensorflow_estimator.python.estimator.canned.linear.LinearClassifierV2 at 0x7f3d6452eb00>
برای ذخیره Estimator
باید یک serving_input_receiver
ایجاد کنید. این تابع بخشی از tf.Graph
می tf.Graph
که داده های خام دریافت شده توسط SavedModel را تجزیه می کند.
ماژول tf.estimator.export
شامل توابعی برای کمک به ساخت این receivers
.
کد زیر یک گیرنده ایجاد می کند ، بر اساس feature_columns
، سری tf.Example
را قبول می کند. tf.Example
بافر پروتکل ، که اغلب با tf- tf.Example
استفاده می شود.
tmpdir = tempfile.mkdtemp()
serving_input_fn = tf.estimator.export.build_parsing_serving_input_receiver_fn(
tf.feature_column.make_parse_example_spec([input_column]))
estimator_base_path = os.path.join(tmpdir, 'from_estimator')
estimator_path = estimator.export_saved_model(estimator_base_path, serving_input_fn)
INFO:tensorflow:Calling model_fn. INFO:tensorflow:Calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Done calling model_fn. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/saved_model/signature_def_utils_impl.py:145: build_tensor_info (from tensorflow.python.saved_model.utils_impl) is deprecated and will be removed in a future version. Instructions for updating: This function will only be available through the v1 compatibility library as tf.compat.v1.saved_model.utils.build_tensor_info or tf.compat.v1.saved_model.build_tensor_info. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/saved_model/signature_def_utils_impl.py:145: build_tensor_info (from tensorflow.python.saved_model.utils_impl) is deprecated and will be removed in a future version. Instructions for updating: This function will only be available through the v1 compatibility library as tf.compat.v1.saved_model.utils.build_tensor_info or tf.compat.v1.saved_model.build_tensor_info. INFO:tensorflow:Signatures INCLUDED in export for Classify: ['serving_default', 'classification'] INFO:tensorflow:Signatures INCLUDED in export for Classify: ['serving_default', 'classification'] INFO:tensorflow:Signatures INCLUDED in export for Regress: ['regression'] INFO:tensorflow:Signatures INCLUDED in export for Regress: ['regression'] INFO:tensorflow:Signatures INCLUDED in export for Predict: ['predict'] INFO:tensorflow:Signatures INCLUDED in export for Predict: ['predict'] INFO:tensorflow:Signatures INCLUDED in export for Train: None INFO:tensorflow:Signatures INCLUDED in export for Train: None INFO:tensorflow:Signatures INCLUDED in export for Eval: None INFO:tensorflow:Signatures INCLUDED in export for Eval: None INFO:tensorflow:Restoring parameters from /tmp/tmpczwhe6jk/model.ckpt-50 INFO:tensorflow:Restoring parameters from /tmp/tmpczwhe6jk/model.ckpt-50 INFO:tensorflow:Assets added to graph. INFO:tensorflow:Assets added to graph. INFO:tensorflow:No assets to write. INFO:tensorflow:No assets to write. INFO:tensorflow:SavedModel written to: /tmp/tmp16t8uhub/from_estimator/temp-1610074656/saved_model.pb INFO:tensorflow:SavedModel written to: /tmp/tmp16t8uhub/from_estimator/temp-1610074656/saved_model.pb
همچنین می توانید آن مدل را از پایتون بارگیری و اجرا کنید:
imported = tf.saved_model.load(estimator_path)
def predict(x):
example = tf.train.Example()
example.features.feature["x"].float_list.value.extend([x])
return imported.signatures["predict"](
examples=tf.constant([example.SerializeToString()]))
print(predict(1.5))
print(predict(3.5))
{'logistic': <tf.Tensor: shape=(1, 1), dtype=float32, numpy=array([[0.581246]], dtype=float32)>, 'logits': <tf.Tensor: shape=(1, 1), dtype=float32, numpy=array([[0.32789052]], dtype=float32)>, 'all_classes': <tf.Tensor: shape=(1, 2), dtype=string, numpy=array([[b'0', b'1']], dtype=object)>, 'probabilities': <tf.Tensor: shape=(1, 2), dtype=float32, numpy=array([[0.418754, 0.581246]], dtype=float32)>, 'all_class_ids': <tf.Tensor: shape=(1, 2), dtype=int32, numpy=array([[0, 1]], dtype=int32)>, 'classes': <tf.Tensor: shape=(1, 1), dtype=string, numpy=array([[b'1']], dtype=object)>, 'class_ids': <tf.Tensor: shape=(1, 1), dtype=int64, numpy=array([[1]])>} {'logistic': <tf.Tensor: shape=(1, 1), dtype=float32, numpy=array([[0.24376468]], dtype=float32)>, 'logits': <tf.Tensor: shape=(1, 1), dtype=float32, numpy=array([[-1.1321492]], dtype=float32)>, 'all_classes': <tf.Tensor: shape=(1, 2), dtype=string, numpy=array([[b'0', b'1']], dtype=object)>, 'probabilities': <tf.Tensor: shape=(1, 2), dtype=float32, numpy=array([[0.7562353 , 0.24376468]], dtype=float32)>, 'all_class_ids': <tf.Tensor: shape=(1, 2), dtype=int32, numpy=array([[0, 1]], dtype=int32)>, 'classes': <tf.Tensor: shape=(1, 1), dtype=string, numpy=array([[b'0']], dtype=object)>, 'class_ids': <tf.Tensor: shape=(1, 1), dtype=int64, numpy=array([[0]])>}
tf.estimator.export.build_raw_serving_input_receiver_fn
به شما امکان می دهد توابع ورودی را ایجاد کنید که به جای tf.train.Example
تانسورهای خام را بگیرد. به tf.train.Example
s.
استفاده از tf.distribute.Strategy
با Estimator (پشتیبانی محدود)
tf.estimator
یک توزیع توزیع شده tf.estimator
API است که در اصل از رویکرد سرور پارامتر async پشتیبانی می کند. tf.estimator
اکنون از tf.distribute.Strategy
پشتیبانی می tf.distribute.Strategy
. اگر از tf.estimator
استفاده می کنید ، می توانید با تغییرات بسیار کمی در کد خود به آموزش توزیع شده تغییر دهید. با استفاده از این ، کاربران برآوردگر هم اکنون می توانند آموزش توزیع همزمان بر روی چندین GPU و چندین کارگر را انجام دهند و همچنین از TPU استفاده کنند. این پشتیبانی در Estimator محدود است. برای جزئیات بیشتر به بخش "چه اکنون پشتیبانی می شود" در زیر مراجعه کنید.
استفاده از tf.distribute.Strategy
با Estimator کمی متفاوت از مورد Keras است. به جای استفاده از strategy.scope
، اکنون شی استراتژی را به RunConfig
برای برآوردگر منتقل می کنید.
برای اطلاعات بیشتر می توانید به راهنمای توزیع شده آموزش مراجعه کنید.
در اینجا یک قطعه از کد است که نشان می دهد این کار را با از پیش ساخته شده برآورد LinearRegressor
و MirroredStrategy
:
mirrored_strategy = tf.distribute.MirroredStrategy()
config = tf.estimator.RunConfig(
train_distribute=mirrored_strategy, eval_distribute=mirrored_strategy)
regressor = tf.estimator.LinearRegressor(
feature_columns=[tf.feature_column.numeric_column('feats')],
optimizer='SGD',
config=config)
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',) INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',) INFO:tensorflow:Initializing RunConfig with distribution strategies. INFO:tensorflow:Initializing RunConfig with distribution strategies. INFO:tensorflow:Not using Distribute Coordinator. INFO:tensorflow:Not using Distribute Coordinator. WARNING:tensorflow:Using temporary folder as model directory: /tmp/tmp4uihzu_a WARNING:tensorflow:Using temporary folder as model directory: /tmp/tmp4uihzu_a INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmp4uihzu_a', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true graph_options { rewrite_options { meta_optimizer_iterations: ONE } } , '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': <tensorflow.python.distribute.mirrored_strategy.MirroredStrategy object at 0x7f3e84699518>, '_device_fn': None, '_protocol': None, '_eval_distribute': <tensorflow.python.distribute.mirrored_strategy.MirroredStrategy object at 0x7f3e84699518>, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1, '_distribute_coordinator_mode': None} INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmp4uihzu_a', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true graph_options { rewrite_options { meta_optimizer_iterations: ONE } } , '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': <tensorflow.python.distribute.mirrored_strategy.MirroredStrategy object at 0x7f3e84699518>, '_device_fn': None, '_protocol': None, '_eval_distribute': <tensorflow.python.distribute.mirrored_strategy.MirroredStrategy object at 0x7f3e84699518>, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1, '_distribute_coordinator_mode': None}
در اینجا ، شما از یک برآوردگر پیش ساخته استفاده می کنید ، اما همان کد با یک برآوردگر سفارشی نیز کار می کند. train_distribute
نحوه توزیع آموزش را تعیین می کند ، و eval_distribute
نحوه توزیع ارزیابی را تعیین می کند. این تفاوت دیگری با کراس است که در آن از یک استراتژی هم برای آموزش و هم برای نتیجه گیری استفاده می کنید.
اکنون می توانید این برآوردگر را با یک تابع ورودی آموزش داده و ارزیابی کنید:
def input_fn():
dataset = tf.data.Dataset.from_tensors(({"feats":[1.]}, [1.]))
return dataset.repeat(1000).batch(10)
regressor.train(input_fn=input_fn, steps=10)
regressor.evaluate(input_fn=input_fn, steps=10)
INFO:tensorflow:Calling model_fn. INFO:tensorflow:Calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Create CheckpointSaverHook. INFO:tensorflow:Create CheckpointSaverHook. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_estimator/python/estimator/util.py:96: DistributedIteratorV1.initialize (from tensorflow.python.distribute.input_lib) is deprecated and will be removed in a future version. Instructions for updating: Use the iterator's `initializer` property instead. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_estimator/python/estimator/util.py:96: DistributedIteratorV1.initialize (from tensorflow.python.distribute.input_lib) is deprecated and will be removed in a future version. Instructions for updating: Use the iterator's `initializer` property instead. INFO:tensorflow:Graph was finalized. INFO:tensorflow:Graph was finalized. INFO:tensorflow:Running local_init_op. INFO:tensorflow:Running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0... INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0... INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmp4uihzu_a/model.ckpt. INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmp4uihzu_a/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0... INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0... INFO:tensorflow:loss = 1.0, step = 0 INFO:tensorflow:loss = 1.0, step = 0 INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 10... INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 10... INFO:tensorflow:Saving checkpoints for 10 into /tmp/tmp4uihzu_a/model.ckpt. INFO:tensorflow:Saving checkpoints for 10 into /tmp/tmp4uihzu_a/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 10... INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 10... INFO:tensorflow:Loss for final step: 2.877698e-13. INFO:tensorflow:Loss for final step: 2.877698e-13. INFO:tensorflow:Calling model_fn. INFO:tensorflow:Calling model_fn. INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Starting evaluation at 2021-01-08T02:57:41Z INFO:tensorflow:Starting evaluation at 2021-01-08T02:57:41Z INFO:tensorflow:Graph was finalized. INFO:tensorflow:Graph was finalized. INFO:tensorflow:Restoring parameters from /tmp/tmp4uihzu_a/model.ckpt-10 INFO:tensorflow:Restoring parameters from /tmp/tmp4uihzu_a/model.ckpt-10 INFO:tensorflow:Running local_init_op. INFO:tensorflow:Running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Evaluation [1/10] INFO:tensorflow:Evaluation [1/10] INFO:tensorflow:Evaluation [2/10] INFO:tensorflow:Evaluation [2/10] INFO:tensorflow:Evaluation [3/10] INFO:tensorflow:Evaluation [3/10] INFO:tensorflow:Evaluation [4/10] INFO:tensorflow:Evaluation [4/10] INFO:tensorflow:Evaluation [5/10] INFO:tensorflow:Evaluation [5/10] INFO:tensorflow:Evaluation [6/10] INFO:tensorflow:Evaluation [6/10] INFO:tensorflow:Evaluation [7/10] INFO:tensorflow:Evaluation [7/10] INFO:tensorflow:Evaluation [8/10] INFO:tensorflow:Evaluation [8/10] INFO:tensorflow:Evaluation [9/10] INFO:tensorflow:Evaluation [9/10] INFO:tensorflow:Evaluation [10/10] INFO:tensorflow:Evaluation [10/10] INFO:tensorflow:Inference Time : 0.26266s INFO:tensorflow:Inference Time : 0.26266s INFO:tensorflow:Finished evaluation at 2021-01-08-02:57:42 INFO:tensorflow:Finished evaluation at 2021-01-08-02:57:42 INFO:tensorflow:Saving dict for global step 10: average_loss = 1.4210855e-14, global_step = 10, label/mean = 1.0, loss = 1.4210855e-14, prediction/mean = 0.99999994 INFO:tensorflow:Saving dict for global step 10: average_loss = 1.4210855e-14, global_step = 10, label/mean = 1.0, loss = 1.4210855e-14, prediction/mean = 0.99999994 INFO:tensorflow:Saving 'checkpoint_path' summary for global step 10: /tmp/tmp4uihzu_a/model.ckpt-10 INFO:tensorflow:Saving 'checkpoint_path' summary for global step 10: /tmp/tmp4uihzu_a/model.ckpt-10 {'average_loss': 1.4210855e-14, 'label/mean': 1.0, 'loss': 1.4210855e-14, 'prediction/mean': 0.99999994, 'global_step': 10}
تفاوت دیگری که در اینجا باید بین Estimator و Keras برجسته شود ، مدیریت ورودی است. در Keras ، هر دسته از مجموعه داده ها به طور خودکار در چندین کپی تقسیم می شوند. با این حال ، در برآوردگر ، شما تقسیم دسته ای خودکار را انجام نمی دهید و یا داده ها را بین کارگران مختلف به طور خودکار خرد نمی کنید. شما کنترل کاملی بر نحوه توزیع داده های خود در بین کارگران و دستگاه ها دارید و باید یک input_fn
ارائه input_fn
تا نحوه توزیع داده های شما مشخص شود.
input_fn
شما برای هر کارگر یک بار فراخوانی می شود ، بنابراین به ازای هر کارگر یک مجموعه داده داده می شود. سپس یک دسته از آن مجموعه داده به یک ماکت در آن کارگر تغذیه می شود ، در نتیجه N دسته برای 1 نسخه از کارگران N مصرف می شود. به عبارت دیگر ، مجموعه داده ای که توسط input_fn
باید دسته هایی از اندازه PER_REPLICA_BATCH_SIZE
. و اندازه دسته جهانی یک مرحله را می توان به عنوان PER_REPLICA_BATCH_SIZE * strategy.num_replicas_in_sync
بدست آورد.
هنگام انجام آموزش های چند کارگر ، باید داده های خود را بین کارگران تقسیم کنید ، یا با یک دانه تصادفی روی هر کدام مرتب شوید. می توانید نمونه ای از نحوه انجام این کار را در آموزش چند کاره با آموزش برآوردگر بررسی کنید .
و به همین ترتیب ، می توانید از استراتژی های چند کارگر و سرور پارامتر نیز استفاده کنید. کد ثابت باقی مانده است ، اما شما باید از tf.estimator.train_and_evaluate
استفاده کنید و متغیرهای محیط TF_CONFIG
برای هر باینری که در خوشه شما اجرا می شود تنظیم کنید.
اکنون چه چیزی پشتیبانی می شود؟
پشتیبانی محدودی از آموزش با Estimator با استفاده از همه استراتژی ها به جز TPUStrategy
. آموزش و ارزیابی v1.train.Scaffold
باید مفید باشد ، اما تعدادی از ویژگی های پیشرفته مانند v1.train.Scaffold
. v1.train.Scaffold
این کار را نمی کند. همچنین ممکن است تعدادی اشکال در این ادغام وجود داشته باشد و هیچ برنامه ای برای بهبود فعال این پشتیبانی وجود ندارد (تمرکز بر پشتیبانی از حلقه Keras و آموزش سفارشی است). در صورت امکان ، شما ترجیح می دهید از tf.distribute
با آن API ها استفاده کنید.
آموزش API | MirroredStrategy | TPUStrategy | MultiWorkerMirroredStrategy | استراتژی CentralStorage | ParameterServerStrategy |
---|---|---|---|---|---|
API برآوردگر | پشتیبانی محدود | پشتیبانی نشده | پشتیبانی محدود | پشتیبانی محدود | پشتیبانی محدود |
مثالها و آموزشها
در اینجا چند نمونه از انتها به انتها آورده شده است که نحوه استفاده از استراتژی های مختلف با Estimator را نشان می دهد:
- آموزش آموزش چند کارگر با برآوردگر نشان می دهد که چگونه می توانید با استفاده از
MultiWorkerMirroredStrategy
در مجموعه داده MNIST با چند کارگر آموزشMultiWorkerMirroredStrategy
. - یک مثال پایان به پایان اجرای آموزش چند کارگری با استراتژی های توزیع در
tensorflow/ecosystem
با استفاده از الگوهای Kubernetes. با یک مدل Keras شروع می شود و با استفاده از APItf.keras.estimator.model_to_estimator
آن را به Estimator تبدیل می کند. - این مقام ResNet50 مدل، که می تواند با استفاده از آموزش دیده
MirroredStrategy
یاMultiWorkerMirroredStrategy
.