Tensorflow 2 efficace

Voir sur TensorFlow.org Exécuter dans Google Colab Voir sur GitHub Télécharger le cahier

Aperçu

Ce guide fournit une liste des meilleures pratiques pour écrire du code à l'aide de TensorFlow 2 (TF2). Reportez - vous à la section migrate du guide pour plus d' informations sur la migration de votre code TF1.x à TF2.

Installer

Importez TensorFlow et d'autres dépendances pour les exemples de ce guide.

import tensorflow as tf
import tensorflow_datasets as tfds

Recommandations pour TensorFlow 2 idiomatique

Refactorisez votre code en modules plus petits

Une bonne pratique consiste à refactoriser votre code en fonctions plus petites qui sont appelées selon les besoins. Pour de meilleures performances, vous devriez essayer de décorer les plus grands blocs de calcul que vous pouvez dans un tf.function (notez que les fonctions de python imbriquées appelées par un tf.function ne nécessitent pas leurs propres décorations séparées, à moins que vous souhaitez utiliser différents jit_compile les paramètres du tf.function ). Selon votre cas d'utilisation, il peut s'agir de plusieurs étapes d'entraînement ou même de l'ensemble de votre boucle d'entraînement. Pour les cas d'utilisation d'inférence, il peut s'agir d'un modèle unique de passe avant.

Régler le taux d'apprentissage par défaut pour certains tf.keras.optimizer s

Certains optimiseurs Keras ont des taux d'apprentissage différents dans TF2. Si vous constatez un changement dans le comportement de convergence de vos modèles, vérifiez les taux d'apprentissage par défaut.

Il n'y a aucun changement pour optimizers.SGD , optimizers.Adam ou optimizers.RMSprop .

Les taux d'apprentissage par défaut suivants ont changé :

Utilisez tf.Module couches s et KERAS pour gérer les variables

tf.Module s et tf.keras.layers.Layer offre des pratiques de variables et trainable_variables propriétés, qui regroupent toutes récursive variables dépendantes. Cela facilite la gestion des variables localement à l'endroit où elles sont utilisées.

Couches KERAS / modèles héritent de tf.train.Checkpointable et sont intégrés avec @tf.function , ce qui permet de poste de contrôle directement ou SavedModels à l'exportation d'objets Keras. Vous ne devez pas nécessairement utiliser Keras de Model.fit API pour tirer profit de ces intégrations.

Lisez la section sur l' apprentissage de transfert et réglage fin dans le guide Keras pour apprendre à recueillir un sous - ensemble de variables pertinentes à l' aide Keras.

Combiner tf.data.Dataset s et tf.function

Le Datasets tensorflow paquet ( tfds ) contient des outils pour les jeux de données prédéfinis de chargement comme tf.data.Dataset objets. Pour cet exemple, vous pouvez charger le jeu de données à l' aide MNIST tfds :

datasets, info = tfds.load(name='mnist', with_info=True, as_supervised=True)
mnist_train, mnist_test = datasets['train'], datasets['test']

Ensuite, préparez les données pour l'entraînement :

  • Redimensionnez chaque image.
  • Mélangez l'ordre des exemples.
  • Collectez des lots d'images et d'étiquettes.
BUFFER_SIZE = 10 # Use a much larger value for real code
BATCH_SIZE = 64
NUM_EPOCHS = 5


def scale(image, label):
  image = tf.cast(image, tf.float32)
  image /= 255

  return image, label

Pour garder l'exemple court, coupez l'ensemble de données pour ne renvoyer que 5 lots :

train_data = mnist_train.map(scale).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
test_data = mnist_test.map(scale).batch(BATCH_SIZE)

STEPS_PER_EPOCH = 5

train_data = train_data.take(STEPS_PER_EPOCH)
test_data = test_data.take(STEPS_PER_EPOCH)
image_batch, label_batch = next(iter(train_data))
2021-09-22 22:13:17.284138: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.

