Эта страница переведена с помощью Cloud Translation API.
Switch to English

Оценщики

Посмотреть на TensorFlow.org Запустить в Google Colab Посмотреть исходный код на GitHub Скачать блокнот

Этот документ представляет tf.estimator - высокоуровневый API tf.estimator . Оценщики включают в себя следующие действия:

  • Тренировка
  • Оценка
  • Предсказание
  • Экспорт для сервировки

TensorFlow реализует несколько готовых оценщиков. Пользовательские оценки все еще поддерживаются, но в основном в качестве меры обратной совместимости. Пользовательские оценщики не следует использовать для нового кода . Все оценщики - готовые или настраиваемые - представляют собой классы, основанные на классе tf.estimator.Estimator .

В качестве быстрого примера попробуйте учебные пособия по Оценщику . Для обзора дизайна API ознакомьтесь с официальным документом .

Настроить

pip install -q -U tensorflow_datasets
import tempfile
import os

import tensorflow as tf
import tensorflow_datasets as tfds

Преимущества

Подобно tf.keras.Model , estimator представляет собой абстракцию на уровне модели. tf.estimator предоставляет некоторые возможности, которые в настоящее время все еще разрабатываются для tf.keras . Эти:

  • Обучение на основе сервера параметров
  • Полная интеграция с TFX

Возможности оценщиков

Оценщики предоставляют следующие преимущества:

  • Вы можете запускать модели на основе оценщика на локальном хосте или в распределенной многосерверной среде, не меняя модель. Кроме того, вы можете запускать модели на основе Оценщика на процессорах, графических процессорах или TPU без перекодирования вашей модели.
  • Оценщики обеспечивают безопасный распределенный цикл обучения, который контролирует, как и когда:
    • Загрузить данные
    • Обработка исключений
    • Создание файлов контрольных точек и восстановление после сбоев
    • Сохраните сводки для TensorBoard

При написании приложения с помощью оценщиков необходимо отделить конвейер ввода данных от модели. Такое разделение упрощает эксперименты с разными наборами данных.

Использование готовых оценщиков

Готовые оценщики позволяют работать на гораздо более высоком концептуальном уровне, чем базовые API TensorFlow. Вам больше не нужно беспокоиться о создании вычислительного графа или сеансов, поскольку оценщики берут на себя всю "сантехнику" за вас. Более того, готовые оценщики позволяют экспериментировать с различными архитектурами моделей, внося лишь минимальные изменения в код. tf.estimator.DNNClassifier , например, представляет собой предварительно созданный класс Estimator, который обучает модели классификации на основе плотных нейронных сетей с прямой связью.

Программа TensorFlow, основанная на готовом оценщике, обычно состоит из следующих четырех шагов:

1. Напишите функции ввода

Например, вы можете создать одну функцию для импорта обучающего набора и другую функцию для импорта набора тестов. Оценщики ожидают, что их входные данные будут отформатированы как пара объектов:

  • Словарь, в котором ключи являются именами функций, а значения - тензорами (или SparseTensors), содержащими соответствующие данные функций.
  • Тензор, содержащий одну или несколько меток

input_fn должен возвращатьtf.data.Dataset который дает пары в этом формате.

Например, следующий код создаетtf.data.Dataset из файла train.csv набора данных train.csv :

def train_input_fn():
  titanic_file = tf.keras.utils.get_file("train.csv", "https://storage.googleapis.com/tf-datasets/titanic/train.csv")
  titanic = tf.data.experimental.make_csv_dataset(
      titanic_file, batch_size=32,
      label_name="survived")
  titanic_batches = (
      titanic.cache().repeat().shuffle(500)
      .prefetch(tf.data.AUTOTUNE))
  return titanic_batches

input_fn выполняется в tf.Graph и также может напрямую возвращать пару (features_dics, labels) содержащую тензоры графа, но это подвержено ошибкам за пределами простых случаев, таких как возврат констант.

2. Определите столбцы функций.

Каждый tf.feature_column определяет имя функции, ее тип и любую предварительную обработку ввода.

Например, следующий фрагмент кода создает три столбца функций.

  • Первый использует функцию age непосредственно в качестве входных данных с плавающей запятой.
  • Второй использует свойство class как категориальный вход.
  • Третий использует embark_town как категориальный ввод, но использует hashing trick чтобы избежать необходимости перечислять параметры и устанавливать количество параметров.

Для получения дополнительной информации см. Руководство по столбцам функций .

age = tf.feature_column.numeric_column('age')
cls = tf.feature_column.categorical_column_with_vocabulary_list('class', ['First', 'Second', 'Third']) 
embark = tf.feature_column.categorical_column_with_hash_bucket('embark_town', 32)

3. Создайте экземпляр соответствующего предварительно созданного оценщика.

Например, вот пример создания LinearClassifier оценщика с именем LinearClassifier :

model_dir = tempfile.mkdtemp()
model = tf.estimator.LinearClassifier(
    model_dir=model_dir,
    feature_columns=[embark, cls, age],
    n_classes=2
)
INFO:tensorflow:Using default config.
INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmpu27sw9ie', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}

Для получения дополнительной информации вы можете перейти к руководству по линейному классификатору .

