Schätzer

Auf TensorFlow.org ansehen In Google Colab ausführen Quelle auf GitHub anzeigen Notizbuch herunterladen

Dieses Dokument stellt tf.estimator -a High-Level - API TensorFlow. Schätzer kapseln die folgenden Aktionen:

  • Ausbildung
  • Auswertung
  • Vorhersage
  • Zur Auslieferung exportieren

TensorFlow implementiert mehrere vorgefertigte Schätzer. Benutzerdefinierte Schätzer werden weiterhin unterstützt, jedoch hauptsächlich als Maß für die Abwärtskompatibilität. Benutzerdefinierte Schätzer sollten nicht für neuen Code verwendet werden. All Schätzer-vorgefertigte oder benutzerdefinierte diejenigen-Klassen sind auf der Basis der tf.estimator.Estimator Klasse.

Für ein schnelles Beispiel versuchen Estimator Tutorials . Einen Überblick über das API - Design, überprüfen Sie das Whitepaper .

Aufstellen

pip install -U tensorflow_datasets
import tempfile
import os

import tensorflow as tf
import tensorflow_datasets as tfds

Vorteile

Ähnlich wie bei einem tf.keras.Model , ein estimator ist eine Modell-Ebene Abstraktion. Die tf.estimator bietet einige Möglichkeiten zur Zeit noch in der Entwicklung für tf.keras . Diese sind:

  • Parameterserverbasiertes Training
  • Volle TFX Integration

Schätzer-Funktionen

Schätzer bieten die folgenden Vorteile:

  • Sie können Estimator-basierte Modelle auf einem lokalen Host oder in einer verteilten Umgebung mit mehreren Servern ausführen, ohne Ihr Modell zu ändern. Darüber hinaus können Sie Estimator-basierte Modelle auf CPUs, GPUs oder TPUs ausführen, ohne Ihr Modell neu zu codieren.
  • Schätzer bieten eine sichere verteilte Trainingsschleife, die steuert, wie und wann:
    • Lade Daten
    • Ausnahmen behandeln
    • Prüfpunktdateien erstellen und nach Fehlern wiederherstellen
    • Zusammenfassungen für TensorBoard speichern

Beim Schreiben einer Anwendung mit Estimators müssen Sie die Dateneingabepipeline vom Modell trennen. Diese Trennung vereinfacht Experimente mit unterschiedlichen Datensätzen.

Verwendung vorgefertigter Schätzer

Mit vorgefertigten Estimatoren können Sie auf einer viel höheren konzeptionellen Ebene arbeiten als die Basis-APIs von TensorFlow. Sie müssen sich nicht mehr um das Erstellen des Rechendiagramms oder der Sitzungen kümmern, da Estimators die gesamte "Installation" für Sie übernimmt. Darüber hinaus können Sie mit vorgefertigten Estimators mit verschiedenen Modellarchitekturen experimentieren, indem Sie nur minimale Codeänderungen vornehmen. tf.estimator.DNNClassifier , zum Beispiel, ist eine vorgefertigte Estimator Klasse dass Modelle basierten Klassifizierungs Züge auf dicht, mit Störgrößenaufschaltung neuronaler Netzen.

Ein TensorFlow-Programm, das auf einem vorgefertigten Estimator basiert, besteht normalerweise aus den folgenden vier Schritten:

1. Schreiben Sie eine Eingangsfunktion

Sie können beispielsweise eine Funktion zum Importieren des Trainingssatzes und eine weitere Funktion zum Importieren des Testsatzes erstellen. Schätzer erwarten, dass ihre Eingaben als ein Objektpaar formatiert sind:

  • Ein Wörterbuch, in dem die Schlüssel Merkmalsnamen sind und die Werte Tensoren (oder SparseTensors) sind, die die entsprechenden Merkmalsdaten enthalten
  • Ein Tensor mit einem oder mehreren Labels

Die input_fn sollte eine Rückkehr tf.data.Dataset dass Ausbeuten Paare in diesem Format.

Zum Beispiel baut der folgende Code ein tf.data.Dataset aus der Titanic - Datensatz train.csv Datei:

def train_input_fn():
  titanic_file = tf.keras.utils.get_file("train.csv", "https://storage.googleapis.com/tf-datasets/titanic/train.csv")
  titanic = tf.data.experimental.make_csv_dataset(
      titanic_file, batch_size=32,
      label_name="survived")
  titanic_batches = (
      titanic.cache().repeat().shuffle(500)
      .prefetch(tf.data.AUTOTUNE))
  return titanic_batches

Die input_fn wird in einem ausgeführt tf.Graph und kann auch eine direkt zurück (features_dics, labels) Paar enthält Graph Tensoren, dies ist jedoch fehleranfällig außerhalb einfachen Fällen wie eine Rückkehr Konstanten.

2. Definieren Sie die Feature-Spalten.

Jede tf.feature_column identifiziert einen Merkmalsname, Typ, und alle Eingangs Vorverarbeitung.