Utilisez une itération Python régulière pour parcourir les données d'entraînement qui tiennent en mémoire. Dans le cas contraire, tf.data.Dataset est la meilleure façon de diffuser des données de formation à partir du disque. Jeux de données sont iterables (pas itérateurs) , et le travail , tout comme d' autres iterables Python dans l' exécution avide. Vous pouvez utiliser pleinement ensemble de données async préchargement / lecture en transit caractéristiques en enveloppant votre code dans tf.function , qui remplace l' itération Python avec les opérations équivalent graphique à l' aide AutoGraph.

@tf.function
def train(model, dataset, optimizer):
  for x, y in dataset:
    with tf.GradientTape() as tape:
      # training=True is only needed if there are layers with different
      # behavior during training versus inference (e.g. Dropout).
      prediction = model(x, training=True)
      loss = loss_fn(prediction, y)
    gradients = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))

Si vous utilisez le Keras Model.fit API, vous ne serez pas à vous soucier de l' itération ensemble de données.

model.compile(optimizer=optimizer, loss=loss_fn)
model.fit(dataset)

Utiliser les boucles d'entraînement Keras

Si vous n'avez pas besoin de contrôle de bas niveau de votre processus de formation, en utilisant Keras' intégré en fit , d' evaluate et de predict des méthodes est recommandée. Ces méthodes fournissent une interface uniforme pour entraîner le modèle quelle que soit l'implémentation (séquentielle, fonctionnelle ou sous-classée).

Les avantages de ces méthodes incluent :

  • Ils acceptent des réseaux NumPy, générateurs Python et, tf.data.Datasets .
  • Ils appliquent la régularisation et les pertes d'activation automatiquement.
  • Ils soutiennent tf.distribute où le code de formation reste la même quelle que soit la configuration matérielle .
  • Ils prennent en charge les callables arbitraires en tant que pertes et métriques.
  • Ils soutiennent callbacks comme tf.keras.callbacks.TensorBoard et callbacks personnalisés.
  • Ils sont performants et utilisent automatiquement des graphiques TensorFlow.

Voici un exemple de la formation d' un modèle à l' aide d' un Dataset . Pour plus de détails sur la façon dont cela fonctionne, consultez les tutoriels .

model = tf.keras.Sequential([
    tf.keras.layers.Conv2D(32, 3, activation='relu',
                           kernel_regularizer=tf.keras.regularizers.l2(0.02),
                           input_shape=(28, 28, 1)),
    tf.keras.layers.MaxPooling2D(),
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dropout(0.1),
    tf.keras.layers.Dense(64, activation='relu'),
    tf.keras.layers.BatchNormalization(),
    tf.keras.layers.Dense(10)
])

# Model is the full model w/o custom layers
model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

model.fit(train_data, epochs=NUM_EPOCHS)
loss, acc = model.evaluate(test_data)

print("Loss {}, Accuracy {}".format(loss, acc))
Epoch 1/5
5/5 [==============================] - 9s 9ms/step - loss: 1.5774 - accuracy: 0.5063
Epoch 2/5
2021-09-22 22:13:26.932626: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
5/5 [==============================] - 0s 5ms/step - loss: 0.4498 - accuracy: 0.9125
Epoch 3/5
2021-09-22 22:13:27.323101: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
5/5 [==============================] - 0s 5ms/step - loss: 0.2929 - accuracy: 0.9563
Epoch 4/5
2021-09-22 22:13:27.717803: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
5/5 [==============================] - 0s 5ms/step - loss: 0.2055 - accuracy: 0.9875
Epoch 5/5
2021-09-22 22:13:28.088985: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
5/5 [==============================] - 0s 5ms/step - loss: 0.1669 - accuracy: 0.9937
2021-09-22 22:13:28.458529: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
5/5 [==============================] - 0s 3ms/step - loss: 1.6056 - accuracy: 0.6500
Loss 1.6056102514266968, Accuracy 0.6499999761581421
2021-09-22 22:13:28.956635: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.

Personnalisez la formation et écrivez votre propre boucle

Si les modèles Keras fonctionnent pour vous, mais que vous avez besoin de plus de flexibilité et de contrôle de l'étape d'entraînement ou des boucles d'entraînement externes, vous pouvez mettre en œuvre vos propres étapes d'entraînement ou même des boucles d'entraînement entières. Consultez le guide Keras sur la personnalisation en fit pour en savoir plus.