4. Назовите метод обучения, оценки или вывода.

Все оценщики предоставляют методы train , evaluate и predict .

model = model.train(input_fn=train_input_fn, steps=100)
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/training/training_util.py:236: Variable.initialized_value (from tensorflow.python.ops.variables) is deprecated and will be removed in a future version.
Instructions for updating:
Use Variable.read_value. Variables in 2.X are initialized automatically both in eager and graph (inside tf.defun) contexts.
Downloading data from https://storage.googleapis.com/tf-datasets/titanic/train.csv
32768/30874 [===============================] - 0s 0us/step
INFO:tensorflow:Calling model_fn.

/tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/keras/engine/base_layer_v1.py:1727: UserWarning: `layer.add_variable` is deprecated and will be removed in a future version. Please use `layer.add_weight` method instead.
  warnings.warn('`layer.add_variable` is deprecated and '

Warning:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/keras/optimizer_v2/ftrl.py:134: calling Constant.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.
Instructions for updating:
Call initializer instance with the dtype argument instead of passing it to the constructor
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmpu27sw9ie/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
INFO:tensorflow:loss = 0.6931472, step = 0
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 100...
INFO:tensorflow:Saving checkpoints for 100 into /tmp/tmpu27sw9ie/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 100...
INFO:tensorflow:Loss for final step: 0.62258995.

result = model.evaluate(train_input_fn, steps=10)

for key, value in result.items():
  print(key, ":", value)
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Starting evaluation at 2021-01-08T02:56:30Z
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /tmp/tmpu27sw9ie/model.ckpt-100
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Evaluation [1/10]
INFO:tensorflow:Evaluation [2/10]
INFO:tensorflow:Evaluation [3/10]
INFO:tensorflow:Evaluation [4/10]
INFO:tensorflow:Evaluation [5/10]
INFO:tensorflow:Evaluation [6/10]
INFO:tensorflow:Evaluation [7/10]
INFO:tensorflow:Evaluation [8/10]
INFO:tensorflow:Evaluation [9/10]
INFO:tensorflow:Evaluation [10/10]
INFO:tensorflow:Inference Time : 0.67613s
INFO:tensorflow:Finished evaluation at 2021-01-08-02:56:31
INFO:tensorflow:Saving dict for global step 100: accuracy = 0.715625, accuracy_baseline = 0.60625, auc = 0.7403657, auc_precision_recall = 0.6804854, average_loss = 0.5836128, global_step = 100, label/mean = 0.39375, loss = 0.5836128, precision = 0.739726, prediction/mean = 0.34897345, recall = 0.42857143
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 100: /tmp/tmpu27sw9ie/model.ckpt-100
accuracy : 0.715625
accuracy_baseline : 0.60625
auc : 0.7403657
auc_precision_recall : 0.6804854
average_loss : 0.5836128
label/mean : 0.39375
loss : 0.5836128
precision : 0.739726
prediction/mean : 0.34897345
recall : 0.42857143
global_step : 100

for pred in model.predict(train_input_fn):
  for key, value in pred.items():
    print(key, ":", value)
  break
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /tmp/tmpu27sw9ie/model.ckpt-100
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
logits : [-0.73942876]
logistic : [0.32312906]
probabilities : [0.6768709 0.3231291]
class_ids : [0]
classes : [b'0']
all_class_ids : [0 1]
all_classes : [b'0' b'1']

Преимущества готовых оценщиков

Готовые оценщики кодируют передовой опыт, обеспечивая следующие преимущества:

  • Лучшие практики для определения того, где должны выполняться различные части вычислительного графа, реализации стратегий на одной машине или в кластере.
  • Лучшие практики для написания (резюме) мероприятий и универсально полезных резюме.

Если вы не используете готовые оценщики, вы должны реализовать указанные выше функции самостоятельно.

Пользовательские оценщики

В основе каждого оценщика - model_fn или настраиваемого - является его модельная функция model_fn , которая представляет собой метод построения графиков для обучения, оценки и прогнозирования. Когда вы используете предварительно созданный оценщик, кто-то другой уже реализовал функцию модели. Если вы полагаетесь на собственный оценщик, вы должны сами написать функцию модели.

Создание оценщика из модели Кераса

Вы можете преобразовать существующие модели tf.keras.estimator.model_to_estimator с помощью tf.keras.estimator.model_to_estimator . Это полезно, если вы хотите модернизировать код модели, но для вашего конвейера обучения по-прежнему требуются оценщики.

Создайте экземпляр модели Keras MobileNet V2 и скомпилируйте модель с оптимизатором, потерями и метриками для обучения:

keras_mobilenet_v2 = tf.keras.applications.MobileNetV2(
    input_shape=(160, 160, 3), include_top=False)
keras_mobilenet_v2.trainable = False

estimator_model = tf.keras.Sequential([
    keras_mobilenet_v2,
    tf.keras.layers.GlobalAveragePooling2D(),
    tf.keras.layers.Dense(1)
])

# Compile the model
estimator_model.compile(
    optimizer='adam',
    loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
    metrics=['accuracy'])
Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/mobilenet_v2/mobilenet_v2_weights_tf_dim_ordering_tf_kernels_1.0_160_no_top.h5
9412608/9406464 [==============================] - 0s 0us/step

Создайте Estimator из скомпилированной модели Кераса. Исходное состояние модели Keras сохраняется в созданном Estimator :

est_mobilenet_v2 = tf.keras.estimator.model_to_estimator(keras_model=estimator_model)
INFO:tensorflow:Using default config.
WARNING:tensorflow:Using temporary folder as model directory: /tmp/tmpeaonpwe8
INFO:tensorflow:Using the Keras model provided.

/tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/keras/backend.py:434: UserWarning: `tf.keras.backend.set_learning_phase` is deprecated and will be removed after 2020-10-11. To update it, simply pass a True/False value to the `training` argument of the `__call__` method of your layer or model.
  warnings.warn('`tf.keras.backend.set_learning_phase` is deprecated and '

INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmpeaonpwe8', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}

Обращайтесь с производным Estimator как с любым другим Estimator .

IMG_SIZE = 160  # All images will be resized to 160x160

def preprocess(image, label):
  image = tf.cast(image, tf.float32)
  image = (image/127.5) - 1
  image = tf.image.resize(image, (IMG_SIZE, IMG_SIZE))
  return image, label
def train_input_fn(batch_size):
  data = tfds.load('cats_vs_dogs', as_supervised=True)
  train_data = data['train']
  train_data = train_data.map(preprocess).shuffle(500).batch(batch_size)
  return train_data

Для обучения вызовите функцию поезда Оценщика:

est_mobilenet_v2.train(input_fn=lambda: train_input_fn(32), steps=50)
Downloading and preparing dataset 786.68 MiB (download: 786.68 MiB, generated: Unknown size, total: 786.68 MiB) to /home/kbuilder/tensorflow_datasets/cats_vs_dogs/4.0.0...

Warning:absl:1738 images were corrupted and were skipped

Dataset cats_vs_dogs downloaded and prepared to /home/kbuilder/tensorflow_datasets/cats_vs_dogs/4.0.0. Subsequent calls will reuse this data.
INFO:tensorflow:Calling model_fn.

INFO:tensorflow:Calling model_fn.

INFO:tensorflow:Done calling model_fn.

INFO:tensorflow:Done calling model_fn.

INFO:tensorflow:Warm-starting with WarmStartSettings: WarmStartSettings(ckpt_to_initialize_from='/tmp/tmpeaonpwe8/keras/keras_model.ckpt', vars_to_warm_start='.*', var_name_to_vocab_info={}, var_name_to_prev_var_name={})

INFO:tensorflow:Warm-starting with WarmStartSettings: WarmStartSettings(ckpt_to_initialize_from='/tmp/tmpeaonpwe8/keras/keras_model.ckpt', vars_to_warm_start='.*', var_name_to_vocab_info={}, var_name_to_prev_var_name={})

INFO:tensorflow:Warm-starting from: /tmp/tmpeaonpwe8/keras/keras_model.ckpt

INFO:tensorflow:Warm-starting from: /tmp/tmpeaonpwe8/keras/keras_model.ckpt

INFO:tensorflow:Warm-starting variables only in TRAINABLE_VARIABLES.

INFO:tensorflow:Warm-starting variables only in TRAINABLE_VARIABLES.

INFO:tensorflow:Warm-started 158 variables.

INFO:tensorflow:Warm-started 158 variables.

INFO:tensorflow:Create CheckpointSaverHook.

INFO:tensorflow:Create CheckpointSaverHook.

INFO:tensorflow:Graph was finalized.

INFO:tensorflow:Graph was finalized.

INFO:tensorflow:Running local_init_op.

INFO:tensorflow:Running local_init_op.

INFO:tensorflow:Done running local_init_op.

INFO:tensorflow:Done running local_init_op.

INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...

INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...

INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmpeaonpwe8/model.ckpt.

INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmpeaonpwe8/model.ckpt.

INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...

INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...

INFO:tensorflow:loss = 0.6884984, step = 0

INFO:tensorflow:loss = 0.6884984, step = 0

INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 50...

INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 50...

INFO:tensorflow:Saving checkpoints for 50 into /tmp/tmpeaonpwe8/model.ckpt.

INFO:tensorflow:Saving checkpoints for 50 into /tmp/tmpeaonpwe8/model.ckpt.

INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 50...

INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 50...

INFO:tensorflow:Loss for final step: 0.67705643.

INFO:tensorflow:Loss for final step: 0.67705643.

<tensorflow_estimator.python.estimator.estimator.EstimatorV2 at 0x7f3d7c3822b0>

Точно так же для оценки вызовите функцию оценки оценщика:

est_mobilenet_v2.evaluate(input_fn=lambda: train_input_fn(32), steps=10)
INFO:tensorflow:Calling model_fn.

INFO:tensorflow:Calling model_fn.
/tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/keras/engine/training.py:2325: UserWarning: `Model.state_updates` will be removed in a future version. This property should not be used in TensorFlow 2.0, as `updates` are applied automatically.
  warnings.warn('`Model.state_updates` will be removed in a future version. '

INFO:tensorflow:Done calling model_fn.

INFO:tensorflow:Done calling model_fn.

INFO:tensorflow:Starting evaluation at 2021-01-08T02:57:32Z

INFO:tensorflow:Starting evaluation at 2021-01-08T02:57:32Z

INFO:tensorflow:Graph was finalized.

INFO:tensorflow:Graph was finalized.

INFO:tensorflow:Restoring parameters from /tmp/tmpeaonpwe8/model.ckpt-50

INFO:tensorflow:Restoring parameters from /tmp/tmpeaonpwe8/model.ckpt-50

INFO:tensorflow:Running local_init_op.

INFO:tensorflow:Running local_init_op.

INFO:tensorflow:Done running local_init_op.

INFO:tensorflow:Done running local_init_op.

INFO:tensorflow:Evaluation [1/10]

INFO:tensorflow:Evaluation [1/10]

INFO:tensorflow:Evaluation [2/10]

INFO:tensorflow:Evaluation [2/10]

INFO:tensorflow:Evaluation [3/10]

INFO:tensorflow:Evaluation [3/10]

INFO:tensorflow:Evaluation [4/10]

INFO:tensorflow:Evaluation [4/10]

INFO:tensorflow:Evaluation [5/10]

INFO:tensorflow:Evaluation [5/10]

INFO:tensorflow:Evaluation [6/10]

INFO:tensorflow:Evaluation [6/10]

INFO:tensorflow:Evaluation [7/10]

INFO:tensorflow:Evaluation [7/10]

INFO:tensorflow:Evaluation [8/10]

INFO:tensorflow:Evaluation [8/10]

INFO:tensorflow:Evaluation [9/10]

INFO:tensorflow:Evaluation [9/10]

INFO:tensorflow:Evaluation [10/10]

INFO:tensorflow:Evaluation [10/10]

INFO:tensorflow:Inference Time : 2.42050s

INFO:tensorflow:Inference Time : 2.42050s

INFO:tensorflow:Finished evaluation at 2021-01-08-02:57:35

INFO:tensorflow:Finished evaluation at 2021-01-08-02:57:35

INFO:tensorflow:Saving dict for global step 50: accuracy = 0.515625, global_step = 50, loss = 0.6688157

INFO:tensorflow:Saving dict for global step 50: accuracy = 0.515625, global_step = 50, loss = 0.6688157

INFO:tensorflow:Saving 'checkpoint_path' summary for global step 50: /tmp/tmpeaonpwe8/model.ckpt-50

INFO:tensorflow:Saving 'checkpoint_path' summary for global step 50: /tmp/tmpeaonpwe8/model.ckpt-50

{'accuracy': 0.515625, 'loss': 0.6688157, 'global_step': 50}

Дополнительные сведения см. В документации к tf.keras.estimator.model_to_estimator .

Сохранение объектных контрольных точек с помощью оценщика

Оценщики по умолчанию сохраняют контрольные точки с именами переменных, а не с графом объектов, описанным в руководстве по контрольным точкам. tf.train.Checkpoint будет читать контрольные точки на основе имен, но имена переменных могут измениться при перемещении частей модели за пределы model_fn . Для обеспечения прямой совместимости сохранение объектных контрольных точек упрощает обучение модели внутри оценщика, а затем использование ее вне его.

import tensorflow.compat.v1 as tf_compat
def toy_dataset():
  inputs = tf.range(10.)[:, None]
  labels = inputs * 5. + tf.range(5.)[None, :]
  return tf.data.Dataset.from_tensor_slices(
    dict(x=inputs, y=labels)).repeat().batch(2)
class Net(tf.keras.Model):
  """A simple linear model."""

  def __init__(self):
    super(Net, self).__init__()
    self.l1 = tf.keras.layers.Dense(5)

  def call(self, x):
    return self.l1(x)
def model_fn(features, labels, mode):
  net = Net()
  opt = tf.keras.optimizers.Adam(0.1)
  ckpt = tf.train.Checkpoint(step=tf_compat.train.get_global_step(),
                             optimizer=opt, net=net)
  with tf.GradientTape() as tape:
    output = net(features['x'])
    loss = tf.reduce_mean(tf.abs(output - features['y']))
  variables = net.trainable_variables
  gradients = tape.gradient(loss, variables)
  return tf.estimator.EstimatorSpec(
    mode,
    loss=loss,
    train_op=tf.group(opt.apply_gradients(zip(gradients, variables)),
                      ckpt.step.assign_add(1)),
    # Tell the Estimator to save "ckpt" in an object-based format.
    scaffold=tf_compat.train.Scaffold(saver=ckpt))

tf.keras.backend.clear_session()
est = tf.estimator.Estimator(model_fn, './tf_estimator_example/')
est.train(toy_dataset, steps=10)
INFO:tensorflow:Using default config.

INFO:tensorflow:Using default config.

INFO:tensorflow:Using config: {'_model_dir': './tf_estimator_example/', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}

INFO:tensorflow:Using config: {'_model_dir': './tf_estimator_example/', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}

INFO:tensorflow:Calling model_fn.

INFO:tensorflow:Calling model_fn.

INFO:tensorflow:Done calling model_fn.

INFO:tensorflow:Done calling model_fn.

INFO:tensorflow:Create CheckpointSaverHook.

INFO:tensorflow:Create CheckpointSaverHook.

INFO:tensorflow:Graph was finalized.

INFO:tensorflow:Graph was finalized.

INFO:tensorflow:Running local_init_op.

INFO:tensorflow:Running local_init_op.

INFO:tensorflow:Done running local_init_op.

INFO:tensorflow:Done running local_init_op.

INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...

INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...

INFO:tensorflow:Saving checkpoints for 0 into ./tf_estimator_example/model.ckpt.

INFO:tensorflow:Saving checkpoints for 0 into ./tf_estimator_example/model.ckpt.

INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...

INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...

INFO:tensorflow:loss = 4.4040537, step = 0

INFO:tensorflow:loss = 4.4040537, step = 0

INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 10...

INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 10...

INFO:tensorflow:Saving checkpoints for 10 into ./tf_estimator_example/model.ckpt.

INFO:tensorflow:Saving checkpoints for 10 into ./tf_estimator_example/model.ckpt.

INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 10...

INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 10...

INFO:tensorflow:Loss for final step: 35.247967.

INFO:tensorflow:Loss for final step: 35.247967.

<tensorflow_estimator.python.estimator.estimator.EstimatorV2 at 0x7f3d64534518>

tf.train.Checkpoint может затем загрузить контрольные точки Оценщика из его model_dir .

opt = tf.keras.optimizers.Adam(0.1)
net = Net()
ckpt = tf.train.Checkpoint(
  step=tf.Variable(1, dtype=tf.int64), optimizer=opt, net=net)
ckpt.restore(tf.train.latest_checkpoint('./tf_estimator_example/'))
ckpt.step.numpy()  # From est.train(..., steps=10)
10

Сохраненные модели из оценщиков

Оценщики экспортируют SavedModels через tf.Estimator.export_saved_model .

input_column = tf.feature_column.numeric_column("x")

estimator = tf.estimator.LinearClassifier(feature_columns=[input_column])

def input_fn():
  return tf.data.Dataset.from_tensor_slices(
    ({"x": [1., 2., 3., 4.]}, [1, 1, 0, 0])).repeat(200).shuffle(64).batch(16)
estimator.train(input_fn)
INFO:tensorflow:Using default config.

INFO:tensorflow:Using default config.

Warning:tensorflow:Using temporary folder as model directory: /tmp/tmpczwhe6jk

Warning:tensorflow:Using temporary folder as model directory: /tmp/tmpczwhe6jk

INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmpczwhe6jk', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}

INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmpczwhe6jk', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}

INFO:tensorflow:Calling model_fn.

INFO:tensorflow:Calling model_fn.

INFO:tensorflow:Done calling model_fn.

INFO:tensorflow:Done calling model_fn.

INFO:tensorflow:Create CheckpointSaverHook.

INFO:tensorflow:Create CheckpointSaverHook.

INFO:tensorflow:Graph was finalized.

INFO:tensorflow:Graph was finalized.

INFO:tensorflow:Running local_init_op.

INFO:tensorflow:Running local_init_op.

INFO:tensorflow:Done running local_init_op.

INFO:tensorflow:Done running local_init_op.

INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...

INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...

INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmpczwhe6jk/model.ckpt.

INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmpczwhe6jk/model.ckpt.

INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...

INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...

INFO:tensorflow:loss = 0.6931472, step = 0

INFO:tensorflow:loss = 0.6931472, step = 0

INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 50...

INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 50...

INFO:tensorflow:Saving checkpoints for 50 into /tmp/tmpczwhe6jk/model.ckpt.

INFO:tensorflow:Saving checkpoints for 50 into /tmp/tmpczwhe6jk/model.ckpt.

INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 50...

INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 50...

INFO:tensorflow:Loss for final step: 0.48830828.

INFO:tensorflow:Loss for final step: 0.48830828.

<tensorflow_estimator.python.estimator.canned.linear.LinearClassifierV2 at 0x7f3d6452eb00>

Чтобы сохранить Estimator вам необходимо создать serving_input_receiver . Эта функция создает часть tf.Graph которая анализирует необработанные данные, полученные SavedModel.

Модуль tf.estimator.export содержит функции, помогающие создавать эти receivers .

Следующий код создает приемник на основе feature_columns , который принимает сериализованные tf.Example протокола tf.Example , которые часто используются с tf- tf.Example .

tmpdir = tempfile.mkdtemp()

serving_input_fn = tf.estimator.export.build_parsing_serving_input_receiver_fn(
  tf.feature_column.make_parse_example_spec([input_column]))

estimator_base_path = os.path.join(tmpdir, 'from_estimator')
estimator_path = estimator.export_saved_model(estimator_base_path, serving_input_fn)
INFO:tensorflow:Calling model_fn.

INFO:tensorflow:Calling model_fn.

INFO:tensorflow:Done calling model_fn.

INFO:tensorflow:Done calling model_fn.

Warning:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/saved_model/signature_def_utils_impl.py:145: build_tensor_info (from tensorflow.python.saved_model.utils_impl) is deprecated and will be removed in a future version.
Instructions for updating:
This function will only be available through the v1 compatibility library as tf.compat.v1.saved_model.utils.build_tensor_info or tf.compat.v1.saved_model.build_tensor_info.

Warning:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/saved_model/signature_def_utils_impl.py:145: build_tensor_info (from tensorflow.python.saved_model.utils_impl) is deprecated and will be removed in a future version.
Instructions for updating:
This function will only be available through the v1 compatibility library as tf.compat.v1.saved_model.utils.build_tensor_info or tf.compat.v1.saved_model.build_tensor_info.

INFO:tensorflow:Signatures INCLUDED in export for Classify: ['serving_default', 'classification']

INFO:tensorflow:Signatures INCLUDED in export for Classify: ['serving_default', 'classification']

INFO:tensorflow:Signatures INCLUDED in export for Regress: ['regression']

INFO:tensorflow:Signatures INCLUDED in export for Regress: ['regression']

INFO:tensorflow:Signatures INCLUDED in export for Predict: ['predict']

INFO:tensorflow:Signatures INCLUDED in export for Predict: ['predict']

INFO:tensorflow:Signatures INCLUDED in export for Train: None

INFO:tensorflow:Signatures INCLUDED in export for Train: None

INFO:tensorflow:Signatures INCLUDED in export for Eval: None

INFO:tensorflow:Signatures INCLUDED in export for Eval: None

INFO:tensorflow:Restoring parameters from /tmp/tmpczwhe6jk/model.ckpt-50

INFO:tensorflow:Restoring parameters from /tmp/tmpczwhe6jk/model.ckpt-50

INFO:tensorflow:Assets added to graph.

INFO:tensorflow:Assets added to graph.

INFO:tensorflow:No assets to write.

INFO:tensorflow:No assets to write.

INFO:tensorflow:SavedModel written to: /tmp/tmp16t8uhub/from_estimator/temp-1610074656/saved_model.pb

INFO:tensorflow:SavedModel written to: /tmp/tmp16t8uhub/from_estimator/temp-1610074656/saved_model.pb

Вы также можете загрузить и запустить эту модель из python:

imported = tf.saved_model.load(estimator_path)

def predict(x):
  example = tf.train.Example()
  example.features.feature["x"].float_list.value.extend([x])
  return imported.signatures["predict"](
    examples=tf.constant([example.SerializeToString()]))
print(predict(1.5))
print(predict(3.5))
{'logistic': <tf.Tensor: shape=(1, 1), dtype=float32, numpy=array([[0.581246]], dtype=float32)>, 'logits': <tf.Tensor: shape=(1, 1), dtype=float32, numpy=array([[0.32789052]], dtype=float32)>, 'all_classes': <tf.Tensor: shape=(1, 2), dtype=string, numpy=array([[b'0', b'1']], dtype=object)>, 'probabilities': <tf.Tensor: shape=(1, 2), dtype=float32, numpy=array([[0.418754, 0.581246]], dtype=float32)>, 'all_class_ids': <tf.Tensor: shape=(1, 2), dtype=int32, numpy=array([[0, 1]], dtype=int32)>, 'classes': <tf.Tensor: shape=(1, 1), dtype=string, numpy=array([[b'1']], dtype=object)>, 'class_ids': <tf.Tensor: shape=(1, 1), dtype=int64, numpy=array([[1]])>}
{'logistic': <tf.Tensor: shape=(1, 1), dtype=float32, numpy=array([[0.24376468]], dtype=float32)>, 'logits': <tf.Tensor: shape=(1, 1), dtype=float32, numpy=array([[-1.1321492]], dtype=float32)>, 'all_classes': <tf.Tensor: shape=(1, 2), dtype=string, numpy=array([[b'0', b'1']], dtype=object)>, 'probabilities': <tf.Tensor: shape=(1, 2), dtype=float32, numpy=array([[0.7562353 , 0.24376468]], dtype=float32)>, 'all_class_ids': <tf.Tensor: shape=(1, 2), dtype=int32, numpy=array([[0, 1]], dtype=int32)>, 'classes': <tf.Tensor: shape=(1, 1), dtype=string, numpy=array([[b'0']], dtype=object)>, 'class_ids': <tf.Tensor: shape=(1, 1), dtype=int64, numpy=array([[0]])>}

tf.estimator.export.build_raw_serving_input_receiver_fn позволяет создавать функции ввода, которые принимают необработанные тензоры, а не tf.train.Example s.

Использование tf.distribute.Strategy с tf.distribute.Strategy (ограниченная поддержка)

tf.estimator - это API tf.estimator для распределенного обучения, который изначально поддерживал подход сервера параметров async. tf.estimator теперь поддерживает tf.distribute.Strategy . Если вы используете tf.estimator , вы можете перейти на распределенное обучение с очень небольшими изменениями в коде. Благодаря этому пользователи Estimator теперь могут выполнять синхронное распределенное обучение на нескольких графических процессорах и нескольких рабочих процессах, а также использовать TPU. Однако эта поддержка в Оценщике ограничена. Более подробную информацию можно найти в разделе « Что сейчас поддерживается » ниже.

Использование tf.distribute.Strategy с tf.distribute.Strategy немного отличается от tf.distribute.Strategy с tf.distribute.Strategy . Вместо использования strategy.scope теперь вы передаете объект стратегии в RunConfig для оценщика.

Вы можете обратиться к распределенному учебному руководству для получения дополнительной информации.

Вот фрагмент кода, который показывает это с предварительно LinearRegressor и MirroredStrategy :

mirrored_strategy = tf.distribute.MirroredStrategy()
config = tf.estimator.RunConfig(
    train_distribute=mirrored_strategy, eval_distribute=mirrored_strategy)
regressor = tf.estimator.LinearRegressor(
    feature_columns=[tf.feature_column.numeric_column('feats')],
    optimizer='SGD',
    config=config)
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)

INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)

INFO:tensorflow:Initializing RunConfig with distribution strategies.

INFO:tensorflow:Initializing RunConfig with distribution strategies.

INFO:tensorflow:Not using Distribute Coordinator.

INFO:tensorflow:Not using Distribute Coordinator.

Warning:tensorflow:Using temporary folder as model directory: /tmp/tmp4uihzu_a

Warning:tensorflow:Using temporary folder as model directory: /tmp/tmp4uihzu_a

INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmp4uihzu_a', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': <tensorflow.python.distribute.mirrored_strategy.MirroredStrategy object at 0x7f3e84699518>, '_device_fn': None, '_protocol': None, '_eval_distribute': <tensorflow.python.distribute.mirrored_strategy.MirroredStrategy object at 0x7f3e84699518>, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1, '_distribute_coordinator_mode': None}

INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmp4uihzu_a', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': <tensorflow.python.distribute.mirrored_strategy.MirroredStrategy object at 0x7f3e84699518>, '_device_fn': None, '_protocol': None, '_eval_distribute': <tensorflow.python.distribute.mirrored_strategy.MirroredStrategy object at 0x7f3e84699518>, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1, '_distribute_coordinator_mode': None}

Здесь вы используете готовый оценщик, но тот же код работает и со специальным оценщиком. train_distribute определяет, как будет распределяться обучение, а eval_distribute определяет, как будет распределяться оценка. Это еще одно отличие от Keras, где вы используете одну и ту же стратегию как для обучения, так и для оценки.

Теперь вы можете обучить и оценить этот Оценщик с помощью функции ввода:

def input_fn():
  dataset = tf.data.Dataset.from_tensors(({"feats":[1.]}, [1.]))
  return dataset.repeat(1000).batch(10)
regressor.train(input_fn=input_fn, steps=10)
regressor.evaluate(input_fn=input_fn, steps=10)
INFO:tensorflow:Calling model_fn.

INFO:tensorflow:Calling model_fn.

INFO:tensorflow:Done calling model_fn.

INFO:tensorflow:Done calling model_fn.

INFO:tensorflow:Create CheckpointSaverHook.

