Tensorflow 2 efficace

Voir sur TensorFlow.org Exécuter dans Google Colab Afficher sur GitHub Télécharger le cahier

Aperçu

Ce guide fournit une liste des meilleures pratiques pour écrire du code à l'aide de TensorFlow 2 (TF2). Il est destiné aux utilisateurs qui ont récemment basculé depuis TensorFlow 1 (TF1). Reportez-vous à la section migration du guide pour plus d'informations sur la migration de votre code TF1 vers TF2.

Installer

Importez TensorFlow et d'autres dépendances pour les exemples de ce guide.

import tensorflow as tf
import tensorflow_datasets as tfds

Recommandations pour TensorFlow 2 idiomatique

Refactorisez votre code en modules plus petits

Une bonne pratique consiste à refactoriser votre code en fonctions plus petites qui sont appelées selon les besoins. Pour de meilleures performances, vous devez essayer de décorer les plus grands blocs de calcul que vous pouvez dans un tf.function (notez que les fonctions python imbriquées appelées par un tf.function ne nécessitent pas leurs propres décorations séparées, sauf si vous souhaitez utiliser différents jit_compile paramètres de la tf.function .). Selon votre cas d'utilisation, il peut s'agir de plusieurs étapes d'entraînement ou même de toute votre boucle d'entraînement. Pour les cas d'utilisation d'inférence, il peut s'agir d'une seule passe avant de modèle.

Ajuster le taux d'apprentissage par défaut pour certains tf.keras.optimizer s

Certains optimiseurs Keras ont des taux d'apprentissage différents dans TF2. Si vous constatez un changement dans le comportement de convergence de vos modèles, vérifiez les taux d'apprentissage par défaut.

Il n'y a aucun changement pour les optimizers.SGD , les optimizers.Adam ou les optimizers.RMSprop .

Les taux d'apprentissage par défaut suivants ont changé :

Utiliser les tf.Module s et Keras pour gérer les variables

tf.Module s et tf.keras.layers.Layer s offrent les variables pratiques et les propriétés trainable_variables , qui rassemblent de manière récursive toutes les variables dépendantes. Cela facilite la gestion des variables localement là où elles sont utilisées.

Les couches/modèles Keras héritent de tf.train.Checkpointable et sont intégrés à @tf.function , ce qui permet de contrôler directement ou d'exporter des SavedModels à partir d'objets Keras. Vous n'avez pas nécessairement besoin d'utiliser l'API Model.fit de Keras pour tirer parti de ces intégrations.

Lisez la section sur l'apprentissage par transfert et le réglage fin du guide Keras pour savoir comment collecter un sous-ensemble de variables pertinentes à l'aide de Keras.

Combinez tf.data.Dataset s et tf.function

Le package TensorFlow Datasets ( tfds ) contient des utilitaires permettant de charger des ensembles de données prédéfinis en tant tf.data.Dataset . Pour cet exemple, vous pouvez charger le jeu de données MNIST à l'aide tfds :

datasets, info = tfds.load(name='mnist', with_info=True, as_supervised=True)
mnist_train, mnist_test = datasets['train'], datasets['test']

Préparez ensuite les données pour la formation :

  • Redimensionnez chaque image.
  • Mélangez l'ordre des exemples.
  • Collectez des lots d'images et d'étiquettes.
BUFFER_SIZE = 10 # Use a much larger value for real code
BATCH_SIZE = 64
NUM_EPOCHS = 5


def scale(image, label):
  image = tf.cast(image, tf.float32)
  image /= 255

  return image, label

Pour que l'exemple reste court, découpez l'ensemble de données pour ne renvoyer que 5 lots :

train_data = mnist_train.map(scale).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
test_data = mnist_test.map(scale).batch(BATCH_SIZE)

STEPS_PER_EPOCH = 5

train_data = train_data.take(STEPS_PER_EPOCH)
test_data = test_data.take(STEPS_PER_EPOCH)
image_batch, label_batch = next(iter(train_data))
2021-12-08 17:15:01.637157: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.