Das folgende Snippet erstellt beispielsweise drei Feature-Spalten.

  • Die erste verwendet die age Funktion direkt als Gleitkommazahl - Eingang.
  • Der zweite verwendet die class Funktion als kategorischen Eingang.
  • Die dritten Verwendungen die embark_town als kategorischer Eingang, sondern verwendet den hashing trick - hashing trick die Notwendigkeit zu vermeiden , die Möglichkeiten aufzuzählen, und die Anzahl der Optionen einzustellen.

Für weitere Informationen besuchen Sie das Feature Spalten Tutorial .

age = tf.feature_column.numeric_column('age')
cls = tf.feature_column.categorical_column_with_vocabulary_list('class', ['First', 'Second', 'Third']) 
embark = tf.feature_column.categorical_column_with_hash_bucket('embark_town', 32)

3. Instanziieren Sie den entsprechenden vorgefertigten Estimator.

Zum Beispiel, hier ist eine Probe Instantiierung eines vorgefertigten Estimator namens LinearClassifier :

model_dir = tempfile.mkdtemp()
model = tf.estimator.LinearClassifier(
    model_dir=model_dir,
    feature_columns=[embark, cls, age],
    n_classes=2
)
INFO:tensorflow:Using default config.
INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmpr_ditsvt', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}

Für weitere Informationen können Sie das gehen lineare Klassifizierer Tutorial .

4. Rufen Sie eine Trainings-, Bewertungs- oder Inferenzmethode auf.

Alle Schätzer bieten train , evaluate und predict Methoden.