INFO:tensorflow:Create CheckpointSaverHook.

Warning:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_estimator/python/estimator/util.py:96: DistributedIteratorV1.initialize (from tensorflow.python.distribute.input_lib) is deprecated and will be removed in a future version.
Instructions for updating:
Use the iterator's `initializer` property instead.

Warning:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_estimator/python/estimator/util.py:96: DistributedIteratorV1.initialize (from tensorflow.python.distribute.input_lib) is deprecated and will be removed in a future version.
Instructions for updating:
Use the iterator's `initializer` property instead.

INFO:tensorflow:Graph was finalized.

INFO:tensorflow:Graph was finalized.

INFO:tensorflow:Running local_init_op.

INFO:tensorflow:Running local_init_op.

INFO:tensorflow:Done running local_init_op.

INFO:tensorflow:Done running local_init_op.

INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...

INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...

INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmp4uihzu_a/model.ckpt.

INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmp4uihzu_a/model.ckpt.

INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...

INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...

INFO:tensorflow:loss = 1.0, step = 0

INFO:tensorflow:loss = 1.0, step = 0

INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 10...

INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 10...

INFO:tensorflow:Saving checkpoints for 10 into /tmp/tmp4uihzu_a/model.ckpt.

INFO:tensorflow:Saving checkpoints for 10 into /tmp/tmp4uihzu_a/model.ckpt.

INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 10...

INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 10...

INFO:tensorflow:Loss for final step: 2.877698e-13.

INFO:tensorflow:Loss for final step: 2.877698e-13.

INFO:tensorflow:Calling model_fn.

INFO:tensorflow:Calling model_fn.

INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Done calling model_fn.

INFO:tensorflow:Done calling model_fn.

INFO:tensorflow:Starting evaluation at 2021-01-08T02:57:41Z

INFO:tensorflow:Starting evaluation at 2021-01-08T02:57:41Z

INFO:tensorflow:Graph was finalized.

INFO:tensorflow:Graph was finalized.

INFO:tensorflow:Restoring parameters from /tmp/tmp4uihzu_a/model.ckpt-10

INFO:tensorflow:Restoring parameters from /tmp/tmp4uihzu_a/model.ckpt-10

INFO:tensorflow:Running local_init_op.

INFO:tensorflow:Running local_init_op.

INFO:tensorflow:Done running local_init_op.

INFO:tensorflow:Done running local_init_op.

INFO:tensorflow:Evaluation [1/10]

INFO:tensorflow:Evaluation [1/10]

INFO:tensorflow:Evaluation [2/10]

INFO:tensorflow:Evaluation [2/10]

INFO:tensorflow:Evaluation [3/10]