Vous pouvez également mettre en œuvre beaucoup de choses en tant que tf.keras.callbacks.Callback .

Cette méthode a de nombreux avantages mentionnés précédemment , mais vous permet de contrôler l'étape de train et même la boucle extérieure.

Une boucle d'entraînement standard comporte trois étapes :

  1. Itérer sur un générateur Python ou tf.data.Dataset pour obtenir des lots d'exemples.
  2. Utilisez tf.GradientTape à des gradients Collect.
  3. Utilisez l' une des tf.keras.optimizers pour appliquer les mises à jour de poids aux variables du modèle.

Rappelles toi:

  • Toujours inclure une training argumentation sur l' call méthode des couches et des modèles sous - classées.
  • Assurez - vous d'appeler le modèle avec la training argumentation correctement.
  • Selon l'utilisation, les variables de modèle peuvent ne pas exister tant que le modèle n'est pas exécuté sur un lot de données.
  • Vous devez gérer manuellement des éléments tels que les pertes de régularisation pour le modèle.

Il n'est pas nécessaire d'exécuter des initialiseurs de variables ou d'ajouter des dépendances de contrôle manuel. tf.function gère les dépendances de contrôle automatique et l' initialisation des variables sur la création pour vous.

model = tf.keras.Sequential([
    tf.keras.layers.Conv2D(32, 3, activation='relu',
                           kernel_regularizer=tf.keras.regularizers.l2(0.02),
                           input_shape=(28, 28, 1)),
    tf.keras.layers.MaxPooling2D(),
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dropout(0.1),
    tf.keras.layers.Dense(64, activation='relu'),
    tf.keras.layers.BatchNormalization(),
    tf.keras.layers.Dense(10)
])

optimizer = tf.keras.optimizers.Adam(0.001)
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)

@tf.function
def train_step(inputs, labels):
  with tf.GradientTape() as tape:
    predictions = model(inputs, training=True)
    regularization_loss=tf.math.add_n(model.losses)
    pred_loss=loss_fn(labels, predictions)
    total_loss=pred_loss + regularization_loss

  gradients = tape.gradient(total_loss, model.trainable_variables)
  optimizer.apply_gradients(zip(gradients, model.trainable_variables))

for epoch in range(NUM_EPOCHS):
  for inputs, labels in train_data:
    train_step(inputs, labels)
  print("Finished epoch", epoch)
2021-09-22 22:13:29.878252: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
Finished epoch 0
2021-09-22 22:13:30.266807: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
Finished epoch 1
2021-09-22 22:13:30.626589: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
Finished epoch 2
2021-09-22 22:13:31.040058: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
Finished epoch 3
Finished epoch 4
2021-09-22 22:13:31.417637: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.

Profitez de tf.function avec le flux de commande Python

tf.function fournit un moyen pour convertir le flux de contrôle de dépendance de données en équivalents en mode graphique comme tf.cond et tf.while_loop .

Un endroit commun où le flux de contrôle dépendant des données apparaît est dans les modèles de séquence. tf.keras.layers.RNN enveloppe une cellule RNN, vous permettant de manière statique ou dynamique Déroulez la récurrence. À titre d'exemple, vous pouvez réimplémenter le déroulement dynamique comme suit.

class DynamicRNN(tf.keras.Model):

  def __init__(self, rnn_cell):
    super(DynamicRNN, self).__init__(self)
    self.cell = rnn_cell

  @tf.function(input_signature=[tf.TensorSpec(dtype=tf.float32, shape=[None, None, 3])])
  def call(self, input_data):

    # [batch, time, features] -> [time, batch, features]
    input_data = tf.transpose(input_data, [1, 0, 2])
    timesteps =  tf.shape(input_data)[0]
    batch_size = tf.shape(input_data)[1]
    outputs = tf.TensorArray(tf.float32, timesteps)
    state = self.cell.get_initial_state(batch_size = batch_size, dtype=tf.float32)
    for i in tf.range(timesteps):
      output, state = self.cell(input_data[i], state)
      outputs = outputs.write(i, output)
    return tf.transpose(outputs.stack(), [1, 0, 2]), state