model = model.train(input_fn=train_input_fn, steps=100)
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/training/training_util.py:236: Variable.initialized_value (from tensorflow.python.ops.variables) is deprecated and will be removed in a future version.
Instructions for updating:
Use Variable.read_value. Variables in 2.X are initialized automatically both in eager and graph (inside tf.defun) contexts.
INFO:tensorflow:Calling model_fn.
/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/keras/engine/base_layer_v1.py:1684: UserWarning: `layer.add_variable` is deprecated and will be removed in a future version. Please use `layer.add_weight` method instead.
  warnings.warn('`layer.add_variable` is deprecated and '
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/keras/optimizer_v2/ftrl.py:147: calling Constant.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.
Instructions for updating:
Call initializer instance with the dtype argument instead of passing it to the constructor
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmpr_ditsvt/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
INFO:tensorflow:loss = 0.6931472, step = 0
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 100...
INFO:tensorflow:Saving checkpoints for 100 into /tmp/tmpr_ditsvt/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 100...
INFO:tensorflow:Loss for final step: 0.617588.
2021-08-28 01:41:00.871385: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
result = model.evaluate(train_input_fn, steps=10)

for key, value in result.items():
  print(key, ":", value)
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Starting evaluation at 2021-08-28T01:41:01
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /tmp/tmpr_ditsvt/model.ckpt-100
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Evaluation [1/10]
INFO:tensorflow:Evaluation [2/10]
INFO:tensorflow:Evaluation [3/10]
INFO:tensorflow:Evaluation [4/10]
INFO:tensorflow:Evaluation [5/10]
INFO:tensorflow:Evaluation [6/10]
INFO:tensorflow:Evaluation [7/10]
INFO:tensorflow:Evaluation [8/10]
INFO:tensorflow:Evaluation [9/10]
INFO:tensorflow:Evaluation [10/10]
INFO:tensorflow:Inference Time : 0.62957s
INFO:tensorflow:Finished evaluation at 2021-08-28-01:41:02
INFO:tensorflow:Saving dict for global step 100: accuracy = 0.6375, accuracy_baseline = 0.6125, auc = 0.71408004, auc_precision_recall = 0.59793115, average_loss = 0.63505936, global_step = 100, label/mean = 0.3875, loss = 0.63505936, precision = 0.525, prediction/mean = 0.51330584, recall = 0.67741936
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 100: /tmp/tmpr_ditsvt/model.ckpt-100
accuracy : 0.6375
accuracy_baseline : 0.6125
auc : 0.71408004
auc_precision_recall : 0.59793115
average_loss : 0.63505936
label/mean : 0.3875
loss : 0.63505936
precision : 0.525
prediction/mean : 0.51330584
recall : 0.67741936
global_step : 100
2021-08-28 01:41:02.211817: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
for pred in model.predict(train_input_fn):
  for key, value in pred.items():
    print(key, ":", value)
  break
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /tmp/tmpr_ditsvt/model.ckpt-100
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
logits : [-0.67344487]
logistic : [0.3377259]
probabilities : [0.66227406 0.3377259 ]
class_ids : [0]
classes : [b'0']
all_class_ids : [0 1]
all_classes : [b'0' b'1']
2021-08-28 01:41:03.085864: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.

Vorteile vorgefertigter Schätzer

Vorgefertigte Estimators codieren Best Practices und bieten die folgenden Vorteile:

  • Best Practices zum Bestimmen, wo verschiedene Teile des Berechnungsdiagramms ausgeführt werden sollen, Implementieren von Strategien auf einem einzelnen Computer oder in einem Cluster.
  • Best Practices für das Schreiben von Ereignissen (Zusammenfassungen) und universell nützliche Zusammenfassungen.

Wenn Sie keine vorgefertigten Schätzer verwenden, müssen Sie die vorstehenden Funktionen selbst implementieren.

Benutzerdefinierte Schätzer

Das Herz jeden Estimator-ob vorgefertigte oder individuell ist seine Modellfunktion, model_fn , das ein Verfahren ist , die grafischen Darstellungen für die Ausbildung, Bewertung und Vorhersage aufbaut. Wenn Sie einen vorgefertigten Estimator verwenden, hat bereits jemand anderes die Modellfunktion implementiert. Wenn Sie sich auf einen benutzerdefinierten Estimator verlassen, müssen Sie die Modellfunktion selbst schreiben.

Erstellen Sie einen Estimator aus einem Keras-Modell

Sie können vorhandene Keras Modelle Schätzer mit konvertieren tf.keras.estimator.model_to_estimator . Dies ist hilfreich, wenn Sie Ihren Modellcode modernisieren möchten, Ihre Trainingspipeline jedoch weiterhin Estimators erfordert.

Instanziieren Sie ein Keras MobileNet V2-Modell und kompilieren Sie das Modell mit dem Optimierer, dem Verlust und den Metriken zum Trainieren:

keras_mobilenet_v2 = tf.keras.applications.MobileNetV2(
    input_shape=(160, 160, 3), include_top=False)
keras_mobilenet_v2.trainable = False

estimator_model = tf.keras.Sequential([
    keras_mobilenet_v2,
    tf.keras.layers.GlobalAveragePooling2D(),
    tf.keras.layers.Dense(1)
])

# Compile the model
estimator_model.compile(
    optimizer='adam',
    loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
    metrics=['accuracy'])
Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/mobilenet_v2/mobilenet_v2_weights_tf_dim_ordering_tf_kernels_1.0_160_no_top.h5
9412608/9406464 [==============================] - 0s 0us/step
9420800/9406464 [==============================] - 0s 0us/step

Erstellen Sie einen Estimator aus dem Keras Modell zusammengestellt. Das Einstiegsmodell Zustand des Keras Modell wird im erstellt erhalten Estimator :

est_mobilenet_v2 = tf.keras.estimator.model_to_estimator(keras_model=estimator_model)
INFO:tensorflow:Using default config.
WARNING:tensorflow:Using temporary folder as model directory: /tmp/tmp3_e4p5uk
INFO:tensorflow:Using the Keras model provided.
/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/keras/backend.py:401: UserWarning: `tf.keras.backend.set_learning_phase` is deprecated and will be removed after 2020-10-11. To update it, simply pass a True/False value to the `training` argument of the `__call__` method of your layer or model.
  warnings.warn('`tf.keras.backend.set_learning_phase` is deprecated and '
/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/keras/utils/generic_utils.py:497: CustomMaskWarning: Custom mask layers require a config and must override get_config. When loading, the custom mask layer must be passed to the custom_objects argument.
  category=CustomMaskWarning)
INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmp3_e4p5uk', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}

Behandeln Sie die abgeleitete Estimator wie bei jedem anderen Estimator .

IMG_SIZE = 160  # All images will be resized to 160x160

def preprocess(image, label):
  image = tf.cast(image, tf.float32)
  image = (image/127.5) - 1
  image = tf.image.resize(image, (IMG_SIZE, IMG_SIZE))
  return image, label
def train_input_fn(batch_size):
  data = tfds.load('cats_vs_dogs', as_supervised=True)
  train_data = data['train']
  train_data = train_data.map(preprocess).shuffle(500).batch(batch_size)
  return train_data

Rufen Sie zum Trainieren die Train-Funktion von Estimator auf:

est_mobilenet_v2.train(input_fn=lambda: train_input_fn(32), steps=50)
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Warm-starting with WarmStartSettings: WarmStartSettings(ckpt_to_initialize_from='/tmp/tmp3_e4p5uk/keras/keras_model.ckpt', vars_to_warm_start='.*', var_name_to_vocab_info={}, var_name_to_prev_var_name={})
INFO:tensorflow:Warm-starting with WarmStartSettings: WarmStartSettings(ckpt_to_initialize_from='/tmp/tmp3_e4p5uk/keras/keras_model.ckpt', vars_to_warm_start='.*', var_name_to_vocab_info={}, var_name_to_prev_var_name={})
INFO:tensorflow:Warm-starting from: /tmp/tmp3_e4p5uk/keras/keras_model.ckpt
INFO:tensorflow:Warm-starting from: /tmp/tmp3_e4p5uk/keras/keras_model.ckpt
INFO:tensorflow:Warm-starting variables only in TRAINABLE_VARIABLES.
INFO:tensorflow:Warm-starting variables only in TRAINABLE_VARIABLES.
INFO:tensorflow:Warm-started 158 variables.
INFO:tensorflow:Warm-started 158 variables.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmp3_e4p5uk/model.ckpt.
INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmp3_e4p5uk/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
INFO:tensorflow:loss = 0.7921847, step = 0
INFO:tensorflow:loss = 0.7921847, step = 0
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 50...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 50...
INFO:tensorflow:Saving checkpoints for 50 into /tmp/tmp3_e4p5uk/model.ckpt.
INFO:tensorflow:Saving checkpoints for 50 into /tmp/tmp3_e4p5uk/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 50...
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 50...
INFO:tensorflow:Loss for final step: 0.64612174.
INFO:tensorflow:Loss for final step: 0.64612174.
<tensorflow_estimator.python.estimator.estimator.EstimatorV2 at 0x7fe8c4524310>

Rufen Sie zum Auswerten auf ähnliche Weise die Auswertungsfunktion des Estimators auf:

est_mobilenet_v2.evaluate(input_fn=lambda: train_input_fn(32), steps=10)
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Calling model_fn.
/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/keras/engine/training.py:2470: UserWarning: `Model.state_updates` will be removed in a future version. This property should not be used in TensorFlow 2.0, as `updates` are applied automatically.
  warnings.warn('`Model.state_updates` will be removed in a future version. '
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Starting evaluation at 2021-08-28T01:41:32
INFO:tensorflow:Starting evaluation at 2021-08-28T01:41:32
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /tmp/tmp3_e4p5uk/model.ckpt-50
INFO:tensorflow:Restoring parameters from /tmp/tmp3_e4p5uk/model.ckpt-50
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Evaluation [1/10]
INFO:tensorflow:Evaluation [1/10]
INFO:tensorflow:Evaluation [2/10]
INFO:tensorflow:Evaluation [2/10]
INFO:tensorflow:Evaluation [3/10]
INFO:tensorflow:Evaluation [3/10]
INFO:tensorflow:Evaluation [4/10]
INFO:tensorflow:Evaluation [4/10]
INFO:tensorflow:Evaluation [5/10]
INFO:tensorflow:Evaluation [5/10]
INFO:tensorflow:Evaluation [6/10]
INFO:tensorflow:Evaluation [6/10]
INFO:tensorflow:Evaluation [7/10]
INFO:tensorflow:Evaluation [7/10]
INFO:tensorflow:Evaluation [8/10]
INFO:tensorflow:Evaluation [8/10]
INFO:tensorflow:Evaluation [9/10]
INFO:tensorflow:Evaluation [9/10]
INFO:tensorflow:Evaluation [10/10]
INFO:tensorflow:Evaluation [10/10]
INFO:tensorflow:Inference Time : 5.56003s
INFO:tensorflow:Inference Time : 5.56003s
INFO:tensorflow:Finished evaluation at 2021-08-28-01:41:37
INFO:tensorflow:Finished evaluation at 2021-08-28-01:41:37
INFO:tensorflow:Saving dict for global step 50: accuracy = 0.5125, global_step = 50, loss = 0.6737833
INFO:tensorflow:Saving dict for global step 50: accuracy = 0.5125, global_step = 50, loss = 0.6737833
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 50: /tmp/tmp3_e4p5uk/model.ckpt-50
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 50: /tmp/tmp3_e4p5uk/model.ckpt-50
{'accuracy': 0.5125, 'loss': 0.6737833, 'global_step': 50}

Weitere Einzelheiten finden Sie in der Dokumentation zu tf.keras.estimator.model_to_estimator .

Speichern von objektbasierten Prüfpunkten mit Estimator

Schätzern von Standardspeicher Checkpoints mit Variablennamen anstelle des Objektgraphen in der beschriebenen Checkpoint Führung . tf.train.Checkpoint wird namensbasierte Checkpoints lesen, aber Variablennamen ändern können , wenn Teile eines Modells außerhalb des Estimator bewegt sich model_fn . Aus Gründen der Vorwärtskompatibilität macht es das Speichern objektbasierter Prüfpunkte einfacher, ein Modell innerhalb eines Estimators zu trainieren und es dann außerhalb eines Estimators zu verwenden.

import tensorflow.compat.v1 as tf_compat
def toy_dataset():
  inputs = tf.range(10.)[:, None]
  labels = inputs * 5. + tf.range(5.)[None, :]
  return tf.data.Dataset.from_tensor_slices(
    dict(x=inputs, y=labels)).repeat().batch(2)
class Net(tf.keras.Model):
  """A simple linear model."""

  def __init__(self):
    super(Net, self).__init__()
    self.l1 = tf.keras.layers.Dense(5)

  def call(self, x):
    return self.l1(x)
def model_fn(features, labels, mode):
  net = Net()
  opt = tf.keras.optimizers.Adam(0.1)
  ckpt = tf.train.Checkpoint(step=tf_compat.train.get_global_step(),
                             optimizer=opt, net=net)
  with tf.GradientTape() as tape:
    output = net(features['x'])
    loss = tf.reduce_mean(tf.abs(output - features['y']))
  variables = net.trainable_variables
  gradients = tape.gradient(loss, variables)
  return tf.estimator.EstimatorSpec(
    mode,
    loss=loss,
    train_op=tf.group(opt.apply_gradients(zip(gradients, variables)),
                      ckpt.step.assign_add(1)),
    # Tell the Estimator to save "ckpt" in an object-based format.
    scaffold=tf_compat.train.Scaffold(saver=ckpt))

tf.keras.backend.clear_session()
est = tf.estimator.Estimator(model_fn, './tf_estimator_example/')
est.train(toy_dataset, steps=10)
INFO:tensorflow:Using default config.
INFO:tensorflow:Using default config.
INFO:tensorflow:Using config: {'_model_dir': './tf_estimator_example/', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
INFO:tensorflow:Using config: {'_model_dir': './tf_estimator_example/', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Saving checkpoints for 0 into ./tf_estimator_example/model.ckpt.
INFO:tensorflow:Saving checkpoints for 0 into ./tf_estimator_example/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
INFO:tensorflow:loss = 4.2813673, step = 0
INFO:tensorflow:loss = 4.2813673, step = 0
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 10...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 10...
INFO:tensorflow:Saving checkpoints for 10 into ./tf_estimator_example/model.ckpt.
INFO:tensorflow:Saving checkpoints for 10 into ./tf_estimator_example/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 10...
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 10...
INFO:tensorflow:Loss for final step: 33.1623.
INFO:tensorflow:Loss for final step: 33.1623.
<tensorflow_estimator.python.estimator.estimator.EstimatorV2 at 0x7fe8b8393b90>

tf.train.Checkpoint können dann die Checkpoints des Estimator laden von seinem model_dir .

opt = tf.keras.optimizers.Adam(0.1)
net = Net()
ckpt = tf.train.Checkpoint(
  step=tf.Variable(1, dtype=tf.int64), optimizer=opt, net=net)
ckpt.restore(tf.train.latest_checkpoint('./tf_estimator_example/'))
ckpt.step.numpy()  # From est.train(..., steps=10)
10

Gespeicherte Modelle von Schätzern

Schätzer exportieren SavedModels durch tf.Estimator.export_saved_model .

input_column = tf.feature_column.numeric_column("x")

estimator = tf.estimator.LinearClassifier(feature_columns=[input_column])

def input_fn():
  return tf.data.Dataset.from_tensor_slices(
    ({"x": [1., 2., 3., 4.]}, [1, 1, 0, 0])).repeat(200).shuffle(64).batch(16)
estimator.train(input_fn)
INFO:tensorflow:Using default config.
INFO:tensorflow:Using default config.
WARNING:tensorflow:Using temporary folder as model directory: /tmp/tmpzoxucggh
WARNING:tensorflow:Using temporary folder as model directory: /tmp/tmpzoxucggh
INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmpzoxucggh', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmpzoxucggh', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmpzoxucggh/model.ckpt.
INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmpzoxucggh/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
INFO:tensorflow:loss = 0.6931472, step = 0
INFO:tensorflow:loss = 0.6931472, step = 0
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 50...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 50...
INFO:tensorflow:Saving checkpoints for 50 into /tmp/tmpzoxucggh/model.ckpt.
INFO:tensorflow:Saving checkpoints for 50 into /tmp/tmpzoxucggh/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 50...
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 50...
INFO:tensorflow:Loss for final step: 0.44376355.
INFO:tensorflow:Loss for final step: 0.44376355.
<tensorflow_estimator.python.estimator.canned.linear.LinearClassifierV2 at 0x7fe8b857a7d0>

Um eine zu sparen Estimator benötigen Sie einen erstellen serving_input_receiver . Diese Funktion baut einen Teil eines tf.Graph , die die empfangenen Rohdaten von der SavedModel analysiert.

Das tf.estimator.export Modul enthält Funktionen zu helfen , diese bauen receivers .

Der folgende Code baut einen Empfänger, basierend auf dem feature_columns , die serialisierten nimmt tf.Example Protokollpuffer, die häufig verwendet werden , mit tf-serving .

tmpdir = tempfile.mkdtemp()

serving_input_fn = tf.estimator.export.build_parsing_serving_input_receiver_fn(
  tf.feature_column.make_parse_example_spec([input_column]))

estimator_base_path = os.path.join(tmpdir, 'from_estimator')
estimator_path = estimator.export_saved_model(estimator_base_path, serving_input_fn)
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Done calling model_fn.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/saved_model/signature_def_utils_impl.py:145: build_tensor_info (from tensorflow.python.saved_model.utils_impl) is deprecated and will be removed in a future version.
Instructions for updating:
This function will only be available through the v1 compatibility library as tf.compat.v1.saved_model.utils.build_tensor_info or tf.compat.v1.saved_model.build_tensor_info.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/saved_model/signature_def_utils_impl.py:145: build_tensor_info (from tensorflow.python.saved_model.utils_impl) is deprecated and will be removed in a future version.
Instructions for updating:
This function will only be available through the v1 compatibility library as tf.compat.v1.saved_model.utils.build_tensor_info or tf.compat.v1.saved_model.build_tensor_info.
INFO:tensorflow:Signatures INCLUDED in export for Classify: ['serving_default', 'classification']
INFO:tensorflow:Signatures INCLUDED in export for Classify: ['serving_default', 'classification']
INFO:tensorflow:Signatures INCLUDED in export for Regress: ['regression']
INFO:tensorflow:Signatures INCLUDED in export for Regress: ['regression']
INFO:tensorflow:Signatures INCLUDED in export for Predict: ['predict']
INFO:tensorflow:Signatures INCLUDED in export for Predict: ['predict']
INFO:tensorflow:Signatures INCLUDED in export for Train: None
INFO:tensorflow:Signatures INCLUDED in export for Train: None
INFO:tensorflow:Signatures INCLUDED in export for Eval: None
INFO:tensorflow:Signatures INCLUDED in export for Eval: None
INFO:tensorflow:Restoring parameters from /tmp/tmpzoxucggh/model.ckpt-50
INFO:tensorflow:Restoring parameters from /tmp/tmpzoxucggh/model.ckpt-50
INFO:tensorflow:Assets added to graph.
INFO:tensorflow:Assets added to graph.
INFO:tensorflow:No assets to write.
INFO:tensorflow:No assets to write.
INFO:tensorflow:SavedModel written to: /tmp/tmpdexgg0w2/from_estimator/temp-1630114899/saved_model.pb
INFO:tensorflow:SavedModel written to: /tmp/tmpdexgg0w2/from_estimator/temp-1630114899/saved_model.pb

Sie können dieses Modell auch von Python aus laden und ausführen:

imported = tf.saved_model.load(estimator_path)

def predict(x):
  example = tf.train.Example()
  example.features.feature["x"].float_list.value.extend([x])
  return imported.signatures["predict"](
    examples=tf.constant([example.SerializeToString()]))
print(predict(1.5))
print(predict(3.5))
{'probabilities': <tf.Tensor: shape=(1, 2), dtype=float32, numpy=array([[0.43685353, 0.5631464 ]], dtype=float32)>, 'logistic': <tf.Tensor: shape=(1, 1), dtype=float32, numpy=array([[0.5631464]], dtype=float32)>, 'classes': <tf.Tensor: shape=(1, 1), dtype=string, numpy=array([[b'1']], dtype=object)>, 'all_class_ids': <tf.Tensor: shape=(1, 2), dtype=int32, numpy=array([[0, 1]], dtype=int32)>, 'all_classes': <tf.Tensor: shape=(1, 2), dtype=string, numpy=array([[b'0', b'1']], dtype=object)>, 'logits': <tf.Tensor: shape=(1, 1), dtype=float32, numpy=array([[0.25394177]], dtype=float32)>, 'class_ids': <tf.Tensor: shape=(1, 1), dtype=int64, numpy=array([[1]])>}
{'probabilities': <tf.Tensor: shape=(1, 2), dtype=float32, numpy=array([[0.77109075, 0.22890931]], dtype=float32)>, 'logistic': <tf.Tensor: shape=(1, 1), dtype=float32, numpy=array([[0.22890928]], dtype=float32)>, 'classes': <tf.Tensor: shape=(1, 1), dtype=string, numpy=array([[b'0']], dtype=object)>, 'all_class_ids': <tf.Tensor: shape=(1, 2), dtype=int32, numpy=array([[0, 1]], dtype=int32)>, 'all_classes': <tf.Tensor: shape=(1, 2), dtype=string, numpy=array([[b'0', b'1']], dtype=object)>, 'logits': <tf.Tensor: shape=(1, 1), dtype=float32, numpy=array([[-1.2144802]], dtype=float32)>, 'class_ids': <tf.Tensor: shape=(1, 1), dtype=int64, numpy=array([[0]])>}

tf.estimator.export.build_raw_serving_input_receiver_fn können Sie Eingabefunktionen schaffen , die roh Tensoren nehmen statt tf.train.Example s.

Mit tf.distribute.Strategy mit Estimator (eingeschränkte Unterstützung)

tf.estimator ist ein verteilte Ausbildung TensorFlow API , die von dem Asynchron - Parameter - Server - Ansatz ursprünglich unterstützt. tf.estimator unterstützt jetzt tf.distribute.Strategy . Wenn Sie mit tf.estimator , können Sie mit sehr wenigen Änderungen an Ihrem Code auf verteilte Training ändern. Damit können Estimator-Benutzer jetzt synchron verteiltes Training auf mehreren GPUs und mehreren Workern durchführen sowie TPUs verwenden. Diese Unterstützung in Estimator ist jedoch begrenzt. Überprüfen Sie die out Was unterstützten nun folgenden Abschnitt für weitere Details.

Mit tf.distribute.Strategy mit Estimator ist etwas anders als im Fall Keras. Anstelle der Verwendung von strategy.scope , übergeben Sie jetzt die Strategie Objekt in die RunConfig für den Estimator.

Sie können auf die beziehen verteilt Schulungsleitfaden für weitere Informationen.

Hier ist ein Ausschnitt von Code, zeigt dies mit einem vorgefertigten Estimator LinearRegressor und MirroredStrategy :

mirrored_strategy = tf.distribute.MirroredStrategy()
config = tf.estimator.RunConfig(
    train_distribute=mirrored_strategy, eval_distribute=mirrored_strategy)
regressor = tf.estimator.LinearRegressor(
    feature_columns=[tf.feature_column.numeric_column('feats')],
    optimizer='SGD',
    config=config)
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)
INFO:tensorflow:Initializing RunConfig with distribution strategies.
INFO:tensorflow:Initializing RunConfig with distribution strategies.
INFO:tensorflow:Not using Distribute Coordinator.
INFO:tensorflow:Not using Distribute Coordinator.
WARNING:tensorflow:Using temporary folder as model directory: /tmp/tmpjvzbia52
WARNING:tensorflow:Using temporary folder as model directory: /tmp/tmpjvzbia52
INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmpjvzbia52', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': <tensorflow.python.distribute.mirrored_strategy.MirroredStrategy object at 0x7fe8c438ded0>, '_device_fn': None, '_protocol': None, '_eval_distribute': <tensorflow.python.distribute.mirrored_strategy.MirroredStrategy object at 0x7fe8c438ded0>, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1, '_distribute_coordinator_mode': None}
INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmpjvzbia52', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': <tensorflow.python.distribute.mirrored_strategy.MirroredStrategy object at 0x7fe8c438ded0>, '_device_fn': None, '_protocol': None, '_eval_distribute': <tensorflow.python.distribute.mirrored_strategy.MirroredStrategy object at 0x7fe8c438ded0>, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1, '_distribute_coordinator_mode': None}

Hier verwenden Sie einen vorgefertigten Estimator, aber derselbe Code funktioniert auch mit einem benutzerdefinierten Estimator. train_distribute legt fest , wie Ausbildung verteilt werden, und eval_distribute bestimmt , wie Auswertung verteilt werden. Dies ist ein weiterer Unterschied zu Keras, bei dem Sie die gleiche Strategie für Training und Evaluierung verwenden.

Jetzt können Sie diesen Estimator mit einer Eingabefunktion trainieren und auswerten:

def input_fn():
  dataset = tf.data.Dataset.from_tensors(({"feats":[1.]}, [1.]))
  return dataset.repeat(1000).batch(10)
regressor.train(input_fn=input_fn, steps=10)
regressor.evaluate(input_fn=input_fn, steps=10)
INFO:tensorflow:Calling model_fn.
/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/data/ops/dataset_ops.py:374: UserWarning: To make it possible to preserve tf.data options across serialization boundaries, their implementation has moved to be part of the TensorFlow graph. As a consequence, the options value is in general no longer known at graph construction time. Invoking this method in graph mode retains the legacy behavior of the original implementation, but note that the returned value might not reflect the actual value of the options.
  warnings.warn("To make it possible to preserve tf.data options across "
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Create CheckpointSaverHook.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow_estimator/python/estimator/util.py:95: DistributedIteratorV1.initialize (from tensorflow.python.distribute.input_lib) is deprecated and will be removed in a future version.
Instructions for updating:
Use the iterator's `initializer` property instead.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow_estimator/python/estimator/util.py:95: DistributedIteratorV1.initialize (from tensorflow.python.distribute.input_lib) is deprecated and will be removed in a future version.
Instructions for updating:
Use the iterator's `initializer` property instead.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmpjvzbia52/model.ckpt.
INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmpjvzbia52/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
2021-08-28 01:41:43.214993: W tensorflow/core/grappler/utils/graph_view.cc:836] No registered 'MultiDeviceIteratorFromStringHandle' OpKernel for GPU devices compatible with node { {node MultiDeviceIteratorFromStringHandle} }
    .  Registered:  device='CPU'

2021-08-28 01:41:43.216335: W tensorflow/core/grappler/utils/graph_view.cc:836] No registered 'MultiDeviceIteratorGetNextFromShard' OpKernel for GPU devices compatible with node { {node MultiDeviceIteratorGetNextFromShard} }
    .  Registered:  device='CPU'