INFO:tensorflow:Evaluation [3/10]

INFO:tensorflow:Evaluation [4/10]

INFO:tensorflow:Evaluation [4/10]

INFO:tensorflow:Evaluation [5/10]

INFO:tensorflow:Evaluation [5/10]

INFO:tensorflow:Evaluation [6/10]

INFO:tensorflow:Evaluation [6/10]

INFO:tensorflow:Evaluation [7/10]

INFO:tensorflow:Evaluation [7/10]

INFO:tensorflow:Evaluation [8/10]

INFO:tensorflow:Evaluation [8/10]

INFO:tensorflow:Evaluation [9/10]

INFO:tensorflow:Evaluation [9/10]

INFO:tensorflow:Evaluation [10/10]

INFO:tensorflow:Evaluation [10/10]

INFO:tensorflow:Inference Time : 0.26266s

INFO:tensorflow:Inference Time : 0.26266s

INFO:tensorflow:Finished evaluation at 2021-01-08-02:57:42

INFO:tensorflow:Finished evaluation at 2021-01-08-02:57:42

INFO:tensorflow:Saving dict for global step 10: average_loss = 1.4210855e-14, global_step = 10, label/mean = 1.0, loss = 1.4210855e-14, prediction/mean = 0.99999994

INFO:tensorflow:Saving dict for global step 10: average_loss = 1.4210855e-14, global_step = 10, label/mean = 1.0, loss = 1.4210855e-14, prediction/mean = 0.99999994