Utilisez l'itération Python régulière pour itérer sur les données d'entraînement qui tiennent dans la mémoire. Sinon, tf.data.Dataset est le meilleur moyen de diffuser des données d'entraînement à partir du disque. Les ensembles de données sont des itérables (et non des itérateurs) et fonctionnent comme les autres itérables Python dans une exécution hâtive. Vous pouvez utiliser pleinement les fonctionnalités de prélecture/diffusion asynchrones des ensembles de données en enveloppant votre code dans tf.function , qui remplace l'itération Python par les opérations de graphe équivalentes à l'aide d'AutoGraph.

@tf.function
def train(model, dataset, optimizer):
  for x, y in dataset:
    with tf.GradientTape() as tape:
      # training=True is only needed if there are layers with different
      # behavior during training versus inference (e.g. Dropout).
      prediction = model(x, training=True)
      loss = loss_fn(prediction, y)
    gradients = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))

Si vous utilisez l'API Keras Model.fit , vous n'aurez pas à vous soucier de l'itération de l'ensemble de données.

model.compile(optimizer=optimizer, loss=loss_fn)
model.fit(dataset)

Utiliser les boucles d'entraînement Keras

Si vous n'avez pas besoin d'un contrôle de bas niveau de votre processus d'entraînement, il est recommandé d'utiliser les méthodes d' fit , d' evaluate et de predict intégrées de Keras. Ces méthodes fournissent une interface uniforme pour former le modèle quelle que soit l'implémentation (séquentielle, fonctionnelle ou sous-classée).

Les avantages de ces méthodes incluent :

  • Ils acceptent les tableaux Numpy, les générateurs Python et tf.data.Datasets .
  • Ils appliquent automatiquement la régularisation et les pertes d'activation.
  • Ils prennent en charge tf.distribute où le code de formation reste le même quelle que soit la configuration matérielle .
  • Ils prennent en charge les callables arbitraires comme les pertes et les métriques.
  • Ils prennent en charge les rappels tels que tf.keras.callbacks.TensorBoard et les rappels personnalisés.
  • Ils sont performants, utilisant automatiquement les graphes TensorFlow.

Voici un exemple d'entraînement d'un modèle à l'aide d'un Dataset . Pour plus de détails sur la façon dont cela fonctionne, consultez les didacticiels .

model = tf.keras.Sequential([
    tf.keras.layers.Conv2D(32, 3, activation='relu',
                           kernel_regularizer=tf.keras.regularizers.l2(0.02),
                           input_shape=(28, 28, 1)),
    tf.keras.layers.MaxPooling2D(),
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dropout(0.1),
    tf.keras.layers.Dense(64, activation='relu'),
    tf.keras.layers.BatchNormalization(),
    tf.keras.layers.Dense(10)
])

# Model is the full model w/o custom layers
model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

model.fit(train_data, epochs=NUM_EPOCHS)
loss, acc = model.evaluate(test_data)

print("Loss {}, Accuracy {}".format(loss, acc))
Epoch 1/5
5/5 [==============================] - 9s 7ms/step - loss: 1.5762 - accuracy: 0.4938
Epoch 2/5
2021-12-08 17:15:11.145429: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
5/5 [==============================] - 0s 6ms/step - loss: 0.5087 - accuracy: 0.8969
Epoch 3/5
2021-12-08 17:15:11.559374: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
5/5 [==============================] - 2s 5ms/step - loss: 0.3348 - accuracy: 0.9469
Epoch 4/5
2021-12-08 17:15:13.860407: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
5/5 [==============================] - 0s 5ms/step - loss: 0.2445 - accuracy: 0.9688
Epoch 5/5
2021-12-08 17:15:14.269850: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
5/5 [==============================] - 0s 6ms/step - loss: 0.2006 - accuracy: 0.9719
2021-12-08 17:15:14.717552: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
5/5 [==============================] - 1s 4ms/step - loss: 1.4553 - accuracy: 0.5781
Loss 1.4552843570709229, Accuracy 0.578125
2021-12-08 17:15:15.862684: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.

Personnalisez la formation et écrivez votre propre boucle

Si les modèles Keras fonctionnent pour vous, mais que vous avez besoin de plus de flexibilité et de contrôle de l'étape d'entraînement ou des boucles d'entraînement externes, vous pouvez mettre en œuvre vos propres étapes d'entraînement ou même des boucles d'entraînement entières. Consultez le guide Keras sur la personnalisation de l' fit pour en savoir plus.