INFO:tensorflow:loss = 1.0, step = 0
INFO:tensorflow:loss = 1.0, step = 0
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 10...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 10...
INFO:tensorflow:Saving checkpoints for 10 into /tmp/tmpjvzbia52/model.ckpt.
INFO:tensorflow:Saving checkpoints for 10 into /tmp/tmpjvzbia52/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 10...
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 10...
INFO:tensorflow:Loss for final step: 2.877698e-13.
INFO:tensorflow:Loss for final step: 2.877698e-13.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Starting evaluation at 2021-08-28T01:41:43
INFO:tensorflow:Starting evaluation at 2021-08-28T01:41:43
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /tmp/tmpjvzbia52/model.ckpt-10
INFO:tensorflow:Restoring parameters from /tmp/tmpjvzbia52/model.ckpt-10
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Done running local_init_op.
2021-08-28 01:41:44.060848: W tensorflow/core/grappler/utils/graph_view.cc:836] No registered 'MultiDeviceIteratorFromStringHandle' OpKernel for GPU devices compatible with node { {node MultiDeviceIteratorFromStringHandle} }
    .  Registered:  device='CPU'

2021-08-28 01:41:44.062121: W tensorflow/core/grappler/utils/graph_view.cc:836] No registered 'MultiDeviceIteratorGetNextFromShard' OpKernel for GPU devices compatible with node { {node MultiDeviceIteratorGetNextFromShard} }
    .  Registered:  device='CPU'
