Cette page a été traduite par l'API Cloud Translation.
Switch to English

Estimateurs

Voir sur TensorFlow.org Exécuter dans Google Colab Afficher la source sur GitHub Télécharger le carnet

Ce document présente tf.estimator une API TensorFlow de haut niveau. Les estimateurs encapsulent les actions suivantes:

  • entraînement
  • évaluation
  • prédiction
  • exporter pour servir

TensorFlow implémente plusieurs estimateurs prédéfinis. Les estimateurs personnalisés sont toujours pris en charge, mais principalement comme mesure de compatibilité ascendante. Les estimateurs personnalisés ne doivent pas être utilisés pour le nouveau code . Tous les Estimators, qu'ils soient tf.estimator.Estimator ou personnalisés, sont des classes basées sur la classe tf.estimator.Estimator .

Pour un exemple rapide, essayez les didacticiels Estimator . Pour un aperçu de la conception de l'API, consultez le livre blanc .

Installer

 pip install -q -U tensorflow_datasets
import tempfile
import os

import tensorflow as tf
import tensorflow_datasets as tfds

Avantages

Semblable à un tf.keras.Model , un estimator est une abstraction au niveau du modèle. Le tf.estimator fournit des fonctionnalités actuellement encore en développement pour tf.keras . Ceux-ci sont:

  • Formation basée sur le serveur de paramètres
  • Intégration TFX complète.

Capacités des estimateurs

Les estimateurs offrent les avantages suivants:

  • Vous pouvez exécuter des modèles basés sur Estimator sur un hôte local ou sur un environnement multi-serveur distribué sans modifier votre modèle. De plus, vous pouvez exécuter des modèles basés sur Estimator sur des processeurs, des GPU ou des TPU sans recoder votre modèle.
  • Les estimateurs fournissent une boucle de formation distribuée sûre qui contrôle comment et quand:
    • charger des données
    • gérer les exceptions
    • créer des fichiers de point de contrôle et récupérer des échecs
    • enregistrer les résumés pour TensorBoard

Lors de l'écriture d'une application avec Estimators, vous devez séparer le pipeline d'entrée de données du modèle. Cette séparation simplifie les expériences avec différents ensembles de données.

Utilisation d'estimateurs prédéfinis

Les estimateurs prédéfinis vous permettent de travailler à un niveau conceptuel beaucoup plus élevé que les API TensorFlow de base. Vous n'avez plus à vous soucier de la création du graphe ou des sessions de calcul puisque les estimateurs gèrent toute la «plomberie» pour vous. De plus, les Estimators prédéfinis vous permettent d'expérimenter différentes architectures de modèle en n'apportant que des modifications minimes au code. tf.estimator.DNNClassifier , par exemple, est une classe Estimator prédéfinie qui entraîne des modèles de classification basés sur des réseaux de neurones denses et à réaction directe.

Un programme TensorFlow reposant sur un Estimator prédéfini comprend généralement les quatre étapes suivantes:

1. Ecrire une fonction d'entrée

Par exemple, vous pouvez créer une fonction pour importer l'ensemble d'apprentissage et une autre fonction pour importer l'ensemble de test. Les estimateurs s'attendent à ce que leurs entrées soient formatées sous forme de paire d'objets:

  • Un dictionnaire dans lequel les clés sont des noms de caractéristiques et les valeurs sont des Tensors (ou SparseTensors) contenant les données de caractéristiques correspondantes
  • Un Tensor contenant une ou plusieurs étiquettes

Le input_fn doit renvoyer un tf.data.Dataset qui produit des paires dans ce format.

Par exemple, le code suivant construit un tf.data.Dataset de l'ensemble de données Titanic train.csv fichier:

def train_input_fn():
  titanic_file = tf.keras.utils.get_file("train.csv", "https://storage.googleapis.com/tf-datasets/titanic/train.csv")
  titanic = tf.data.experimental.make_csv_dataset(
      titanic_file, batch_size=32,
      label_name="survived")
  titanic_batches = (
      titanic.cache().repeat().shuffle(500)
      .prefetch(tf.data.experimental.AUTOTUNE))
  return titanic_batches

Le input_fn est exécuté dans un tf.Graph et peut aussi renvoyer directement une paire (features_dics, labels) contenant des tenseurs de graphe, mais ceci est sujet aux erreurs en dehors des cas simples comme le retour de constantes.

2. Définissez les colonnes de caractéristiques.

Chaque tf.feature_column identifie un nom de fonction, son type et tout prétraitement d'entrée.

Par exemple, l'extrait de code suivant crée trois colonnes de fonctionnalités.

  • Le premier utilise la fonction d' age directement comme entrée à virgule flottante.
  • Le second utilise la fonction de class comme entrée catégorielle.
  • Le troisième utilise embark_town comme entrée catégorielle, mais utilise l' hashing trick pour éviter d'avoir à énumérer les options et de définir le nombre d'options.

Pour plus d'informations, consultez le didacticiel sur les colonnes de fonctionnalités .

age = tf.feature_column.numeric_column('age')
cls = tf.feature_column.categorical_column_with_vocabulary_list('class', ['First', 'Second', 'Third']) 
embark = tf.feature_column.categorical_column_with_hash_bucket('embark_town', 32)

3. Instanciez l'estimateur prédéfini pertinent.

Par exemple, voici un exemple d'instanciation d'un Estimator LinearClassifier nommé LinearClassifier :

model_dir = tempfile.mkdtemp()
model = tf.estimator.LinearClassifier(
    model_dir=model_dir,
    feature_columns=[embark, cls, age],
    n_classes=2
)
INFO:tensorflow:Using default config.
INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmp_2fgw1gd', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}

Pour plus d'informations, consultez le didacticiel du classificateur linéaire .