Vous pouvez également implémenter de nombreuses choses en tant que tf.keras.callbacks.Callback .

Cette méthode présente de nombreux avantages mentionnés précédemment , mais vous donne le contrôle de l'étape du train et même de la boucle extérieure.

Une boucle d'entraînement standard comporte trois étapes :

  1. Itérez sur un générateur Python ou tf.data.Dataset pour obtenir des lots d'exemples.
  2. Utilisez tf.GradientTape pour collecter les dégradés.
  3. Utilisez l'un des tf.keras.optimizers pour appliquer des mises à jour de poids aux variables du modèle.

Rappelles toi:

  • Incluez toujours un argument de training sur la méthode d' call des couches et modèles sous-classés.
  • Assurez-vous d'appeler le modèle avec l'argument d' training défini correctement.
  • Selon l'utilisation, les variables de modèle peuvent ne pas exister tant que le modèle n'est pas exécuté sur un lot de données.
  • Vous devez gérer manuellement des choses comme les pertes de régularisation pour le modèle.

Il n'est pas nécessaire d'exécuter des initialiseurs de variables ou d'ajouter des dépendances de contrôle manuel. tf.function gère pour vous les dépendances de contrôle automatique et l'initialisation des variables lors de la création.

model = tf.keras.Sequential([
    tf.keras.layers.Conv2D(32, 3, activation='relu',
                           kernel_regularizer=tf.keras.regularizers.l2(0.02),
                           input_shape=(28, 28, 1)),
    tf.keras.layers.MaxPooling2D(),
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dropout(0.1),
    tf.keras.layers.Dense(64, activation='relu'),
    tf.keras.layers.BatchNormalization(),
    tf.keras.layers.Dense(10)
])

optimizer = tf.keras.optimizers.Adam(0.001)
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)

@tf.function
def train_step(inputs, labels):
  with tf.GradientTape() as tape:
    predictions = model(inputs, training=True)
    regularization_loss=tf.math.add_n(model.losses)
    pred_loss=loss_fn(labels, predictions)
    total_loss=pred_loss + regularization_loss

  gradients = tape.gradient(total_loss, model.trainable_variables)
  optimizer.apply_gradients(zip(gradients, model.trainable_variables))

for epoch in range(NUM_EPOCHS):
  for inputs, labels in train_data:
    train_step(inputs, labels)
  print("Finished epoch", epoch)
2021-12-08 17:15:16.714849: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
Finished epoch 0
2021-12-08 17:15:17.097043: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
Finished epoch 1
2021-12-08 17:15:17.502480: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
Finished epoch 2
2021-12-08 17:15:17.873701: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
Finished epoch 3
Finished epoch 4
2021-12-08 17:15:18.344196: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.

Tirez parti de tf.function avec le flux de contrôle Python

tf.function fournit un moyen de convertir le flux de contrôle dépendant des données en équivalents en mode graphique comme tf.cond et tf.while_loop .

Un endroit commun où le flux de contrôle dépendant des données apparaît est dans les modèles de séquence. tf.keras.layers.RNN enveloppe une cellule RNN, vous permettant de dérouler la récurrence de manière statique ou dynamique. Par exemple, vous pouvez réimplémenter le déroulement dynamique comme suit.

class DynamicRNN(tf.keras.Model):

  def __init__(self, rnn_cell):
    super(DynamicRNN, self).__init__(self)
    self.cell = rnn_cell

  @tf.function(input_signature=[tf.TensorSpec(dtype=tf.float32, shape=[None, None, 3])])
  def call(self, input_data):

    # [batch, time, features] -> [time, batch, features]
    input_data = tf.transpose(input_data, [1, 0, 2])
    timesteps =  tf.shape(input_data)[0]
    batch_size = tf.shape(input_data)[1]
    outputs = tf.TensorArray(tf.float32, timesteps)
    state = self.cell.get_initial_state(batch_size = batch_size, dtype=tf.float32)
    for i in tf.range(timesteps):
      output, state = self.cell(input_data[i], state)
      outputs = outputs.write(i, output)
    return tf.transpose(outputs.stack(), [1, 0, 2]), state