INFO:tensorflow:Evaluation [1/10]
INFO:tensorflow:Evaluation [1/10]
INFO:tensorflow:Evaluation [2/10]
INFO:tensorflow:Evaluation [2/10]
INFO:tensorflow:Evaluation [3/10]
INFO:tensorflow:Evaluation [3/10]
INFO:tensorflow:Evaluation [4/10]
INFO:tensorflow:Evaluation [4/10]
INFO:tensorflow:Evaluation [5/10]
INFO:tensorflow:Evaluation [5/10]
INFO:tensorflow:Evaluation [6/10]
INFO:tensorflow:Evaluation [6/10]
INFO:tensorflow:Evaluation [7/10]
INFO:tensorflow:Evaluation [7/10]
INFO:tensorflow:Evaluation [8/10]
INFO:tensorflow:Evaluation [8/10]
INFO:tensorflow:Evaluation [9/10]
INFO:tensorflow:Evaluation [9/10]
INFO:tensorflow:Evaluation [10/10]
INFO:tensorflow:Evaluation [10/10]
INFO:tensorflow:Inference Time : 0.22219s
INFO:tensorflow:Inference Time : 0.22219s
INFO:tensorflow:Finished evaluation at 2021-08-28-01:41:44
INFO:tensorflow:Finished evaluation at 2021-08-28-01:41:44
INFO:tensorflow:Saving dict for global step 10: average_loss = 1.4210855e-14, global_step = 10, label/mean = 1.0, loss = 1.4210855e-14, prediction/mean = 0.99999994
INFO:tensorflow:Saving dict for global step 10: average_loss = 1.4210855e-14, global_step = 10, label/mean = 1.0, loss = 1.4210855e-14, prediction/mean = 0.99999994
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 10: /tmp/tmpjvzbia52/model.ckpt-10
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 10: /tmp/tmpjvzbia52/model.ckpt-10
{'average_loss': 1.4210855e-14,
 'label/mean': 1.0,
 'loss': 1.4210855e-14,
 'prediction/mean': 0.99999994,
 'global_step': 10}