INFO:tensorflow:Saving 'checkpoint_path' summary for global step 10: /tmp/tmp4uihzu_a/model.ckpt-10

INFO:tensorflow:Saving 'checkpoint_path' summary for global step 10: /tmp/tmp4uihzu_a/model.ckpt-10

{'average_loss': 1.4210855e-14,
 'label/mean': 1.0,
 'loss': 1.4210855e-14,
 'prediction/mean': 0.99999994,
 'global_step': 10}

Еще одно отличие, которое следует выделить здесь между оценщиком и Keras, - это обработка ввода. В Keras каждый пакет набора данных автоматически разделяется на несколько реплик. Однако в Оценщике вы не выполняете ни автоматическое разделение пакетов, ни автоматическое сегментирование данных между разными исполнителями. У вас есть полный контроль над тем, как вы хотите, чтобы ваши данные распределялись между рабочими и устройствами, и вы должны предоставить input_fn чтобы указать, как распределять ваши данные.

Ваш input_fn вызывается один раз для каждого рабочего, что дает один набор данных для каждого рабочего. Затем один пакет из этого набора данных передается в одну реплику на этом работнике, тем самым потребляя N пакетов для N реплик на 1 работнике. Другими словами, набор данных, возвращаемый input_fn должен предоставлять пакеты размером PER_REPLICA_BATCH_SIZE . И глобальный размер пакета для шага можно получить как PER_REPLICA_BATCH_SIZE * strategy.num_replicas_in_sync .

