Réserve cette date! Google I / O revient du 18 au 20 mai S'inscrire maintenant
Cette page a été traduite par l'API Cloud Translation.
Switch to English

Estimateurs

Voir sur TensorFlow.org Exécuter dans Google Colab Afficher la source sur GitHub Télécharger le cahier

Ce document présente tf.estimator une API TensorFlow de haut niveau. Les estimateurs encapsulent les actions suivantes:

  • Formation
  • Évaluation
  • Prédiction
  • Exporter pour servir

TensorFlow implémente plusieurs Estimateurs prédéfinis. Les estimateurs personnalisés sont toujours pris en charge, mais principalement comme mesure de compatibilité ascendante. Les estimateurs personnalisés ne doivent pas être utilisés pour le nouveau code . Tous les Estimators, prédéfinis ou personnalisés, sont des classes basées sur la classe tf.estimator.Estimator .

Pour un exemple rapide, essayez les didacticiels Estimator . Pour un aperçu de la conception de l'API, consultez le livre blanc .

Installer

pip install -q -U tensorflow_datasets
import tempfile
import os

import tensorflow as tf
import tensorflow_datasets as tfds

Avantages

Semblable à un tf.keras.Model , un estimator est une abstraction au niveau du modèle. Le tf.estimator fournit certaines fonctionnalités actuellement encore en développement pour tf.keras . Ceux-ci sont:

  • Formation basée sur le serveur de paramètres
  • Intégration TFX complète

Capacités des estimateurs

Les estimateurs offrent les avantages suivants:

  • Vous pouvez exécuter des modèles basés sur Estimator sur un hôte local ou sur un environnement multi-serveur distribué sans modifier votre modèle. En outre, vous pouvez exécuter des modèles basés sur Estimator sur des processeurs, des GPU ou des TPU sans recoder votre modèle.
  • Les estimateurs fournissent une boucle de formation distribuée sûre qui contrôle comment et quand:
    • Charger des données
    • Gérer les exceptions
    • Créer des fichiers de point de contrôle et récupérer des échecs
    • Enregistrer les résumés pour TensorBoard

Lors de l'écriture d'une application avec Estimators, vous devez séparer le pipeline d'entrée de données du modèle. Cette séparation simplifie les expériences avec différents ensembles de données.

Utilisation d'estimateurs prédéfinis

Les estimateurs prédéfinis vous permettent de travailler à un niveau conceptuel beaucoup plus élevé que les API TensorFlow de base. Vous n'avez plus à vous soucier de la création du graphique ou des sessions de calcul puisque les estimateurs gèrent toute la «plomberie» pour vous. De plus, les estimateurs prédéfinis vous permettent d'expérimenter différentes architectures de modèle en n'apportant que des modifications minimes au code. tf.estimator.DNNClassifier , par exemple, est une classe Estimator prédéfinie qui entraîne des modèles de classification basés sur des réseaux de neurones denses et à réaction directe.

Un programme TensorFlow reposant sur un Estimator prédéfini comprend généralement les quatre étapes suivantes:

1. Ecrire une fonction d'entrée

Par exemple, vous pouvez créer une fonction pour importer l'ensemble d'apprentissage et une autre fonction pour importer l'ensemble de test. Les estimateurs s'attendent à ce que leurs entrées soient formatées sous la forme d'une paire d'objets:

  • Un dictionnaire dans lequel les clés sont des noms de caractéristiques et les valeurs sont des Tensors (ou SparseTensors) contenant les données de caractéristiques correspondantes
  • Un Tensor contenant une ou plusieurs étiquettes

Le input_fn doit renvoyer untf.data.Dataset qui produit des paires dans ce format.

Par exemple, le code suivant construit untf.data.Dataset de l'ensemble de données Titanic train.csv fichier:

def train_input_fn():
  titanic_file = tf.keras.utils.get_file("train.csv", "https://storage.googleapis.com/tf-datasets/titanic/train.csv")
  titanic = tf.data.experimental.make_csv_dataset(
      titanic_file, batch_size=32,
      label_name="survived")
  titanic_batches = (
      titanic.cache().repeat().shuffle(500)
      .prefetch(tf.data.AUTOTUNE))
  return titanic_batches

Le input_fn est exécuté dans un tf.Graph et peut aussi renvoyer directement une paire (features_dics, labels) contenant des tenseurs de graphe, mais ceci est sujet aux erreurs en dehors des cas simples comme le retour de constantes.

2. Définissez les colonnes de caractéristiques.

Chaque tf.feature_column identifie un nom de fonction, son type et tout prétraitement d'entrée.

Par exemple, l'extrait de code suivant crée trois colonnes de fonctionnalités.

  • Le premier utilise la fonction d' age directement comme entrée à virgule flottante.
  • Le second utilise la fonction de class comme entrée catégorielle.
  • Le troisième utilise embark_town comme entrée catégorielle, mais utilise l' hashing trick pour éviter d'avoir à énumérer les options et de définir le nombre d'options.