Ein weiterer hier hervorzuhebender Unterschied zwischen Estimator und Keras ist die Eingabebehandlung. In Keras wird jeder Batch des Datasets automatisch auf die mehreren Replikate aufgeteilt. In Estimator führen Sie jedoch weder eine automatische Stapelaufteilung durch, noch teilen Sie die Daten automatisch auf verschiedene Worker auf. Sie haben die volle Kontrolle darüber , wie Sie Ihre Daten über Arbeitnehmer und Geräte verteilt werden, und Sie müssen einen liefern input_fn angeben , wie die Daten zu verteilen.

Ihr input_fn heißt einmal pro Arbeiter, also einen Datensatz pro Arbeitnehmer geben. Dann wird ein Batch aus diesem Dataset einem Replikat auf diesem Worker zugeführt, wodurch N Batches für N Replikate auf einem Worker verbraucht werden. Mit anderen Worten, kehrte der Datensatz durch die input_fn Chargen Größe zur Verfügung stellen sollte PER_REPLICA_BATCH_SIZE . Und die globale Losgröße für einen Schritt kann erhalten werden als PER_REPLICA_BATCH_SIZE * strategy.num_replicas_in_sync .

Wenn Sie ein Training mit mehreren Mitarbeitern durchführen, sollten Sie Ihre Daten entweder auf die Mitarbeiter aufteilen oder mit einem zufälligen Seed auf jedem mischen. Sie können ein Beispiel überprüfen , wie dies in dem tun mit Estimator Multi-Arbeiterausbildung Tutorial.

