Migrer d'Estimator vers les API Keras

Voir sur TensorFlow.org Exécuter dans Google Colab Voir la source sur GitHub Télécharger le cahier

Ce guide explique comment migrer des API tf.keras de TensorFlow 1 vers les API tf.estimator.Estimator de TensorFlow 2. Tout d'abord, vous allez configurer et exécuter un modèle de base pour la formation et l'évaluation avec tf.estimator.Estimator . Ensuite, vous effectuerez les étapes équivalentes dans TensorFlow 2 avec les API tf.keras . Vous apprendrez également à personnaliser l'étape d'entraînement en sous- tf.keras.Model et en utilisant tf.GradientTape .

  • Dans TensorFlow 1, les API tf.estimator.Estimator de haut niveau vous permettent d'entraîner et d'évaluer un modèle, ainsi que d'effectuer des inférences et d'enregistrer votre modèle (pour la diffusion).
  • Dans TensorFlow 2, utilisez les API Keras pour effectuer les tâches susmentionnées, telles que la création de modèles , l'application de gradient, la formation , l'évaluation et la prédiction.

(Pour migrer des workflows d'enregistrement de modèles/points de contrôle vers TensorFlow 2, consultez les guides de migration SavedModel et Checkpoint .)

Installer

Commencez par des importations et un jeu de données simple :

import tensorflow as tf
import tensorflow.compat.v1 as tf1
features = [[1., 1.5], [2., 2.5], [3., 3.5]]
labels = [[0.3], [0.5], [0.7]]
eval_features = [[4., 4.5], [5., 5.5], [6., 6.5]]
eval_labels = [[0.8], [0.9], [1.]]

TensorFlow 1 : entraîner et évaluer avec tf.estimator.Estimator

Cet exemple montre comment effectuer l'entraînement et l'évaluation avec tf.estimator.Estimator dans TensorFlow 1.

Commencez par définir quelques fonctions : une fonction d'entrée pour les données d'entraînement, une fonction d'entrée d'évaluation pour les données d'évaluation et une fonction de modèle qui indique à l' Estimator comment l'opération d'entraînement est définie avec les caractéristiques et les étiquettes :

def _input_fn():
  return tf1.data.Dataset.from_tensor_slices((features, labels)).batch(1)

def _eval_input_fn():
  return tf1.data.Dataset.from_tensor_slices(
      (eval_features, eval_labels)).batch(1)

def _model_fn(features, labels, mode):
  logits = tf1.layers.Dense(1)(features)
  loss = tf1.losses.mean_squared_error(labels=labels, predictions=logits)
  optimizer = tf1.train.AdagradOptimizer(0.05)
  train_op = optimizer.minimize(loss, global_step=tf1.train.get_global_step())
  return tf1.estimator.EstimatorSpec(mode, loss=loss, train_op=train_op)

Instanciez votre Estimator et entraînez le modèle :

estimator = tf1.estimator.Estimator(model_fn=_model_fn)
estimator.train(_input_fn)
INFO:tensorflow:Using default config.
WARNING:tensorflow:Using temporary folder as model directory: /tmp/tmpeovq622_
INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmpeovq622_', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/training/training_util.py:236: Variable.initialized_value (from tensorflow.python.ops.variables) is deprecated and will be removed in a future version.
Instructions for updating:
Use Variable.read_value. Variables in 2.X are initialized automatically both in eager and graph (inside tf.defun) contexts.
INFO:tensorflow:Calling model_fn.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/training/adagrad.py:77: calling Constant.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.
Instructions for updating:
Call initializer instance with the dtype argument instead of passing it to the constructor
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmpeovq622_/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
INFO:tensorflow:loss = 2.0834494, step = 0
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 3...
INFO:tensorflow:Saving checkpoints for 3 into /tmp/tmpeovq622_/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 3...
INFO:tensorflow:Loss for final step: 9.88002.
<tensorflow_estimator.python.estimator.estimator.Estimator at 0x7fbd06673350>

Évaluez le programme avec l'ensemble d'évaluation :

estimator.evaluate(_eval_input_fn)
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Starting evaluation at 2021-10-26T01:32:58
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /tmp/tmpeovq622_/model.ckpt-3
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Inference Time : 0.10194s
INFO:tensorflow:Finished evaluation at 2021-10-26-01:32:58
INFO:tensorflow:Saving dict for global step 3: global_step = 3, loss = 20.543152
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 3: /tmp/tmpeovq622_/model.ckpt-3
{'loss': 20.543152, 'global_step': 3}

TensorFlow 2 : entraîner et évaluer avec les méthodes Keras intégrées

Cet exemple montre comment effectuer une formation et une évaluation avec Keras Model.fit et Model.evaluate dans TensorFlow 2. (Vous pouvez en savoir plus dans le guide Formation et évaluation avec le guide des méthodes intégrées .)

  • Commencez par préparer le pipeline de l'ensemble de données avec les API tf.data.Dataset .
  • Définissez un modèle Keras Sequential simple avec une couche linéaire ( tf.keras.layers.Dense ).
  • Instanciez un optimiseur Adagrad ( tf.keras.optimizers.Adagrad ).
  • Configurez le modèle pour l'entraînement en transmettant la variable d' optimizer et la perte d'erreur quadratique moyenne ( "mse" ) à Model.compile .
dataset = tf.data.Dataset.from_tensor_slices((features, labels)).batch(1)
eval_dataset = tf.data.Dataset.from_tensor_slices(
      (eval_features, eval_labels)).batch(1)

model = tf.keras.models.Sequential([tf.keras.layers.Dense(1)])
optimizer = tf.keras.optimizers.Adagrad(learning_rate=0.05)

model.compile(optimizer=optimizer, loss="mse")

Avec cela, vous êtes prêt à former le modèle en appelant Model.fit :

model.fit(dataset)
3/3 [==============================] - 0s 2ms/step - loss: 0.2785
<keras.callbacks.History at 0x7fbc4b320350>

Enfin, évaluez le modèle avec Model.evaluate :

model.evaluate(eval_dataset, return_dict=True)
3/3 [==============================] - 0s 1ms/step - loss: 0.0451
{'loss': 0.04510306194424629}

TensorFlow 2 : entraîner et évaluer avec une étape d'entraînement personnalisée et des méthodes Keras intégrées

Dans TensorFlow 2, vous pouvez également écrire votre propre fonction d'étape d'entraînement personnalisée avec tf.GradientTape pour effectuer des passes avant et arrière, tout en profitant du support d'entraînement intégré, tel que tf.keras.callbacks.Callback et tf.distribute.Strategy . (En savoir plus dans Personnalisation de ce qui se passe dans Model.fit et Rédaction de boucles d'entraînement personnalisées à partir de zéro .)