lstm_cell = tf.keras.layers.LSTMCell(units = 13)

my_rnn = DynamicRNN(lstm_cell)
outputs, state = my_rnn(tf.random.normal(shape=[10,20,3]))
print(outputs.shape)
(10, 20, 13)

Lire le tf.function Guide pour plus d' informations.

Des mesures et des pertes de style nouveau

Les mesures et les pertes sont que le travail avec impatience les deux objets et tf.function s.

Un objet de la perte est appelable, et attend ( y_true , y_pred ) comme arguments:

cce = tf.keras.losses.CategoricalCrossentropy(from_logits=True)
cce([[1, 0]], [[-1.0,3.0]]).numpy()
4.01815

Utiliser des métriques pour collecter et afficher des données

Vous pouvez utiliser tf.metrics aux données agrégées et tf.summary pour connecter des résumés et le rediriger vers un écrivain en utilisant un gestionnaire de contexte. Les résumés sont émis directement à l'écrivain qui signifie que vous devez fournir à l' step valeur à la callsite.

summary_writer = tf.summary.create_file_writer('/tmp/summaries')
with summary_writer.as_default():
  tf.summary.scalar('loss', 0.1, step=42)

Utilisez tf.metrics pour agréger les données avant de les enregistrer sous forme de résumés. Les métriques sont avec état ; ils accumulent les valeurs et renvoient un résultat cumulatif lorsque vous appelez le result méthode (comme Mean.result ). Effacer les valeurs accumulées avec Model.reset_states .

def train(model, optimizer, dataset, log_freq=10):
  avg_loss = tf.keras.metrics.Mean(name='loss', dtype=tf.float32)
  for images, labels in dataset:
    loss = train_step(model, optimizer, images, labels)
    avg_loss.update_state(loss)
    if tf.equal(optimizer.iterations % log_freq, 0):
      tf.summary.scalar('loss', avg_loss.result(), step=optimizer.iterations)
      avg_loss.reset_states()

def test(model, test_x, test_y, step_num):
  # training=False is only needed if there are layers with different
  # behavior during training versus inference (e.g. Dropout).
  loss = loss_fn(model(test_x, training=False), test_y)
  tf.summary.scalar('loss', loss, step=step_num)

train_summary_writer = tf.summary.create_file_writer('/tmp/summaries/train')
test_summary_writer = tf.summary.create_file_writer('/tmp/summaries/test')

with train_summary_writer.as_default():
  train(model, optimizer, dataset)

with test_summary_writer.as_default():
  test(model, test_x, test_y, optimizer.iterations)

Visualisez les résumés générés en pointant TensorBoard vers le répertoire du journal des résumés :

tensorboard --logdir /tmp/summaries

Utilisez le tf.summary API pour les données de synthèse d'écriture pour la visualisation en TensorBoard. Pour plus d' informations, lire le tf.summary Guide .

# Create the metrics
loss_metric = tf.keras.metrics.Mean(name='train_loss')
accuracy_metric = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy')

@tf.function
def train_step(inputs, labels):
  with tf.GradientTape() as tape:
    predictions = model(inputs, training=True)
    regularization_loss=tf.math.add_n(model.losses)
    pred_loss=loss_fn(labels, predictions)
    total_loss=pred_loss + regularization_loss

  gradients = tape.gradient(total_loss, model.trainable_variables)
  optimizer.apply_gradients(zip(gradients, model.trainable_variables))
  # Update the metrics
  loss_metric.update_state(total_loss)
  accuracy_metric.update_state(labels, predictions)


for epoch in range(NUM_EPOCHS):
  # Reset the metrics
  loss_metric.reset_states()
  accuracy_metric.reset_states()

  for inputs, labels in train_data:
    train_step(inputs, labels)
  # Get the metric results
  mean_loss=loss_metric.result()
  mean_accuracy = accuracy_metric.result()

  print('Epoch: ', epoch)
  print('  loss:     {:.3f}'.format(mean_loss))
  print('  accuracy: {:.3f}'.format(mean_accuracy))