4. Appelez une méthode de formation, d'évaluation ou d'inférence.

Tous les estimateurs fournissent des méthodes de train , d' evaluate et de predict .

model = model.train(input_fn=train_input_fn, steps=100)
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/training/training_util.py:236: Variable.initialized_value (from tensorflow.python.ops.variables) is deprecated and will be removed in a future version.
Instructions for updating:
Use Variable.read_value. Variables in 2.X are initialized automatically both in eager and graph (inside tf.defun) contexts.
Downloading data from https://storage.googleapis.com/tf-datasets/titanic/train.csv
32768/30874 [===============================] - 0s 0us/step
INFO:tensorflow:Calling model_fn.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_estimator/python/estimator/canned/linear.py:1481: Layer.add_variable (from tensorflow.python.keras.engine.base_layer_v1) is deprecated and will be removed in a future version.
Instructions for updating:
Please use `layer.add_weight` method instead.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/keras/optimizer_v2/ftrl.py:112: calling Constant.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.
Instructions for updating:
Call initializer instance with the dtype argument instead of passing it to the constructor
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmp_2fgw1gd/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
INFO:tensorflow:loss = 0.6931472, step = 0
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 100...
INFO:tensorflow:Saving checkpoints for 100 into /tmp/tmp_2fgw1gd/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 100...
INFO:tensorflow:Loss for final step: 0.6098593.

result = model.evaluate(train_input_fn, steps=10)

for key, value in result.items():
  print(key, ":", value)
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Starting evaluation at 2020-10-15T01:25:18Z
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /tmp/tmp_2fgw1gd/model.ckpt-100
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Evaluation [1/10]
INFO:tensorflow:Evaluation [2/10]
INFO:tensorflow:Evaluation [3/10]
INFO:tensorflow:Evaluation [4/10]
INFO:tensorflow:Evaluation [5/10]
INFO:tensorflow:Evaluation [6/10]
INFO:tensorflow:Evaluation [7/10]
INFO:tensorflow:Evaluation [8/10]
INFO:tensorflow:Evaluation [9/10]
INFO:tensorflow:Evaluation [10/10]
INFO:tensorflow:Inference Time : 0.63935s
INFO:tensorflow:Finished evaluation at 2020-10-15-01:25:19
INFO:tensorflow:Saving dict for global step 100: accuracy = 0.7, accuracy_baseline = 0.603125, auc = 0.70968133, auc_precision_recall = 0.6162292, average_loss = 0.6068252, global_step = 100, label/mean = 0.396875, loss = 0.6068252, precision = 0.6962025, prediction/mean = 0.3867289, recall = 0.43307087
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 100: /tmp/tmp_2fgw1gd/model.ckpt-100
accuracy : 0.7
accuracy_baseline : 0.603125
auc : 0.70968133
auc_precision_recall : 0.6162292
average_loss : 0.6068252
label/mean : 0.396875
loss : 0.6068252
precision : 0.6962025
prediction/mean : 0.3867289
recall : 0.43307087
global_step : 100

for pred in model.predict(train_input_fn):
  for key, value in pred.items():
    print(key, ":", value)
  break
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /tmp/tmp_2fgw1gd/model.ckpt-100
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
logits : [0.6824188]
logistic : [0.6642783]
probabilities : [0.33572164 0.6642783 ]
class_ids : [1]
classes : [b'1']
all_class_ids : [0 1]
all_classes : [b'0' b'1']

Avantages des estimateurs préfabriqués

Les estimateurs prédéfinis codent les meilleures pratiques, offrant les avantages suivants:

  • Meilleures pratiques pour déterminer où les différentes parties du graphe de calcul doivent s'exécuter, implémenter des stratégies sur une seule machine ou sur un cluster.
  • Meilleures pratiques pour la rédaction d'événements (résumés) et résumés universellement utiles.

Si vous n'utilisez pas d'estimateurs prédéfinis, vous devez implémenter vous-même les fonctionnalités précédentes.

Estimateurs personnalisés

Le cœur de chaque Estimator, qu'il soit model_fn ou personnalisé, est sa fonction de modèle , model_fn , qui est une méthode qui crée des graphiques pour l'entraînement, l'évaluation et la prédiction. Lorsque vous utilisez un Estimator prédéfini, quelqu'un d'autre a déjà implémenté la fonction de modèle. Lorsque vous vous appuyez sur un Estimator personnalisé, vous devez écrire vous-même la fonction de modèle.

Créer un estimateur à partir d'un modèle Keras

Vous pouvez convertir des modèles Keras existants en Estimators avec tf.keras.estimator.model_to_estimator . Cela est utile si vous souhaitez moderniser le code de votre modèle, mais que votre pipeline de formation nécessite toujours des estimateurs.

Instanciez un modèle Keras MobileNet V2 et compilez le modèle avec l'optimiseur, la perte et les métriques pour s'entraîner avec:

import tensorflow comme tf importation tensorflow_datasets comme tfds

keras_mobilenet_v2 = tf.keras.applications.MobileNetV2(
    input_shape=(160, 160, 3), include_top=False)
keras_mobilenet_v2.trainable = False

estimator_model = tf.keras.Sequential([
    keras_mobilenet_v2,
    tf.keras.layers.GlobalAveragePooling2D(),
    tf.keras.layers.Dense(1)
])

# Compile the model
estimator_model.compile(
    optimizer='adam',
    loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
    metrics=['accuracy'])
Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/mobilenet_v2/mobilenet_v2_weights_tf_dim_ordering_tf_kernels_1.0_160_no_top.h5
9412608/9406464 [==============================] - 0s 0us/step

Créez un Estimator partir du modèle Keras compilé. L'état initial du modèle du modèle Keras est conservé dans l' Estimator créé:

est_mobilenet_v2 = tf.keras.estimator.model_to_estimator(keras_model=estimator_model)
INFO:tensorflow:Using default config.
WARNING:tensorflow:Using temporary folder as model directory: /tmp/tmpy7tafocp
INFO:tensorflow:Using the Keras model provided.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_estimator/python/estimator/keras.py:220: set_learning_phase (from tensorflow.python.keras.backend) is deprecated and will be removed after 2020-10-11.
Instructions for updating:
Simply pass a True/False value to the `training` argument of the `__call__` method of your layer or model.
INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmpy7tafocp', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}

Traitez l' Estimator dérivé comme vous le feriez avec n'importe quel autre Estimator .

IMG_SIZE = 160  # All images will be resized to 160x160

def preprocess(image, label):
  image = tf.cast(image, tf.float32)
  image = (image/127.5) - 1
  image = tf.image.resize(image, (IMG_SIZE, IMG_SIZE))
  return image, label
def train_input_fn(batch_size):
  data = tfds.load('cats_vs_dogs', as_supervised=True)
  train_data = data['train']
  train_data = train_data.map(preprocess).shuffle(500).batch(batch_size)
  return train_data

Pour vous entraîner, appelez la fonction train d'Estimator:

est_mobilenet_v2.train(input_fn=lambda: train_input_fn(32), steps=50)
Downloading and preparing dataset cats_vs_dogs/4.0.0 (download: 786.68 MiB, generated: Unknown size, total: 786.68 MiB) to /home/kbuilder/tensorflow_datasets/cats_vs_dogs/4.0.0...

Warning:absl:1738 images were corrupted and were skipped

Shuffling and writing examples to /home/kbuilder/tensorflow_datasets/cats_vs_dogs/4.0.0.incompleteMUGW8X/cats_vs_dogs-train.tfrecord
Dataset cats_vs_dogs downloaded and prepared to /home/kbuilder/tensorflow_datasets/cats_vs_dogs/4.0.0. Subsequent calls will reuse this data.
INFO:tensorflow:Calling model_fn.

INFO:tensorflow:Calling model_fn.

INFO:tensorflow:Done calling model_fn.

INFO:tensorflow:Done calling model_fn.

INFO:tensorflow:Warm-starting with WarmStartSettings: WarmStartSettings(ckpt_to_initialize_from='/tmp/tmpy7tafocp/keras/keras_model.ckpt', vars_to_warm_start='.*', var_name_to_vocab_info={}, var_name_to_prev_var_name={})

INFO:tensorflow:Warm-starting with WarmStartSettings: WarmStartSettings(ckpt_to_initialize_from='/tmp/tmpy7tafocp/keras/keras_model.ckpt', vars_to_warm_start='.*', var_name_to_vocab_info={}, var_name_to_prev_var_name={})

INFO:tensorflow:Warm-starting from: /tmp/tmpy7tafocp/keras/keras_model.ckpt

INFO:tensorflow:Warm-starting from: /tmp/tmpy7tafocp/keras/keras_model.ckpt

INFO:tensorflow:Warm-starting variables only in TRAINABLE_VARIABLES.

INFO:tensorflow:Warm-starting variables only in TRAINABLE_VARIABLES.

INFO:tensorflow:Warm-started 158 variables.

INFO:tensorflow:Warm-started 158 variables.

INFO:tensorflow:Create CheckpointSaverHook.

INFO:tensorflow:Create CheckpointSaverHook.

INFO:tensorflow:Graph was finalized.

INFO:tensorflow:Graph was finalized.

INFO:tensorflow:Running local_init_op.

INFO:tensorflow:Running local_init_op.

INFO:tensorflow:Done running local_init_op.

INFO:tensorflow:Done running local_init_op.

INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...

INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...

INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmpy7tafocp/model.ckpt.

INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmpy7tafocp/model.ckpt.

INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...

INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...

INFO:tensorflow:loss = 0.68286127, step = 0

INFO:tensorflow:loss = 0.68286127, step = 0

INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 50...

INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 50...

INFO:tensorflow:Saving checkpoints for 50 into /tmp/tmpy7tafocp/model.ckpt.

INFO:tensorflow:Saving checkpoints for 50 into /tmp/tmpy7tafocp/model.ckpt.

INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 50...

INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 50...

INFO:tensorflow:Loss for final step: 0.70231926.

INFO:tensorflow:Loss for final step: 0.70231926.

<tensorflow_estimator.python.estimator.estimator.EstimatorV2 at 0x7f1a6c4d4cf8>

De même, pour évaluer, appelez la fonction d'évaluation de l'estimateur:

est_mobilenet_v2.evaluate(input_fn=lambda: train_input_fn(32), steps=10)
INFO:tensorflow:Calling model_fn.

INFO:tensorflow:Calling model_fn.

Warning:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/keras/engine/training_v1.py:2048: Model.state_updates (from tensorflow.python.keras.engine.training) is deprecated and will be removed in a future version.
Instructions for updating:
This property should not be used in TensorFlow 2.0, as updates are applied automatically.

Warning:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/keras/engine/training_v1.py:2048: Model.state_updates (from tensorflow.python.keras.engine.training) is deprecated and will be removed in a future version.
Instructions for updating:
This property should not be used in TensorFlow 2.0, as updates are applied automatically.

INFO:tensorflow:Done calling model_fn.

INFO:tensorflow:Done calling model_fn.

INFO:tensorflow:Starting evaluation at 2020-10-15T01:26:26Z

INFO:tensorflow:Starting evaluation at 2020-10-15T01:26:26Z

INFO:tensorflow:Graph was finalized.

INFO:tensorflow:Graph was finalized.