При выполнении обучения с несколькими рабочими вы должны либо разделить свои данные между рабочими, либо перемешать их со случайным начальным числом для каждого. Вы можете проверить пример того, как это сделать, в учебнике « Обучение нескольких сотрудников с помощью оценщика» .

Точно так же вы можете использовать стратегии с несколькими рабочими и сервером параметров. Код остается тем же, но вам нужно использовать tf.estimator.train_and_evaluate и установить переменные среды TF_CONFIG для каждого двоичного tf.estimator.train_and_evaluate , запущенного в вашем кластере.

Что сейчас поддерживается?

Существует ограниченная поддержка обучения с помощью Оценщика со всеми стратегиями, кроме TPUStrategy . Базовое обучение и оценка должны работать, но ряд дополнительных функций, таких как v1.train.Scaffold , не работают. В этой интеграции также может быть ряд ошибок, и нет планов активно улучшать эту поддержку (основное внимание уделяется Keras и поддержке пользовательского цикла обучения). Если это вообще возможно, лучше вместо этого использовать tf.distribute с этими API.

API обучения ЗеркальныйСтратегия TPUStrategy MultiWorkerMirroredStrategy CentralStorageСтратегия ParameterServerStrategy
API оценщика Ограниченная поддержка Не поддерживается Ограниченная поддержка Ограниченная поддержка Ограниченная поддержка

Примеры и руководства

Вот несколько сквозных примеров, показывающих, как использовать различные стратегии с Оценщиком:

  1. В учебном пособии «Обучение нескольких сотрудников с помощью оценщика» показано, как можно обучать нескольких сотрудников с помощью MultiWorkerMirroredStrategy в наборе данных MNIST.
  2. tensorflow/ecosystem пример выполнения обучения нескольких сотрудников со стратегиями распределения в tensorflow/ecosystem с использованием шаблонов Kubernetes. Он начинается с модели tf.keras.estimator.model_to_estimator и преобразует ее в tf.keras.estimator.model_to_estimator API tf.keras.estimator.model_to_estimator .
  3. Официальная модель ResNet50 , которую можно обучить с помощью MirroredStrategy или MultiWorkerMirroredStrategy .