Ebenso können Sie Multi-Worker- und Parameterserver-Strategien verwenden. Der Code bleibt gleich, aber Sie müssen verwenden tf.estimator.train_and_evaluate und Satz TF_CONFIG Umgebungsvariablen für jeden binären Lauf im Cluster.

Was wird jetzt unterstützt?

Es gibt eine begrenzte Unterstützung für die Ausbildung mit Estimator alle Strategien , außer der Verwendung TPUStrategy . Die Grundausbildung und Bewertung sollte funktionieren, aber eine Reihe von erweiterten Funktionen wie v1.train.Scaffold nicht. Es kann auch eine Reihe von Fehlern in dieser Integration geben und es gibt keine Pläne, diese Unterstützung aktiv zu verbessern (der Fokus liegt auf Keras und benutzerdefinierten Trainingsschleifen-Unterstützung). Wenn möglich, sollten Sie es vorziehen , verwenden tf.distribute stattdessen mit diesen APIs.

Schulungs-API Gespiegelte Strategie TPUS-Strategie MultiWorkerMirroredStrategy CentralStorageStrategie ParameterServerStrategie
Schätzer-API Eingeschränkter Support Nicht unterstützt Eingeschränkter Support Eingeschränkter Support Eingeschränkter Support

Beispiele und Tutorials

Hier sind einige End-to-End-Beispiele, die zeigen, wie verschiedene Strategien mit Estimator verwendet werden:

  1. Das Multi-Arbeiter Training mit Estimator Tutorial zeigt , wie Sie mit mehreren Arbeitern mit trainieren können MultiWorkerMirroredStrategy auf dem MNIST - Datensatz.
  2. Ein End-to-End - Beispiel mit Vertriebsstrategien Mehrarbeiterausbildung läuft in tensorflow/ecosystem mit Kubernetes Vorlagen. Es beginnt mit einem Keras Modell und wandelt es in ein Estimator die Verwendung tf.keras.estimator.model_to_estimator API.
  3. Das offizielle ResNet50 Modell, das trainiert werden kann entweder mit MirroredStrategy oder MultiWorkerMirroredStrategy .