INFO:tensorflow:Restoring parameters from /tmp/tmpy7tafocp/model.ckpt-50

INFO:tensorflow:Restoring parameters from /tmp/tmpy7tafocp/model.ckpt-50

INFO:tensorflow:Running local_init_op.

INFO:tensorflow:Running local_init_op.

INFO:tensorflow:Done running local_init_op.

INFO:tensorflow:Done running local_init_op.

INFO:tensorflow:Evaluation [1/10]

INFO:tensorflow:Evaluation [1/10]

INFO:tensorflow:Evaluation [2/10]

INFO:tensorflow:Evaluation [2/10]

INFO:tensorflow:Evaluation [3/10]

INFO:tensorflow:Evaluation [3/10]

INFO:tensorflow:Evaluation [4/10]

INFO:tensorflow:Evaluation [4/10]

INFO:tensorflow:Evaluation [5/10]

INFO:tensorflow:Evaluation [5/10]

INFO:tensorflow:Evaluation [6/10]

INFO:tensorflow:Evaluation [6/10]

INFO:tensorflow:Evaluation [7/10]

INFO:tensorflow:Evaluation [7/10]

INFO:tensorflow:Evaluation [8/10]

INFO:tensorflow:Evaluation [8/10]

INFO:tensorflow:Evaluation [9/10]

INFO:tensorflow:Evaluation [9/10]

INFO:tensorflow:Evaluation [10/10]

INFO:tensorflow:Evaluation [10/10]

INFO:tensorflow:Inference Time : 1.92025s

INFO:tensorflow:Inference Time : 1.92025s

INFO:tensorflow:Finished evaluation at 2020-10-15-01:26:28

INFO:tensorflow:Finished evaluation at 2020-10-15-01:26:28

INFO:tensorflow:Saving dict for global step 50: accuracy = 0.565625, global_step = 50, loss = 0.6713216

INFO:tensorflow:Saving dict for global step 50: accuracy = 0.565625, global_step = 50, loss = 0.6713216

INFO:tensorflow:Saving 'checkpoint_path' summary for global step 50: /tmp/tmpy7tafocp/model.ckpt-50

INFO:tensorflow:Saving 'checkpoint_path' summary for global step 50: /tmp/tmpy7tafocp/model.ckpt-50

{'accuracy': 0.565625, 'loss': 0.6713216, 'global_step': 50}

Pour plus de détails, veuillez vous référer à la documentation de tf.keras.estimator.model_to_estimator .

Enregistrement de points de contrôle basés sur des objets avec Estimator

Les estimateurs enregistrent par défaut les points de contrôle avec des noms de variables plutôt que le graphique d'objet décrit dans le guide des points de contrôle . tf.train.Checkpoint lira les points de contrôle basés sur les noms, mais les noms de variables peuvent changer lors du déplacement de parties d'un modèle en dehors de model_fn de l'estimateur. Pour la compatibilité ascendante, la sauvegarde des points de contrôle basés sur des objets facilite l'apprentissage d'un modèle dans un Estimator, puis son utilisation en dehors de celui-ci.

import tensorflow.compat.v1 as tf_compat
def toy_dataset():
  inputs = tf.range(10.)[:, None]
  labels = inputs * 5. + tf.range(5.)[None, :]
  return tf.data.Dataset.from_tensor_slices(
    dict(x=inputs, y=labels)).repeat().batch(2)
class Net(tf.keras.Model):
  """A simple linear model."""

  def __init__(self):
    super(Net, self).__init__()
    self.l1 = tf.keras.layers.Dense(5)

  def call(self, x):
    return self.l1(x)
def model_fn(features, labels, mode):
  net = Net()
  opt = tf.keras.optimizers.Adam(0.1)
  ckpt = tf.train.Checkpoint(step=tf_compat.train.get_global_step(),
                             optimizer=opt, net=net)
  with tf.GradientTape() as tape:
    output = net(features['x'])
    loss = tf.reduce_mean(tf.abs(output - features['y']))
  variables = net.trainable_variables
  gradients = tape.gradient(loss, variables)
  return tf.estimator.EstimatorSpec(
    mode,
    loss=loss,
    train_op=tf.group(opt.apply_gradients(zip(gradients, variables)),
                      ckpt.step.assign_add(1)),
    # Tell the Estimator to save "ckpt" in an object-based format.
    scaffold=tf_compat.train.Scaffold(saver=ckpt))

tf.keras.backend.clear_session()
est = tf.estimator.Estimator(model_fn, './tf_estimator_example/')
est.train(toy_dataset, steps=10)
INFO:tensorflow:Using default config.

INFO:tensorflow:Using default config.

INFO:tensorflow:Using config: {'_model_dir': './tf_estimator_example/', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}

INFO:tensorflow:Using config: {'_model_dir': './tf_estimator_example/', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}

INFO:tensorflow:Calling model_fn.

INFO:tensorflow:Calling model_fn.

INFO:tensorflow:Done calling model_fn.

INFO:tensorflow:Done calling model_fn.

INFO:tensorflow:Create CheckpointSaverHook.

INFO:tensorflow:Create CheckpointSaverHook.

INFO:tensorflow:Graph was finalized.

INFO:tensorflow:Graph was finalized.

INFO:tensorflow:Running local_init_op.

INFO:tensorflow:Running local_init_op.

INFO:tensorflow:Done running local_init_op.

INFO:tensorflow:Done running local_init_op.

INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...

INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...

INFO:tensorflow:Saving checkpoints for 0 into ./tf_estimator_example/model.ckpt.

INFO:tensorflow:Saving checkpoints for 0 into ./tf_estimator_example/model.ckpt.

INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...

INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...

INFO:tensorflow:loss = 4.5633583, step = 0