2021-09-22 22:13:32.370558: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
Epoch:  0
  loss:     0.143
  accuracy: 0.997
2021-09-22 22:13:32.752675: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
Epoch:  1
  loss:     0.119
  accuracy: 0.997
2021-09-22 22:13:33.122889: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
Epoch:  2
  loss:     0.106
  accuracy: 0.997
2021-09-22 22:13:33.522935: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
Epoch:  3
  loss:     0.089
  accuracy: 1.000
Epoch:  4
  loss:     0.079
  accuracy: 1.000
2021-09-22 22:13:33.899024: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.

Noms des métriques Keras

Les modèles Keras sont cohérents en ce qui concerne la gestion des noms de métriques. Lorsque vous passez une chaîne dans la liste des paramètres, cette chaîne exacte est utilisée comme indicateur du name . Ces noms sont visibles dans l'objet retourné par l' histoire model.fit , et dans les journaux passés à keras.callbacks . est défini sur la chaîne que vous avez transmise dans la liste de métriques.

model.compile(
    optimizer = tf.keras.optimizers.Adam(0.001),
    loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics = ['acc', 'accuracy', tf.keras.metrics.SparseCategoricalAccuracy(name="my_accuracy")])
history = model.fit(train_data)
5/5 [==============================] - 1s 5ms/step - loss: 0.0962 - acc: 0.9969 - accuracy: 0.9969 - my_accuracy: 0.9969
2021-09-22 22:13:34.802566: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
history.history.keys()
dict_keys(['loss', 'acc', 'accuracy', 'my_accuracy'])

Débogage

Utilisez une exécution rapide pour exécuter votre code étape par étape afin d'inspecter les formes, les types de données et les valeurs. Certaines API, comme tf.function , tf.keras , etc. sont conçus pour utiliser l' exécution graphique, pour des performances et de portabilité. Lors du débogage, utilisez tf.config.run_functions_eagerly(True) pour utiliser l' exécution de ce code à l' intérieur désireux.

Par exemple:

@tf.function
def f(x):
  if x > 0:
    import pdb
    pdb.set_trace()
    x = x + 1
  return x

tf.config.run_functions_eagerly(True)
f(tf.constant(1))
f()
-> x = x + 1
(Pdb) l
  6     @tf.function
  7     def f(x):
  8       if x > 0:
  9         import pdb
 10         pdb.set_trace()
 11  ->     x = x + 1
 12       return x
 13
 14     tf.config.run_functions_eagerly(True)
 15     f(tf.constant(1))
[EOF]

Cela fonctionne également dans les modèles Keras et d'autres API qui prennent en charge l'exécution rapide :

class CustomModel(tf.keras.models.Model):

  @tf.function
  def call(self, input_data):
    if tf.reduce_mean(input_data) > 0:
      return input_data
    else:
      import pdb
      pdb.set_trace()
      return input_data // 2


tf.config.run_functions_eagerly(True)
model = CustomModel()
model(tf.constant([-2, -4]))
call()
-> return input_data // 2
(Pdb) l
 10         if tf.reduce_mean(input_data) > 0:
 11           return input_data
 12         else:
 13           import pdb
 14           pdb.set_trace()
 15  ->       return input_data // 2
 16
 17
 18     tf.config.run_functions_eagerly(True)
 19     model = CustomModel()
 20     model(tf.constant([-2, -4]))

Remarques:

Ne gardez pas tf.Tensors dans vos objets

Ces objets pourraient se tenseur soit créé dans un tf.function ou dans le contexte avide, et ces tenseurs se comporter différemment. Toujours utiliser tf.Tensor seulement pour des valeurs intermédiaires s.

Pour suivre l' état, utilisez tf.Variable s comme ils sont toujours utilisables des deux contextes. Lire le tf.Variable Guide pour en savoir plus.

Ressources et lectures complémentaires

  • Lire les TF2 guides et tutoriels pour en savoir plus sur l'utilisation de TF2.

  • Si vous avez déjà utilisé TF1.x, il est fortement recommandé de migrer votre code vers TF2. Lire les migrations guides pour en savoir plus.