lstm_cell = tf.keras.layers.LSTMCell(units = 13)

my_rnn = DynamicRNN(lstm_cell)
outputs, state = my_rnn(tf.random.normal(shape=[10,20,3]))
print(outputs.shape)
(10, 20, 13)

Lisez le guide des tf.function pour plus d'informations.

Métriques et pertes de style nouveau

Les métriques et les pertes sont à la fois des objets qui fonctionnent avec impatience et dans tf.function s.

Un objet loss est appelable et attend ( y_true , y_pred ) comme arguments :

cce = tf.keras.losses.CategoricalCrossentropy(from_logits=True)
cce([[1, 0]], [[-1.0,3.0]]).numpy()
4.01815

Utiliser des métriques pour collecter et afficher des données

Vous pouvez utiliser tf.metrics pour agréger les données et tf.summary pour consigner les résumés et les rediriger vers un rédacteur à l'aide d'un gestionnaire de contexte. Les résumés sont émis directement au rédacteur, ce qui signifie que vous devez fournir la valeur du step au site d'appel.

summary_writer = tf.summary.create_file_writer('/tmp/summaries')
with summary_writer.as_default():
  tf.summary.scalar('loss', 0.1, step=42)

Utilisez tf.metrics pour agréger les données avant de les enregistrer sous forme de résumés. Les métriques sont avec état ; ils accumulent des valeurs et renvoient un résultat cumulé lorsque vous appelez la méthode result (comme Mean.result ). Effacez les valeurs accumulées avec Model.reset_states .

def train(model, optimizer, dataset, log_freq=10):
  avg_loss = tf.keras.metrics.Mean(name='loss', dtype=tf.float32)
  for images, labels in dataset:
    loss = train_step(model, optimizer, images, labels)
    avg_loss.update_state(loss)
    if tf.equal(optimizer.iterations % log_freq, 0):
      tf.summary.scalar('loss', avg_loss.result(), step=optimizer.iterations)
      avg_loss.reset_states()

def test(model, test_x, test_y, step_num):
  # training=False is only needed if there are layers with different
  # behavior during training versus inference (e.g. Dropout).
  loss = loss_fn(model(test_x, training=False), test_y)
  tf.summary.scalar('loss', loss, step=step_num)

train_summary_writer = tf.summary.create_file_writer('/tmp/summaries/train')
test_summary_writer = tf.summary.create_file_writer('/tmp/summaries/test')

with train_summary_writer.as_default():
  train(model, optimizer, dataset)

with test_summary_writer.as_default():
  test(model, test_x, test_y, optimizer.iterations)

Visualisez les résumés générés en faisant pointer TensorBoard vers le répertoire du journal des résumés :

tensorboard --logdir /tmp/summaries

Utilisez l'API tf.summary pour écrire des données récapitulatives à visualiser dans TensorBoard. Pour plus d'informations, lisez le guide tf.summary .

# Create the metrics
loss_metric = tf.keras.metrics.Mean(name='train_loss')
accuracy_metric = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy')

@tf.function
def train_step(inputs, labels):
  with tf.GradientTape() as tape:
    predictions = model(inputs, training=True)
    regularization_loss=tf.math.add_n(model.losses)
    pred_loss=loss_fn(labels, predictions)
    total_loss=pred_loss + regularization_loss

  gradients = tape.gradient(total_loss, model.trainable_variables)
  optimizer.apply_gradients(zip(gradients, model.trainable_variables))
  # Update the metrics
  loss_metric.update_state(total_loss)
  accuracy_metric.update_state(labels, predictions)


for epoch in range(NUM_EPOCHS):
  # Reset the metrics
  loss_metric.reset_states()
  accuracy_metric.reset_states()

  for inputs, labels in train_data:
    train_step(inputs, labels)
  # Get the metric results
  mean_loss=loss_metric.result()
  mean_accuracy = accuracy_metric.result()

  print('Epoch: ', epoch)
  print('  loss:     {:.3f}'.format(mean_loss))
  print('  accuracy: {:.3f}'.format(mean_accuracy))