INFO:tensorflow:loss = 4.5633583, step = 0

INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 10...

INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 10...

INFO:tensorflow:Saving checkpoints for 10 into ./tf_estimator_example/model.ckpt.

INFO:tensorflow:Saving checkpoints for 10 into ./tf_estimator_example/model.ckpt.

INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 10...

INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 10...

INFO:tensorflow:Loss for final step: 37.95615.

INFO:tensorflow:Loss for final step: 37.95615.

<tensorflow_estimator.python.estimator.estimator.EstimatorV2 at 0x7f1a6c477630>

tf.train.Checkpoint peut alors charger les points de contrôle de l'estimateur à partir de son model_dir .

opt = tf.keras.optimizers.Adam(0.1)
net = Net()
ckpt = tf.train.Checkpoint(
  step=tf.Variable(1, dtype=tf.int64), optimizer=opt, net=net)
ckpt.restore(tf.train.latest_checkpoint('./tf_estimator_example/'))
ckpt.step.numpy()  # From est.train(..., steps=10)
10

Modèles enregistrés des estimateurs

Les estimateurs exportent SavedModels via tf.Estimator.export_saved_model .

input_column = tf.feature_column.numeric_column("x")

estimator = tf.estimator.LinearClassifier(feature_columns=[input_column])

def input_fn():
  return tf.data.Dataset.from_tensor_slices(
    ({"x": [1., 2., 3., 4.]}, [1, 1, 0, 0])).repeat(200).shuffle(64).batch(16)
estimator.train(input_fn)
INFO:tensorflow:Using default config.

INFO:tensorflow:Using default config.

Warning:tensorflow:Using temporary folder as model directory: /tmp/tmpn_8rzqza

Warning:tensorflow:Using temporary folder as model directory: /tmp/tmpn_8rzqza

INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmpn_8rzqza', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}

INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmpn_8rzqza', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}

INFO:tensorflow:Calling model_fn.

INFO:tensorflow:Calling model_fn.

INFO:tensorflow:Done calling model_fn.

INFO:tensorflow:Done calling model_fn.

INFO:tensorflow:Create CheckpointSaverHook.

INFO:tensorflow:Create CheckpointSaverHook.

INFO:tensorflow:Graph was finalized.

INFO:tensorflow:Graph was finalized.

INFO:tensorflow:Running local_init_op.

INFO:tensorflow:Running local_init_op.

INFO:tensorflow:Done running local_init_op.

INFO:tensorflow:Done running local_init_op.

INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...

INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...

INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmpn_8rzqza/model.ckpt.

INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmpn_8rzqza/model.ckpt.

INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...

INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...

INFO:tensorflow:loss = 0.6931472, step = 0

INFO:tensorflow:loss = 0.6931472, step = 0

INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 50...

INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 50...

INFO:tensorflow:Saving checkpoints for 50 into /tmp/tmpn_8rzqza/model.ckpt.

INFO:tensorflow:Saving checkpoints for 50 into /tmp/tmpn_8rzqza/model.ckpt.

INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 50...

INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 50...

INFO:tensorflow:Loss for final step: 0.35876164.

INFO:tensorflow:Loss for final step: 0.35876164.

<tensorflow_estimator.python.estimator.canned.linear.LinearClassifierV2 at 0x7f1a6c448b00>

Pour enregistrer un Estimator vous devez créer un serving_input_receiver . Cette fonction crée une partie d'un tf.Graph qui analyse les données brutes reçues par SavedModel.

Le module tf.estimator.export contient des fonctions pour aider à construire ces receivers .

Le code suivant crée un récepteur, basé sur les feature_columns , qui accepte les tampons de protocole tf.Example sérialisés, qui sont souvent utilisés avec le service tf .

tmpdir = tempfile.mkdtemp()

serving_input_fn = tf.estimator.export.build_parsing_serving_input_receiver_fn(
  tf.feature_column.make_parse_example_spec([input_column]))

estimator_base_path = os.path.join(tmpdir, 'from_estimator')
estimator_path = estimator.export_saved_model(estimator_base_path, serving_input_fn)
INFO:tensorflow:Calling model_fn.

INFO:tensorflow:Calling model_fn.

INFO:tensorflow:Done calling model_fn.

INFO:tensorflow:Done calling model_fn.

Warning:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/saved_model/signature_def_utils_impl.py:145: build_tensor_info (from tensorflow.python.saved_model.utils_impl) is deprecated and will be removed in a future version.
Instructions for updating:
This function will only be available through the v1 compatibility library as tf.compat.v1.saved_model.utils.build_tensor_info or tf.compat.v1.saved_model.build_tensor_info.

Warning:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/saved_model/signature_def_utils_impl.py:145: build_tensor_info (from tensorflow.python.saved_model.utils_impl) is deprecated and will be removed in a future version.
Instructions for updating:
This function will only be available through the v1 compatibility library as tf.compat.v1.saved_model.utils.build_tensor_info or tf.compat.v1.saved_model.build_tensor_info.

INFO:tensorflow:Signatures INCLUDED in export for Classify: ['serving_default', 'classification']

INFO:tensorflow:Signatures INCLUDED in export for Classify: ['serving_default', 'classification']

INFO:tensorflow:Signatures INCLUDED in export for Regress: ['regression']

INFO:tensorflow:Signatures INCLUDED in export for Regress: ['regression']

INFO:tensorflow:Signatures INCLUDED in export for Predict: ['predict']

INFO:tensorflow:Signatures INCLUDED in export for Predict: ['predict']

INFO:tensorflow:Signatures INCLUDED in export for Train: None

INFO:tensorflow:Signatures INCLUDED in export for Train: None

INFO:tensorflow:Signatures INCLUDED in export for Eval: None