Pour plus d'informations, consultez le didacticiel sur les colonnes de fonctionnalités .

age = tf.feature_column.numeric_column('age')
cls = tf.feature_column.categorical_column_with_vocabulary_list('class', ['First', 'Second', 'Third']) 
embark = tf.feature_column.categorical_column_with_hash_bucket('embark_town', 32)

3. Instanciez l'estimateur préfabriqué pertinent.

Par exemple, voici un exemple d'instanciation d'un Estimator LinearClassifier nommé LinearClassifier :

model_dir = tempfile.mkdtemp()
model = tf.estimator.LinearClassifier(
    model_dir=model_dir,
    feature_columns=[embark, cls, age],
    n_classes=2
)
INFO:tensorflow:Using default config.
INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmpu27sw9ie', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}

Pour plus d'informations, vous pouvez consulter le didacticiel du classificateur linéaire .

4. Appelez une méthode de formation, d'évaluation ou d'inférence.

Tous les estimateurs fournissent des méthodes de train , d' evaluate et de predict .

model = model.train(input_fn=train_input_fn, steps=100)
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/training/training_util.py:236: Variable.initialized_value (from tensorflow.python.ops.variables) is deprecated and will be removed in a future version.
Instructions for updating:
Use Variable.read_value. Variables in 2.X are initialized automatically both in eager and graph (inside tf.defun) contexts.
Downloading data from https://storage.googleapis.com/tf-datasets/titanic/train.csv
32768/30874 [===============================] - 0s 0us/step
INFO:tensorflow:Calling model_fn.
/tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/keras/engine/base_layer_v1.py:1727: UserWarning: `layer.add_variable` is deprecated and will be removed in a future version. Please use `layer.add_weight` method instead.
  warnings.warn('`layer.add_variable` is deprecated and '
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/keras/optimizer_v2/ftrl.py:134: calling Constant.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.
Instructions for updating:
Call initializer instance with the dtype argument instead of passing it to the constructor
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmpu27sw9ie/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
INFO:tensorflow:loss = 0.6931472, step = 0
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 100...
INFO:tensorflow:Saving checkpoints for 100 into /tmp/tmpu27sw9ie/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 100...
INFO:tensorflow:Loss for final step: 0.62258995.
result = model.evaluate(train_input_fn, steps=10)

for key, value in result.items():
  print(key, ":", value)
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Starting evaluation at 2021-01-08T02:56:30Z
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /tmp/tmpu27sw9ie/model.ckpt-100
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Evaluation [1/10]
INFO:tensorflow:Evaluation [2/10]
INFO:tensorflow:Evaluation [3/10]
INFO:tensorflow:Evaluation [4/10]
INFO:tensorflow:Evaluation [5/10]
INFO:tensorflow:Evaluation [6/10]
INFO:tensorflow:Evaluation [7/10]
INFO:tensorflow:Evaluation [8/10]
INFO:tensorflow:Evaluation [9/10]
INFO:tensorflow:Evaluation [10/10]
INFO:tensorflow:Inference Time : 0.67613s
INFO:tensorflow:Finished evaluation at 2021-01-08-02:56:31
INFO:tensorflow:Saving dict for global step 100: accuracy = 0.715625, accuracy_baseline = 0.60625, auc = 0.7403657, auc_precision_recall = 0.6804854, average_loss = 0.5836128, global_step = 100, label/mean = 0.39375, loss = 0.5836128, precision = 0.739726, prediction/mean = 0.34897345, recall = 0.42857143
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 100: /tmp/tmpu27sw9ie/model.ckpt-100
accuracy : 0.715625
accuracy_baseline : 0.60625
auc : 0.7403657
auc_precision_recall : 0.6804854
average_loss : 0.5836128
label/mean : 0.39375
loss : 0.5836128
precision : 0.739726
prediction/mean : 0.34897345
recall : 0.42857143
global_step : 100
for pred in model.predict(train_input_fn):
  for key, value in pred.items():
    print(key, ":", value)
  break
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /tmp/tmpu27sw9ie/model.ckpt-100
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
logits : [-0.73942876]
logistic : [0.32312906]
probabilities : [0.6768709 0.3231291]
class_ids : [0]
classes : [b'0']
all_class_ids : [0 1]
all_classes : [b'0' b'1']

Avantages des estimateurs préfabriqués

Les estimateurs prédéfinis codent les meilleures pratiques, offrant les avantages suivants:

  • Meilleures pratiques pour déterminer où les différentes parties du graphe de calcul doivent s'exécuter, implémenter des stratégies sur une seule machine ou sur un cluster.
  • Meilleures pratiques pour la rédaction d'événements (résumés) et résumés universellement utiles.

Si vous n'utilisez pas d'estimateurs prédéfinis, vous devez implémenter vous-même les fonctionnalités précédentes.

Estimateurs personnalisés

Le cœur de chaque Estimator, qu'il soit model_fn ou personnalisé, est sa fonction de modèle , model_fn , qui est une méthode qui crée des graphiques pour l'entraînement, l'évaluation et la prédiction. Lorsque vous utilisez un Estimator prédéfini, quelqu'un d'autre a déjà implémenté la fonction de modèle. Lorsque vous vous appuyez sur un Estimator personnalisé, vous devez écrire vous-même la fonction de modèle.

Créer un estimateur à partir d'un modèle Keras

Vous pouvez convertir des modèles Keras existants en Estimators avec tf.keras.estimator.model_to_estimator . Cela est utile si vous souhaitez moderniser votre code de modèle, mais que votre pipeline de formation nécessite toujours des estimateurs.

Instanciez un modèle Keras MobileNet V2 et compilez le modèle avec l'optimiseur, la perte et les métriques pour s'entraîner avec:

keras_mobilenet_v2 = tf.keras.applications.MobileNetV2(
    input_shape=(160, 160, 3), include_top=False)
keras_mobilenet_v2.trainable = False

estimator_model = tf.keras.Sequential([
    keras_mobilenet_v2,
    tf.keras.layers.GlobalAveragePooling2D(),
    tf.keras.layers.Dense(1)
])

# Compile the model
estimator_model.compile(
    optimizer='adam',
    loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
    metrics=['accuracy'])
Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/mobilenet_v2/mobilenet_v2_weights_tf_dim_ordering_tf_kernels_1.0_160_no_top.h5
9412608/9406464 [==============================] - 0s 0us/step

Créez un Estimator partir du modèle Keras compilé. L'état initial du modèle Keras est conservé dans l' Estimator créé:

est_mobilenet_v2 = tf.keras.estimator.model_to_estimator(keras_model=estimator_model)
INFO:tensorflow:Using default config.
WARNING:tensorflow:Using temporary folder as model directory: /tmp/tmpeaonpwe8
INFO:tensorflow:Using the Keras model provided.
/tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/keras/backend.py:434: UserWarning: `tf.keras.backend.set_learning_phase` is deprecated and will be removed after 2020-10-11. To update it, simply pass a True/False value to the `training` argument of the `__call__` method of your layer or model.
  warnings.warn('`tf.keras.backend.set_learning_phase` is deprecated and '
INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmpeaonpwe8', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}

Traitez l' Estimator dérivé comme vous le feriez avec n'importe quel autre Estimator .

IMG_SIZE = 160  # All images will be resized to 160x160

def preprocess(image, label):
  image = tf.cast(image, tf.float32)
  image = (image/127.5) - 1
  image = tf.image.resize(image, (IMG_SIZE, IMG_SIZE))
  return image, label
def train_input_fn(batch_size):
  data = tfds.load('cats_vs_dogs', as_supervised=True)
  train_data = data['train']
  train_data = train_data.map(preprocess).shuffle(500).batch(batch_size)
  return train_data

Pour vous entraîner, appelez la fonction train d'Estimator:

est_mobilenet_v2.train(input_fn=lambda: train_input_fn(32), steps=50)
Downloading and preparing dataset 786.68 MiB (download: 786.68 MiB, generated: Unknown size, total: 786.68 MiB) to /home/kbuilder/tensorflow_datasets/cats_vs_dogs/4.0.0...
WARNING:absl:1738 images were corrupted and were skipped
Dataset cats_vs_dogs downloaded and prepared to /home/kbuilder/tensorflow_datasets/cats_vs_dogs/4.0.0. Subsequent calls will reuse this data.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Warm-starting with WarmStartSettings: WarmStartSettings(ckpt_to_initialize_from='/tmp/tmpeaonpwe8/keras/keras_model.ckpt', vars_to_warm_start='.*', var_name_to_vocab_info={}, var_name_to_prev_var_name={})
INFO:tensorflow:Warm-starting with WarmStartSettings: WarmStartSettings(ckpt_to_initialize_from='/tmp/tmpeaonpwe8/keras/keras_model.ckpt', vars_to_warm_start='.*', var_name_to_vocab_info={}, var_name_to_prev_var_name={})
INFO:tensorflow:Warm-starting from: /tmp/tmpeaonpwe8/keras/keras_model.ckpt
INFO:tensorflow:Warm-starting from: /tmp/tmpeaonpwe8/keras/keras_model.ckpt
INFO:tensorflow:Warm-starting variables only in TRAINABLE_VARIABLES.
INFO:tensorflow:Warm-starting variables only in TRAINABLE_VARIABLES.
INFO:tensorflow:Warm-started 158 variables.
INFO:tensorflow:Warm-started 158 variables.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmpeaonpwe8/model.ckpt.
INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmpeaonpwe8/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
INFO:tensorflow:loss = 0.6884984, step = 0
INFO:tensorflow:loss = 0.6884984, step = 0
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 50...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 50...
INFO:tensorflow:Saving checkpoints for 50 into /tmp/tmpeaonpwe8/model.ckpt.
INFO:tensorflow:Saving checkpoints for 50 into /tmp/tmpeaonpwe8/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 50...
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 50...
INFO:tensorflow:Loss for final step: 0.67705643.
INFO:tensorflow:Loss for final step: 0.67705643.
<tensorflow_estimator.python.estimator.estimator.EstimatorV2 at 0x7f3d7c3822b0>

De même, pour évaluer, appelez la fonction d'évaluation de l'estimateur:

est_mobilenet_v2.evaluate(input_fn=lambda: train_input_fn(32), steps=10)
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Calling model_fn.
/tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/keras/engine/training.py:2325: UserWarning: `Model.state_updates` will be removed in a future version. This property should not be used in TensorFlow 2.0, as `updates` are applied automatically.
  warnings.warn('`Model.state_updates` will be removed in a future version. '
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Starting evaluation at 2021-01-08T02:57:32Z
INFO:tensorflow:Starting evaluation at 2021-01-08T02:57:32Z
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /tmp/tmpeaonpwe8/model.ckpt-50
INFO:tensorflow:Restoring parameters from /tmp/tmpeaonpwe8/model.ckpt-50
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Evaluation [1/10]
INFO:tensorflow:Evaluation [1/10]
INFO:tensorflow:Evaluation [2/10]
INFO:tensorflow:Evaluation [2/10]
INFO:tensorflow:Evaluation [3/10]
INFO:tensorflow:Evaluation [3/10]
INFO:tensorflow:Evaluation [4/10]
INFO:tensorflow:Evaluation [4/10]
INFO:tensorflow:Evaluation [5/10]
INFO:tensorflow:Evaluation [5/10]
INFO:tensorflow:Evaluation [6/10]
INFO:tensorflow:Evaluation [6/10]
INFO:tensorflow:Evaluation [7/10]
INFO:tensorflow:Evaluation [7/10]
INFO:tensorflow:Evaluation [8/10]
INFO:tensorflow:Evaluation [8/10]
INFO:tensorflow:Evaluation [9/10]
INFO:tensorflow:Evaluation [9/10]
INFO:tensorflow:Evaluation [10/10]
INFO:tensorflow:Evaluation [10/10]
INFO:tensorflow:Inference Time : 2.42050s
INFO:tensorflow:Inference Time : 2.42050s
INFO:tensorflow:Finished evaluation at 2021-01-08-02:57:35
INFO:tensorflow:Finished evaluation at 2021-01-08-02:57:35
INFO:tensorflow:Saving dict for global step 50: accuracy = 0.515625, global_step = 50, loss = 0.6688157
INFO:tensorflow:Saving dict for global step 50: accuracy = 0.515625, global_step = 50, loss = 0.6688157
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 50: /tmp/tmpeaonpwe8/model.ckpt-50
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 50: /tmp/tmpeaonpwe8/model.ckpt-50
{'accuracy': 0.515625, 'loss': 0.6688157, 'global_step': 50}

Pour plus de détails, reportez-vous à la documentation de tf.keras.estimator.model_to_estimator .

Enregistrement de points de contrôle basés sur des objets avec Estimator

Les estimateurs enregistrent par défaut les points de contrôle avec des noms de variables plutôt que le graphique d'objet décrit dans le guide des points de contrôle . tf.train.Checkpoint lira les points de contrôle basés sur les noms, mais les noms de variables peuvent changer lors du déplacement de parties d'un modèle en dehors de model_fn de l'estimateur. Pour la compatibilité ascendante, l'enregistrement des points de contrôle basés sur les objets facilite l'apprentissage d'un modèle dans un Estimator, puis son utilisation en dehors de celui-ci.

import tensorflow.compat.v1 as tf_compat
def toy_dataset():
  inputs = tf.range(10.)[:, None]
  labels = inputs * 5. + tf.range(5.)[None, :]
  return tf.data.Dataset.from_tensor_slices(
    dict(x=inputs, y=labels)).repeat().batch(2)
class Net(tf.keras.Model):
  """A simple linear model."""

  def __init__(self):
    super(Net, self).__init__()
    self.l1 = tf.keras.layers.Dense(5)

  def call(self, x):
    return self.l1(x)
def model_fn(features, labels, mode):
  net = Net()
  opt = tf.keras.optimizers.Adam(0.1)
  ckpt = tf.train.Checkpoint(step=tf_compat.train.get_global_step(),
                             optimizer=opt, net=net)
  with tf.GradientTape() as tape:
    output = net(features['x'])
    loss = tf.reduce_mean(tf.abs(output - features['y']))
  variables = net.trainable_variables
  gradients = tape.gradient(loss, variables)
  return tf.estimator.EstimatorSpec(
    mode,
    loss=loss,
    train_op=tf.group(opt.apply_gradients(zip(gradients, variables)),
                      ckpt.step.assign_add(1)),
    # Tell the Estimator to save "ckpt" in an object-based format.
    scaffold=tf_compat.train.Scaffold(saver=ckpt))

tf.keras.backend.clear_session()
est = tf.estimator.Estimator(model_fn, './tf_estimator_example/')
est.train(toy_dataset, steps=10)
INFO:tensorflow:Using default config.
INFO:tensorflow:Using default config.
INFO:tensorflow:Using config: {'_model_dir': './tf_estimator_example/', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
INFO:tensorflow:Using config: {'_model_dir': './tf_estimator_example/', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Saving checkpoints for 0 into ./tf_estimator_example/model.ckpt.
INFO:tensorflow:Saving checkpoints for 0 into ./tf_estimator_example/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
INFO:tensorflow:loss = 4.4040537, step = 0
INFO:tensorflow:loss = 4.4040537, step = 0
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 10...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 10...
INFO:tensorflow:Saving checkpoints for 10 into ./tf_estimator_example/model.ckpt.
INFO:tensorflow:Saving checkpoints for 10 into ./tf_estimator_example/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 10...
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 10...
INFO:tensorflow:Loss for final step: 35.247967.
INFO:tensorflow:Loss for final step: 35.247967.
<tensorflow_estimator.python.estimator.estimator.EstimatorV2 at 0x7f3d64534518>

tf.train.Checkpoint peut alors charger les points de contrôle de l'estimateur à partir de son model_dir .

opt = tf.keras.optimizers.Adam(0.1)
net = Net()
ckpt = tf.train.Checkpoint(
  step=tf.Variable(1, dtype=tf.int64), optimizer=opt, net=net)
ckpt.restore(tf.train.latest_checkpoint('./tf_estimator_example/'))
ckpt.step.numpy()  # From est.train(..., steps=10)
10

Modèles enregistrés à partir des estimateurs

Les estimateurs exportent SavedModels via tf.Estimator.export_saved_model .

input_column = tf.feature_column.numeric_column("x")

estimator = tf.estimator.LinearClassifier(feature_columns=[input_column])

def input_fn():
  return tf.data.Dataset.from_tensor_slices(
    ({"x": [1., 2., 3., 4.]}, [1, 1, 0, 0])).repeat(200).shuffle(64).batch(16)
estimator.train(input_fn)
INFO:tensorflow:Using default config.
INFO:tensorflow:Using default config.
WARNING:tensorflow:Using temporary folder as model directory: /tmp/tmpczwhe6jk
WARNING:tensorflow:Using temporary folder as model directory: /tmp/tmpczwhe6jk
INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmpczwhe6jk', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmpczwhe6jk', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmpczwhe6jk/model.ckpt.
INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmpczwhe6jk/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
INFO:tensorflow:loss = 0.6931472, step = 0
INFO:tensorflow:loss = 0.6931472, step = 0
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 50...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 50...
INFO:tensorflow:Saving checkpoints for 50 into /tmp/tmpczwhe6jk/model.ckpt.
INFO:tensorflow:Saving checkpoints for 50 into /tmp/tmpczwhe6jk/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 50...
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 50...
INFO:tensorflow:Loss for final step: 0.48830828.
INFO:tensorflow:Loss for final step: 0.48830828.
<tensorflow_estimator.python.estimator.canned.linear.LinearClassifierV2 at 0x7f3d6452eb00>

Pour enregistrer un Estimator vous devez créer un serving_input_receiver . Cette fonction crée une partie d'un tf.Graph qui analyse les données brutes reçues par SavedModel.

Le module tf.estimator.export contient des fonctions pour aider à construire ces receivers .

Le code suivant crée un récepteur, basé sur les feature_columns , qui accepte les tampons de protocole tf.Example sérialisés, qui sont souvent utilisés avec le service tf .

tmpdir = tempfile.mkdtemp()

serving_input_fn = tf.estimator.export.build_parsing_serving_input_receiver_fn(
  tf.feature_column.make_parse_example_spec([input_column]))

estimator_base_path = os.path.join(tmpdir, 'from_estimator')
estimator_path = estimator.export_saved_model(estimator_base_path, serving_input_fn)
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Done calling model_fn.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/saved_model/signature_def_utils_impl.py:145: build_tensor_info (from tensorflow.python.saved_model.utils_impl) is deprecated and will be removed in a future version.
Instructions for updating:
This function will only be available through the v1 compatibility library as tf.compat.v1.saved_model.utils.build_tensor_info or tf.compat.v1.saved_model.build_tensor_info.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/saved_model/signature_def_utils_impl.py:145: build_tensor_info (from tensorflow.python.saved_model.utils_impl) is deprecated and will be removed in a future version.
Instructions for updating:
This function will only be available through the v1 compatibility library as tf.compat.v1.saved_model.utils.build_tensor_info or tf.compat.v1.saved_model.build_tensor_info.
INFO:tensorflow:Signatures INCLUDED in export for Classify: ['serving_default', 'classification']
INFO:tensorflow:Signatures INCLUDED in export for Classify: ['serving_default', 'classification']
INFO:tensorflow:Signatures INCLUDED in export for Regress: ['regression']
INFO:tensorflow:Signatures INCLUDED in export for Regress: ['regression']
INFO:tensorflow:Signatures INCLUDED in export for Predict: ['predict']
INFO:tensorflow:Signatures INCLUDED in export for Predict: ['predict']
INFO:tensorflow:Signatures INCLUDED in export for Train: None
INFO:tensorflow:Signatures INCLUDED in export for Train: None
INFO:tensorflow:Signatures INCLUDED in export for Eval: None
INFO:tensorflow:Signatures INCLUDED in export for Eval: None
INFO:tensorflow:Restoring parameters from /tmp/tmpczwhe6jk/model.ckpt-50
INFO:tensorflow:Restoring parameters from /tmp/tmpczwhe6jk/model.ckpt-50
INFO:tensorflow:Assets added to graph.
INFO:tensorflow:Assets added to graph.
INFO:tensorflow:No assets to write.
INFO:tensorflow:No assets to write.
INFO:tensorflow:SavedModel written to: /tmp/tmp16t8uhub/from_estimator/temp-1610074656/saved_model.pb
INFO:tensorflow:SavedModel written to: /tmp/tmp16t8uhub/from_estimator/temp-1610074656/saved_model.pb

Vous pouvez également charger et exécuter ce modèle, à partir de python:

imported = tf.saved_model.load(estimator_path)

def predict(x):
  example = tf.train.Example()
  example.features.feature["x"].float_list.value.extend([x])
  return imported.signatures["predict"](
    examples=tf.constant([example.SerializeToString()]))
print(predict(1.5))
print(predict(3.5))
{'logistic': <tf.Tensor: shape=(1, 1), dtype=float32, numpy=array([[0.581246]], dtype=float32)>, 'logits': <tf.Tensor: shape=(1, 1), dtype=float32, numpy=array([[0.32789052]], dtype=float32)>, 'all_classes': <tf.Tensor: shape=(1, 2), dtype=string, numpy=array([[b'0', b'1']], dtype=object)>, 'probabilities': <tf.Tensor: shape=(1, 2), dtype=float32, numpy=array([[0.418754, 0.581246]], dtype=float32)>, 'all_class_ids': <tf.Tensor: shape=(1, 2), dtype=int32, numpy=array([[0, 1]], dtype=int32)>, 'classes': <tf.Tensor: shape=(1, 1), dtype=string, numpy=array([[b'1']], dtype=object)>, 'class_ids': <tf.Tensor: shape=(1, 1), dtype=int64, numpy=array([[1]])>}
{'logistic': <tf.Tensor: shape=(1, 1), dtype=float32, numpy=array([[0.24376468]], dtype=float32)>, 'logits': <tf.Tensor: shape=(1, 1), dtype=float32, numpy=array([[-1.1321492]], dtype=float32)>, 'all_classes': <tf.Tensor: shape=(1, 2), dtype=string, numpy=array([[b'0', b'1']], dtype=object)>, 'probabilities': <tf.Tensor: shape=(1, 2), dtype=float32, numpy=array([[0.7562353 , 0.24376468]], dtype=float32)>, 'all_class_ids': <tf.Tensor: shape=(1, 2), dtype=int32, numpy=array([[0, 1]], dtype=int32)>, 'classes': <tf.Tensor: shape=(1, 1), dtype=string, numpy=array([[b'0']], dtype=object)>, 'class_ids': <tf.Tensor: shape=(1, 1), dtype=int64, numpy=array([[0]])>}

tf.estimator.export.build_raw_serving_input_receiver_fn vous permet de créer des fonctions d'entrée qui prennent des tenseurs bruts plutôt que tf.train.Example s.

Utilisation de tf.distribute.Strategy avec Estimator (prise en charge limitée)

tf.estimator est une API TensorFlow de formation distribuée qui prenait à l'origine en charge l'approche de serveur de paramètres async. tf.estimator prend désormais en charge tf.distribute.Strategy . Si vous utilisez tf.estimator , vous pouvez passer à la formation distribuée avec très peu de changements dans votre code. Grâce à cela, les utilisateurs d'Estimator peuvent désormais effectuer une formation distribuée synchrone sur plusieurs GPU et plusieurs travailleurs, ainsi qu'utiliser des TPU. Cette prise en charge dans Estimator est cependant limitée. Consultez la section Ce qui est pris en charge maintenant ci-dessous pour plus de détails.

L'utilisation de tf.distribute.Strategy avec Estimator est légèrement différente de celle de Keras. Au lieu d'utiliser strategy.scope , vous passez maintenant l'objet de stratégie dans RunConfig pour l'estimateur.

Vous pouvez vous référer au guide de formation distribué pour plus d'informations.

Voici un extrait de code qui montre cela avec un Estimator LinearRegressor et une MirroredStrategy LinearRegressor :

mirrored_strategy = tf.distribute.MirroredStrategy()
config = tf.estimator.RunConfig(
    train_distribute=mirrored_strategy, eval_distribute=mirrored_strategy)
regressor = tf.estimator.LinearRegressor(
    feature_columns=[tf.feature_column.numeric_column('feats')],
    optimizer='SGD',
    config=config)
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)
INFO:tensorflow:Initializing RunConfig with distribution strategies.
INFO:tensorflow:Initializing RunConfig with distribution strategies.
INFO:tensorflow:Not using Distribute Coordinator.
INFO:tensorflow:Not using Distribute Coordinator.
WARNING:tensorflow:Using temporary folder as model directory: /tmp/tmp4uihzu_a
WARNING:tensorflow:Using temporary folder as model directory: /tmp/tmp4uihzu_a
INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmp4uihzu_a', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': <tensorflow.python.distribute.mirrored_strategy.MirroredStrategy object at 0x7f3e84699518>, '_device_fn': None, '_protocol': None, '_eval_distribute': <tensorflow.python.distribute.mirrored_strategy.MirroredStrategy object at 0x7f3e84699518>, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1, '_distribute_coordinator_mode': None}
INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmp4uihzu_a', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': <tensorflow.python.distribute.mirrored_strategy.MirroredStrategy object at 0x7f3e84699518>, '_device_fn': None, '_protocol': None, '_eval_distribute': <tensorflow.python.distribute.mirrored_strategy.MirroredStrategy object at 0x7f3e84699518>, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1, '_distribute_coordinator_mode': None}

Ici, vous utilisez un Estimator prédéfini, mais le même code fonctionne également avec un Estimator personnalisé. train_distribute détermine comment la formation sera distribuée et eval_distribute détermine comment l'évaluation sera distribuée. C'est une autre différence par rapport à Keras où vous utilisez la même stratégie pour l'entraînement et l'évaluation.

Vous pouvez maintenant entraîner et évaluer cet estimateur avec une fonction d'entrée:

def input_fn():
  dataset = tf.data.Dataset.from_tensors(({"feats":[1.]}, [1.]))
  return dataset.repeat(1000).batch(10)
regressor.train(input_fn=input_fn, steps=10)
regressor.evaluate(input_fn=input_fn, steps=10)
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Create CheckpointSaverHook.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_estimator/python/estimator/util.py:96: DistributedIteratorV1.initialize (from tensorflow.python.distribute.input_lib) is deprecated and will be removed in a future version.
Instructions for updating:
Use the iterator's `initializer` property instead.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_estimator/python/estimator/util.py:96: DistributedIteratorV1.initialize (from tensorflow.python.distribute.input_lib) is deprecated and will be removed in a future version.
Instructions for updating:
Use the iterator's `initializer` property instead.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmp4uihzu_a/model.ckpt.
INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmp4uihzu_a/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
INFO:tensorflow:loss = 1.0, step = 0
INFO:tensorflow:loss = 1.0, step = 0
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 10...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 10...
INFO:tensorflow:Saving checkpoints for 10 into /tmp/tmp4uihzu_a/model.ckpt.
INFO:tensorflow:Saving checkpoints for 10 into /tmp/tmp4uihzu_a/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 10...
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 10...
INFO:tensorflow:Loss for final step: 2.877698e-13.
INFO:tensorflow:Loss for final step: 2.877698e-13.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Starting evaluation at 2021-01-08T02:57:41Z
INFO:tensorflow:Starting evaluation at 2021-01-08T02:57:41Z
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /tmp/tmp4uihzu_a/model.ckpt-10
INFO:tensorflow:Restoring parameters from /tmp/tmp4uihzu_a/model.ckpt-10
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Evaluation [1/10]
INFO:tensorflow:Evaluation [1/10]
INFO:tensorflow:Evaluation [2/10]
INFO:tensorflow:Evaluation [2/10]
INFO:tensorflow:Evaluation [3/10]
INFO:tensorflow:Evaluation [3/10]
INFO:tensorflow:Evaluation [4/10]
INFO:tensorflow:Evaluation [4/10]
INFO:tensorflow:Evaluation [5/10]
INFO:tensorflow:Evaluation [5/10]
INFO:tensorflow:Evaluation [6/10]
INFO:tensorflow:Evaluation [6/10]
INFO:tensorflow:Evaluation [7/10]
INFO:tensorflow:Evaluation [7/10]
INFO:tensorflow:Evaluation [8/10]
INFO:tensorflow:Evaluation [8/10]
INFO:tensorflow:Evaluation [9/10]
INFO:tensorflow:Evaluation [9/10]
INFO:tensorflow:Evaluation [10/10]
INFO:tensorflow:Evaluation [10/10]
INFO:tensorflow:Inference Time : 0.26266s
INFO:tensorflow:Inference Time : 0.26266s
INFO:tensorflow:Finished evaluation at 2021-01-08-02:57:42
INFO:tensorflow:Finished evaluation at 2021-01-08-02:57:42
INFO:tensorflow:Saving dict for global step 10: average_loss = 1.4210855e-14, global_step = 10, label/mean = 1.0, loss = 1.4210855e-14, prediction/mean = 0.99999994
INFO:tensorflow:Saving dict for global step 10: average_loss = 1.4210855e-14, global_step = 10, label/mean = 1.0, loss = 1.4210855e-14, prediction/mean = 0.99999994
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 10: /tmp/tmp4uihzu_a/model.ckpt-10
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 10: /tmp/tmp4uihzu_a/model.ckpt-10
{'average_loss': 1.4210855e-14,
 'label/mean': 1.0,
 'loss': 1.4210855e-14,
 'prediction/mean': 0.99999994,
 'global_step': 10}

Une autre différence à souligner ici entre Estimator et Keras est la gestion des entrées. Dans Keras, chaque lot de l'ensemble de données est divisé automatiquement entre les multiples réplicas. Dans Estimator, cependant, vous n'effectuez pas de fractionnement automatique des lots, ni ne partitionnez automatiquement les données entre différents travailleurs. Vous avez un contrôle total sur la façon dont vous souhaitez que vos données soient distribuées entre les travailleurs et les périphériques, et vous devez fournir un input_fn pour spécifier comment distribuer vos données.

Votre input_fn est appelé une fois par travailleur, donnant ainsi un ensemble de données par travailleur. Ensuite, un lot de cet ensemble de données est transmis à un réplica sur ce worker, consommant ainsi N lots pour N réplicas sur 1 worker. En d'autres termes, l'ensemble de données renvoyé par input_fn doit fournir des lots de taille PER_REPLICA_BATCH_SIZE . Et la taille globale du lot pour une étape peut être obtenue sous la forme PER_REPLICA_BATCH_SIZE * strategy.num_replicas_in_sync .

Lorsque vous effectuez une formation multi-travailleurs, vous devez soit diviser vos données entre les nœuds de calcul, soit les mélanger avec une graine aléatoire sur chacun. Vous pouvez consulter un exemple de procédure dans le didacticiel Formation multi-travailleurs avec Estimator .

Et de la même manière, vous pouvez également utiliser des stratégies de serveurs multi-ouvriers et de paramètres. Le code reste le même, mais vous devez utiliser tf.estimator.train_and_evaluate et définir les variables d'environnement TF_CONFIG pour chaque binaire exécuté dans votre cluster.

Qu'est-ce qui est pris en charge maintenant?

La prise en charge de la formation avec Estimator en utilisant toutes les stratégies à l'exception de TPUStrategy . La formation et l'évaluation de base devraient fonctionner, mais un certain nombre de fonctionnalités avancées telles que v1.train.Scaffold ne le font pas. Il peut également y avoir un certain nombre de bogues dans cette intégration et il n'est pas prévu d'améliorer activement cette prise en charge (l'accent est mis sur Keras et la prise en charge de la boucle d'entraînement personnalisée). Dans la mesure du possible, vous devriez préférer utiliser tf.distribute avec ces API à la place.

API de formation MiroirStratégie TPUStratégie MultiWorkerMirroredStrategy CentralStorageStrategy ParameterServerStrategy
API Estimator Assistance limitée Non supporté Assistance limitée Assistance limitée Assistance limitée

Exemples et tutoriels

Voici quelques exemples de bout en bout qui montrent comment utiliser diverses stratégies avec Estimator:

  1. Le didacticiel Formation multi-travailleurs avec Estimator montre comment vous pouvez vous entraîner avec plusieurs travailleurs à l'aide de MultiWorkerMirroredStrategy sur l'ensemble de données MNIST.
  2. Un exemple de bout en bout d' exécution d'une formation multi-travailleurs avec des stratégies de distribution dans tensorflow/ecosystem aide de modèles Kubernetes. Il commence par un modèle Keras et le convertit en Estimator à l'aide de l'API tf.keras.estimator.model_to_estimator .
  3. Le modèle officiel ResNet50 , qui peut être formé en utilisant MirroredStrategy ou MultiWorkerMirroredStrategy .