2021-12-08 17:15:19.339736: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
Epoch:  0
  loss:     0.142
  accuracy: 0.991
2021-12-08 17:15:19.781743: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
Epoch:  1
  loss:     0.125
  accuracy: 0.997
2021-12-08 17:15:20.219033: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
Epoch:  2
  loss:     0.110
  accuracy: 0.997
2021-12-08 17:15:20.598085: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
Epoch:  3
  loss:     0.099
  accuracy: 0.997
Epoch:  4
  loss:     0.085
  accuracy: 1.000
2021-12-08 17:15:20.981787: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.

Noms des métriques Keras

Les modèles Keras sont cohérents quant à la gestion des noms de métriques. Lorsque vous transmettez une chaîne dans la liste des métriques, cette chaîne exacte est utilisée comme name de la métrique . Ces noms sont visibles dans l'objet historique renvoyé par model.fit et dans les journaux transmis à keras.callbacks . est défini sur la chaîne que vous avez transmise dans la liste des métriques.

model.compile(
    optimizer = tf.keras.optimizers.Adam(0.001),
    loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics = ['acc', 'accuracy', tf.keras.metrics.SparseCategoricalAccuracy(name="my_accuracy")])
history = model.fit(train_data)
5/5 [==============================] - 1s 5ms/step - loss: 0.0963 - acc: 0.9969 - accuracy: 0.9969 - my_accuracy: 0.9969
2021-12-08 17:15:21.942940: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
history.history.keys()
dict_keys(['loss', 'acc', 'accuracy', 'my_accuracy'])

Débogage

Utilisez une exécution rapide pour exécuter votre code étape par étape afin d'inspecter les formes, les types de données et les valeurs. Certaines API, comme tf.function , tf.keras , etc. sont conçues pour utiliser l'exécution de Graph, pour les performances et la portabilité. Lors du débogage, utilisez tf.config.run_functions_eagerly(True) pour utiliser une exécution rapide dans ce code.

Par example:

@tf.function
def f(x):
  if x > 0:
    import pdb
    pdb.set_trace()
    x = x + 1
  return x

tf.config.run_functions_eagerly(True)
f(tf.constant(1))
>>> f()
-> x = x + 1
(Pdb) l
  6     @tf.function
  7     def f(x):
  8       if x > 0:
  9         import pdb
 10         pdb.set_trace()
 11  ->     x = x + 1
 12       return x
 13
 14     tf.config.run_functions_eagerly(True)
 15     f(tf.constant(1))
[EOF]

Cela fonctionne également à l'intérieur des modèles Keras et d'autres API qui prennent en charge l'exécution rapide :

class CustomModel(tf.keras.models.Model):

  @tf.function
  def call(self, input_data):
    if tf.reduce_mean(input_data) > 0:
      return input_data
    else:
      import pdb
      pdb.set_trace()
      return input_data // 2


tf.config.run_functions_eagerly(True)
model = CustomModel()
model(tf.constant([-2, -4]))
>>> call()
-> return input_data // 2
(Pdb) l
 10         if tf.reduce_mean(input_data) > 0:
 11           return input_data
 12         else:
 13           import pdb
 14           pdb.set_trace()
 15  ->       return input_data // 2
 16
 17
 18     tf.config.run_functions_eagerly(True)
 19     model = CustomModel()
 20     model(tf.constant([-2, -4]))

Remarques:

Ne gardez pas tf.Tensors dans vos objets

Ces objets tenseurs peuvent être créés soit dans une tf.function soit dans le contexte impatient, et ces tenseurs se comportent différemment. Utilisez toujours tf.Tensor s uniquement pour les valeurs intermédiaires.

Pour suivre l'état, utilisez tf.Variable s car ils sont toujours utilisables dans les deux contextes. Lisez le guide tf.Variable pour en savoir plus.

Ressources et lectures complémentaires

  • Lisez les guides et tutoriels TF2 pour en savoir plus sur l'utilisation de TF2.

  • Si vous utilisiez auparavant TF1.x, il est fortement recommandé de migrer votre code vers TF2. Lisez les guides de migration pour en savoir plus.