INFO:tensorflow:Signatures INCLUDED in export for Eval: None

INFO:tensorflow:Restoring parameters from /tmp/tmpn_8rzqza/model.ckpt-50

INFO:tensorflow:Restoring parameters from /tmp/tmpn_8rzqza/model.ckpt-50

INFO:tensorflow:Assets added to graph.

INFO:tensorflow:Assets added to graph.

INFO:tensorflow:No assets to write.

INFO:tensorflow:No assets to write.

INFO:tensorflow:SavedModel written to: /tmp/tmptcppevt7/from_estimator/temp-1602725189/saved_model.pb

INFO:tensorflow:SavedModel written to: /tmp/tmptcppevt7/from_estimator/temp-1602725189/saved_model.pb

Vous pouvez également charger et exécuter ce modèle, à partir de python:

imported = tf.saved_model.load(estimator_path)

def predict(x):
  example = tf.train.Example()
  example.features.feature["x"].float_list.value.extend([x])
  return imported.signatures["predict"](
    examples=tf.constant([example.SerializeToString()]))
print(predict(1.5))
print(predict(3.5))
{'classes': <tf.Tensor: shape=(1, 1), dtype=string, numpy=array([[b'1']], dtype=object)>, 'probabilities': <tf.Tensor: shape=(1, 2), dtype=float32, numpy=array([[0.43590346, 0.5640965 ]], dtype=float32)>, 'logits': <tf.Tensor: shape=(1, 1), dtype=float32, numpy=array([[0.2578045]], dtype=float32)>, 'logistic': <tf.Tensor: shape=(1, 1), dtype=float32, numpy=array([[0.5640965]], dtype=float32)>, 'all_class_ids': <tf.Tensor: shape=(1, 2), dtype=int32, numpy=array([[0, 1]], dtype=int32)>, 'all_classes': <tf.Tensor: shape=(1, 2), dtype=string, numpy=array([[b'0', b'1']], dtype=object)>, 'class_ids': <tf.Tensor: shape=(1, 1), dtype=int64, numpy=array([[1]])>}
{'classes': <tf.Tensor: shape=(1, 1), dtype=string, numpy=array([[b'0']], dtype=object)>, 'probabilities': <tf.Tensor: shape=(1, 2), dtype=float32, numpy=array([[0.7984398 , 0.20156018]], dtype=float32)>, 'logits': <tf.Tensor: shape=(1, 1), dtype=float32, numpy=array([[-1.3765715]], dtype=float32)>, 'logistic': <tf.Tensor: shape=(1, 1), dtype=float32, numpy=array([[0.2015602]], dtype=float32)>, 'all_class_ids': <tf.Tensor: shape=(1, 2), dtype=int32, numpy=array([[0, 1]], dtype=int32)>, 'all_classes': <tf.Tensor: shape=(1, 2), dtype=string, numpy=array([[b'0', b'1']], dtype=object)>, 'class_ids': <tf.Tensor: shape=(1, 1), dtype=int64, numpy=array([[0]])>}

tf.estimator.export.build_raw_serving_input_receiver_fn vous permet de créer des fonctions d'entrée qui prennent des tenseurs bruts plutôt que tf.train.Example s.

Utilisation de tf.distribute.Strategy avec Estimator (prise en charge limitée)

Consultez le guide de formation distribuée pour plus d'informations.

tf.estimator est une API TensorFlow de formation distribuée qui prenait à l'origine en charge l'approche de serveur de paramètres async. tf.estimator prend désormais en charge tf.distribute.Strategy . Si vous utilisez tf.estimator , vous pouvez passer à la formation distribuée avec très peu de changements dans votre code. Grâce à cela, les utilisateurs d'Estimator peuvent désormais effectuer une formation distribuée synchrone sur plusieurs GPU et plusieurs travailleurs, ainsi qu'utiliser des TPU. Cette prise en charge dans Estimator est cependant limitée. Voir la section Ce qui est pris en charge maintenant ci-dessous pour plus de détails.

L'utilisation de tf.distribute.Strategy avec Estimator est légèrement différente de celle du cas Keras. Au lieu d'utiliser strategy.scope , nous passons maintenant l'objet de stratégie dans le RunConfig pour l'estimateur.

Voici un extrait de code qui montre cela avec un Estimator LinearRegressor et MirroredStrategy LinearRegressor :

mirrored_strategy = tf.distribute.MirroredStrategy()
config = tf.estimator.RunConfig(
    train_distribute=mirrored_strategy, eval_distribute=mirrored_strategy)
regressor = tf.estimator.LinearRegressor(
    feature_columns=[tf.feature_column.numeric_column('feats')],
    optimizer='SGD',
    config=config)
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)

INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)

INFO:tensorflow:Initializing RunConfig with distribution strategies.

INFO:tensorflow:Initializing RunConfig with distribution strategies.

INFO:tensorflow:Not using Distribute Coordinator.

INFO:tensorflow:Not using Distribute Coordinator.

Warning:tensorflow:Using temporary folder as model directory: /tmp/tmphb70j0wf

Warning:tensorflow:Using temporary folder as model directory: /tmp/tmphb70j0wf

INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmphb70j0wf', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': <tensorflow.python.distribute.mirrored_strategy.MirroredStrategy object at 0x7f1b28cfd4e0>, '_device_fn': None, '_protocol': None, '_eval_distribute': <tensorflow.python.distribute.mirrored_strategy.MirroredStrategy object at 0x7f1b28cfd4e0>, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1, '_distribute_coordinator_mode': None}

INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmphb70j0wf', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': <tensorflow.python.distribute.mirrored_strategy.MirroredStrategy object at 0x7f1b28cfd4e0>, '_device_fn': None, '_protocol': None, '_eval_distribute': <tensorflow.python.distribute.mirrored_strategy.MirroredStrategy object at 0x7f1b28cfd4e0>, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1, '_distribute_coordinator_mode': None}