Dans cet exemple, commencez par créer un tf.keras.Model personnalisé en sous- tf.keras.Sequential qui remplace Model.train_step . (En savoir plus sur la sous-classe de tf.keras.Model ). Dans cette classe, définissez une fonction train_step personnalisée qui, pour chaque lot de données, effectue une passe avant et une passe arrière au cours d'une étape d'apprentissage.

class CustomModel(tf.keras.Sequential):
  """A custom sequential model that overrides `Model.train_step`."""

  def train_step(self, data):
    batch_data, labels = data

    with tf.GradientTape() as tape:
      predictions = self(batch_data, training=True)
      # Compute the loss value (the loss function is configured
      # in `Model.compile`).
      loss = self.compiled_loss(labels, predictions)

    # Compute the gradients of the parameters with respect to the loss.
    gradients = tape.gradient(loss, self.trainable_variables)
    # Perform gradient descent by updating the weights/parameters.
    self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))
    # Update the metrics (includes the metric that tracks the loss).
    self.compiled_metrics.update_state(labels, predictions)
    # Return a dict mapping metric names to the current values.
    return {m.name: m.result() for m in self.metrics}

Ensuite, comme avant :

dataset = tf.data.Dataset.from_tensor_slices((features, labels)).batch(1)
eval_dataset = tf.data.Dataset.from_tensor_slices(
      (eval_features, eval_labels)).batch(1)

model = CustomModel([tf.keras.layers.Dense(1)])
optimizer = tf.keras.optimizers.Adagrad(learning_rate=0.05)

model.compile(optimizer=optimizer, loss="mse")

Appelez Model.fit pour entraîner le modèle :

model.fit(dataset)
3/3 [==============================] - 0s 2ms/step - loss: 0.0587
<keras.callbacks.History at 0x7fbc3873f1d0>

Et, enfin, évaluez le programme avec Model.evaluate :

model.evaluate(eval_dataset, return_dict=True)
3/3 [==============================] - 0s 2ms/step - loss: 0.0197
{'loss': 0.019738242030143738}

Prochaines étapes

Ressources Keras supplémentaires que vous pourriez trouver utiles :

Les guides suivants peuvent vous aider à migrer les workflows de stratégie de distribution à partir des API tf.estimator :