Nous utilisons ici un Estimator prédéfini, mais le même code fonctionne également avec un Estimator personnalisé. train_distribute détermine comment la formation sera distribuée, et eval_distribute détermine comment l'évaluation sera distribuée. C'est une autre différence par rapport à Keras où nous utilisons la même stratégie pour l'entraînement et l'évaluation.

Nous pouvons maintenant former et évaluer cet Estimator avec une fonction d'entrée:

def input_fn():
  dataset = tf.data.Dataset.from_tensors(({"feats":[1.]}, [1.]))
  return dataset.repeat(1000).batch(10)
regressor.train(input_fn=input_fn, steps=10)
regressor.evaluate(input_fn=input_fn, steps=10)
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/data/ops/multi_device_iterator_ops.py:339: get_next_as_optional (from tensorflow.python.data.ops.iterator_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use `tf.data.Iterator.get_next_as_optional()` instead.

Warning:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/data/ops/multi_device_iterator_ops.py:339: get_next_as_optional (from tensorflow.python.data.ops.iterator_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use `tf.data.Iterator.get_next_as_optional()` instead.

INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Calling model_fn.

INFO:tensorflow:Calling model_fn.

INFO:tensorflow:Done calling model_fn.

INFO:tensorflow:Done calling model_fn.

Warning:tensorflow:AutoGraph could not transform <function _combine_distributed_scaffold.<locals>.<lambda> at 0x7f1b28ec8b70> and will run it as-is.
Cause: could not parse the source code:

      lambda scaffold: scaffold.ready_op, args=(grouped_scaffold,))

This error may be avoided by creating the lambda in a standalone statement.

To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert

Warning:tensorflow:AutoGraph could not transform <function _combine_distributed_scaffold.<locals>.<lambda> at 0x7f1b28ec8b70> and will run it as-is.
Cause: could not parse the source code:

      lambda scaffold: scaffold.ready_op, args=(grouped_scaffold,))

This error may be avoided by creating the lambda in a standalone statement.

To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert

Warning: AutoGraph could not transform <function _combine_distributed_scaffold.<locals>.<lambda> at 0x7f1b28ec8b70> and will run it as-is.
Cause: could not parse the source code:

      lambda scaffold: scaffold.ready_op, args=(grouped_scaffold,))

This error may be avoided by creating the lambda in a standalone statement.

To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert
INFO:tensorflow:Create CheckpointSaverHook.

INFO:tensorflow:Create CheckpointSaverHook.

Warning:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_estimator/python/estimator/util.py:96: DistributedIteratorV1.initialize (from tensorflow.python.distribute.input_lib) is deprecated and will be removed in a future version.
Instructions for updating:
Use the iterator's `initializer` property instead.

Warning:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_estimator/python/estimator/util.py:96: DistributedIteratorV1.initialize (from tensorflow.python.distribute.input_lib) is deprecated and will be removed in a future version.
Instructions for updating:
Use the iterator's `initializer` property instead.

INFO:tensorflow:Graph was finalized.

INFO:tensorflow:Graph was finalized.

INFO:tensorflow:Running local_init_op.

INFO:tensorflow:Running local_init_op.

INFO:tensorflow:Done running local_init_op.

INFO:tensorflow:Done running local_init_op.

INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...

INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...

INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmphb70j0wf/model.ckpt.

INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmphb70j0wf/model.ckpt.

INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...

INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...

INFO:tensorflow:loss = 1.0, step = 0

INFO:tensorflow:loss = 1.0, step = 0

INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 10...

INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 10...

INFO:tensorflow:Saving checkpoints for 10 into /tmp/tmphb70j0wf/model.ckpt.

INFO:tensorflow:Saving checkpoints for 10 into /tmp/tmphb70j0wf/model.ckpt.

INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 10...

INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 10...

INFO:tensorflow:Loss for final step: 2.877698e-13.

INFO:tensorflow:Loss for final step: 2.877698e-13.

INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Calling model_fn.

INFO:tensorflow:Calling model_fn.

INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Done calling model_fn.

INFO:tensorflow:Done calling model_fn.

Warning:tensorflow:AutoGraph could not transform <function _combine_distributed_scaffold.<locals>.<lambda> at 0x7f1ad8062d08> and will run it as-is.
Cause: could not parse the source code:

      lambda scaffold: scaffold.ready_op, args=(grouped_scaffold,))

This error may be avoided by creating the lambda in a standalone statement.

To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert

Warning:tensorflow:AutoGraph could not transform <function _combine_distributed_scaffold.<locals>.<lambda> at 0x7f1ad8062d08> and will run it as-is.
Cause: could not parse the source code:

      lambda scaffold: scaffold.ready_op, args=(grouped_scaffold,))

This error may be avoided by creating the lambda in a standalone statement.

To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert

Warning: AutoGraph could not transform <function _combine_distributed_scaffold.<locals>.<lambda> at 0x7f1ad8062d08> and will run it as-is.
Cause: could not parse the source code:

      lambda scaffold: scaffold.ready_op, args=(grouped_scaffold,))

This error may be avoided by creating the lambda in a standalone statement.

To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert
INFO:tensorflow:Starting evaluation at 2020-10-15T01:26:34Z

INFO:tensorflow:Starting evaluation at 2020-10-15T01:26:34Z

INFO:tensorflow:Graph was finalized.

INFO:tensorflow:Graph was finalized.

INFO:tensorflow:Restoring parameters from /tmp/tmphb70j0wf/model.ckpt-10

INFO:tensorflow:Restoring parameters from /tmp/tmphb70j0wf/model.ckpt-10

INFO:tensorflow:Running local_init_op.

INFO:tensorflow:Running local_init_op.

INFO:tensorflow:Done running local_init_op.

INFO:tensorflow:Done running local_init_op.

INFO:tensorflow:Evaluation [1/10]

INFO:tensorflow:Evaluation [1/10]

INFO:tensorflow:Evaluation [2/10]

INFO:tensorflow:Evaluation [2/10]

INFO:tensorflow:Evaluation [3/10]

INFO:tensorflow:Evaluation [3/10]

INFO:tensorflow:Evaluation [4/10]

INFO:tensorflow:Evaluation [4/10]

INFO:tensorflow:Evaluation [5/10]

INFO:tensorflow:Evaluation [5/10]

INFO:tensorflow:Evaluation [6/10]

INFO:tensorflow:Evaluation [6/10]

INFO:tensorflow:Evaluation [7/10]

INFO:tensorflow:Evaluation [7/10]

INFO:tensorflow:Evaluation [8/10]

INFO:tensorflow:Evaluation [8/10]

INFO:tensorflow:Evaluation [9/10]

INFO:tensorflow:Evaluation [9/10]

INFO:tensorflow:Evaluation [10/10]

INFO:tensorflow:Evaluation [10/10]

INFO:tensorflow:Inference Time : 0.23888s

INFO:tensorflow:Inference Time : 0.23888s

INFO:tensorflow:Finished evaluation at 2020-10-15-01:26:34

INFO:tensorflow:Finished evaluation at 2020-10-15-01:26:34

INFO:tensorflow:Saving dict for global step 10: average_loss = 1.4210855e-14, global_step = 10, label/mean = 1.0, loss = 1.4210855e-14, prediction/mean = 0.99999994

INFO:tensorflow:Saving dict for global step 10: average_loss = 1.4210855e-14, global_step = 10, label/mean = 1.0, loss = 1.4210855e-14, prediction/mean = 0.99999994

INFO:tensorflow:Saving 'checkpoint_path' summary for global step 10: /tmp/tmphb70j0wf/model.ckpt-10

INFO:tensorflow:Saving 'checkpoint_path' summary for global step 10: /tmp/tmphb70j0wf/model.ckpt-10

{'average_loss': 1.4210855e-14,
 'label/mean': 1.0,
 'loss': 1.4210855e-14,
 'prediction/mean': 0.99999994,
 'global_step': 10}

Une autre différence à souligner ici entre Estimator et Keras est la gestion des entrées. Dans Keras, nous avons mentionné que chaque lot de l'ensemble de données est divisé automatiquement entre les multiples réplicas. Dans Estimator, cependant, nous n'effectuons pas de fractionnement automatique du lot, ni de partitionnement automatique des données entre différents travailleurs. Vous avez un contrôle total sur la façon dont vous souhaitez que vos données soient distribuées entre les travailleurs et les périphériques, et vous devez fournir un input_fn pour spécifier comment distribuer vos données.

Votre input_fn est appelé une fois par travailleur, donnant ainsi un ensemble de données par travailleur. Ensuite, un lot de cet ensemble de données est transmis à un réplica sur ce worker, consommant ainsi N lots pour N réplicas sur 1 worker. En d'autres termes, l'ensemble de données renvoyé par input_fn doit fournir des lots de taille PER_REPLICA_BATCH_SIZE . Et la taille globale du lot pour une étape peut être obtenue sous la forme PER_REPLICA_BATCH_SIZE * strategy.num_replicas_in_sync .

Lorsque vous effectuez une formation multi-ouvriers, vous devez soit diviser vos données entre les travailleurs, soit mélanger avec une graine aléatoire sur chacun. Vous pouvez voir un exemple de la façon de procéder dans la formation multi-ouvriers avec Estimator .

Et de la même manière, vous pouvez également utiliser des stratégies de serveurs multi-ouvriers et de paramètres. Le code reste le même, mais vous devez utiliser tf.estimator.train_and_evaluate et définir les variables d'environnement TF_CONFIG pour chaque binaire exécuté dans votre cluster.

Qu'est-ce qui est pris en charge maintenant?

La prise en charge de la formation avec Estimator en utilisant toutes les stratégies à l'exception de TPUStrategy . La formation et l'évaluation de base devraient fonctionner, mais un certain nombre de fonctionnalités avancées telles que v1.train.Scaffold ne le font pas. Il peut également y avoir un certain nombre de bogues dans cette intégration. Pour le moment, nous ne prévoyons pas d'améliorer activement cette prise en charge, mais nous nous concentrons plutôt sur Keras et la prise en charge de la boucle d'entraînement personnalisée. Dans la mesure du possible, vous devriez préférer utiliser tf.distribute avec ces API à la place.

API de formation MiroirStratégie TPUStratégie MultiWorkerMirroredStrategy CentralStorageStrategy ParameterServerStrategy
API Estimator Assistance limitée Non supporté Assistance limitée Assistance limitée Assistance limitée

Exemples et tutoriels

Voici quelques exemples qui montrent l'utilisation de bout en bout de diverses stratégies avec Estimator:

  1. Formation multi-travailleurs avec Estimator pour former MNIST avec plusieurs travailleurs à l'aide de MultiWorkerMirroredStrategy .
  2. Exemple de bout en bout de formation multi-ouvriers dans tensorflow / écosystème à l'aide de modèles Kubernetes. Cet exemple commence par un modèle Keras et le convertit en Estimator à l'aide de l'API tf.keras.estimator.model_to_estimator .
  3. Modèle officiel ResNet50 , qui peut être formé à l'aide de MirroredStrategy ou MultiWorkerMirroredStrategy .