Flusso tensoriale effettivo 2

Visualizza su TensorFlow.org Esegui in Google Colab Visualizza su GitHub Scarica taccuino

Panoramica

Questa guida fornisce un elenco di best practice per scrivere codice utilizzando TensorFlow 2 (TF2). Fare riferimento alla sezione di migrazione della guida per ulteriori informazioni sulla migrazione il codice TF1.x per TF2.

Impostare

Importa TensorFlow e altre dipendenze per gli esempi in questa guida.

import tensorflow as tf
import tensorflow_datasets as tfds

Raccomandazioni per TensorFlow 2 . idiomatico

Refactoring il codice in moduli più piccoli

Una buona pratica consiste nel refactoring del codice in funzioni più piccole che vengono chiamate secondo necessità. Per ottenere prestazioni ottimali, si dovrebbe cercare di decorare le più grandi blocchi di calcolo che si può in un tf.function (nota che le funzioni Python nidificate chiamati da un tf.function non richiedono le proprie decorazioni separati, a meno che non si desidera utilizzare diversi jit_compile impostazioni per il tf.function ). A seconda del tuo caso d'uso, potrebbero trattarsi di più fasi di allenamento o persino dell'intero ciclo di allenamento. Per i casi d'uso di inferenza, potrebbe essere un singolo passaggio in avanti del modello.

Regolare la velocità di apprendimento di default per alcuni tf.keras.optimizer s

Alcuni ottimizzatori Keras hanno diversi tassi di apprendimento in TF2. Se noti un cambiamento nel comportamento di convergenza per i tuoi modelli, controlla i tassi di apprendimento predefiniti.

Non ci sono cambiamenti per optimizers.SGD , optimizers.Adam o optimizers.RMSprop .

I seguenti tassi di apprendimento predefiniti sono cambiati:

Usa tf.Module strati Keras s e per gestire le variabili

tf.Module s e tf.keras.layers.Layer s offrono le convenienti variables e trainable_variables proprietà, che ricorsivamente raccogliere tutte le variabili dipendenti. Ciò semplifica la gestione delle variabili a livello locale rispetto al luogo in cui vengono utilizzate.

Strati Keras / modelli ereditano da tf.train.Checkpointable e sono integrati con @tf.function , che rende possibile il punto di controllo direttamente o SavedModels esportazione da oggetti Keras. Non è necessariamente bisogno di usare Keras' Model.fit API per approfittare di queste integrazioni.

Leggere la sezione di apprendimento trasferimento e messa a punto nella guida Keras per imparare a raccogliere un sottoinsieme di variabili pertinenti utilizzando Keras.

Combinare tf.data.Dataset s e tf.function

Il tensorflow Datasets pacchetto ( tfds ) contiene utilità per caricare set di dati predefiniti come tf.data.Dataset oggetti. Per questo esempio, è possibile caricare il set di dati utilizzando MNIST tfds :

datasets, info = tfds.load(name='mnist', with_info=True, as_supervised=True)
mnist_train, mnist_test = datasets['train'], datasets['test']

Quindi preparare i dati per l'allenamento:

  • Ridimensiona ogni immagine.
  • Rimescola l'ordine degli esempi.
  • Raccogli lotti di immagini ed etichette.
BUFFER_SIZE = 10 # Use a much larger value for real code
BATCH_SIZE = 64
NUM_EPOCHS = 5


def scale(image, label):
  image = tf.cast(image, tf.float32)
  image /= 255

  return image, label

Per mantenere l'esempio breve, taglia il set di dati per restituire solo 5 batch:

train_data = mnist_train.map(scale).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
test_data = mnist_test.map(scale).batch(BATCH_SIZE)

STEPS_PER_EPOCH = 5

train_data = train_data.take(STEPS_PER_EPOCH)
test_data = test_data.take(STEPS_PER_EPOCH)
image_batch, label_batch = next(iter(train_data))
2021-09-22 22:13:17.284138: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.

Usa la normale iterazione di Python per scorrere i dati di training che si adattano alla memoria. In caso contrario, tf.data.Dataset è il modo migliore per trasmettere i dati di allenamento dal disco. Dataset sono iterabili (non iteratori) , e funzionano come altri iterables pitone nell'esecuzione ansioso. È possibile utilizzare appieno set di dati asincrona prefetching / funzioni di streaming avvolgendo il codice in tf.function , che sostituisce Python iterazione con le operazioni grafico equivalenti utilizzando autografo.

@tf.function
def train(model, dataset, optimizer):
  for x, y in dataset:
    with tf.GradientTape() as tape:
      # training=True is only needed if there are layers with different
      # behavior during training versus inference (e.g. Dropout).
      prediction = model(x, training=True)
      loss = loss_fn(prediction, y)
    gradients = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))

Se si utilizza l'Keras Model.fit API, non dovrete preoccuparvi di set di dati iterazione.

model.compile(optimizer=optimizer, loss=loss_fn)
model.fit(dataset)

Usa i cicli di allenamento Keras

Se non avete bisogno di controllo di basso livello del processo di formazione, utilizzando Keras' built-in fit , evaluate e predict i metodi è raccomandato. Questi metodi forniscono un'interfaccia uniforme per addestrare il modello indipendentemente dall'implementazione (sequenziale, funzionale o sottoclasse).

I vantaggi di questi metodi includono:

  • Essi accettano array NumPy, generatori Python e, tf.data.Datasets .
  • Applicano la regolarizzazione e le perdite di attivazione automaticamente.
  • Sostengono tf.distribute in cui il codice di formazione rimane lo stesso indipendentemente dalla configurazione hardware .
  • Supportano callable arbitrari come perdite e metriche.
  • Sostengono callback come tf.keras.callbacks.TensorBoard , e callback personalizzati.
  • Sono performanti, utilizzando automaticamente i grafici TensorFlow.

Ecco un esempio di formazione di un modello usando un Dataset . Per i dettagli su come funziona, controllare le esercitazioni .

model = tf.keras.Sequential([
    tf.keras.layers.Conv2D(32, 3, activation='relu',
                           kernel_regularizer=tf.keras.regularizers.l2(0.02),
                           input_shape=(28, 28, 1)),
    tf.keras.layers.MaxPooling2D(),
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dropout(0.1),
    tf.keras.layers.Dense(64, activation='relu'),
    tf.keras.layers.BatchNormalization(),
    tf.keras.layers.Dense(10)
])

# Model is the full model w/o custom layers
model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

model.fit(train_data, epochs=NUM_EPOCHS)
loss, acc = model.evaluate(test_data)

print("Loss {}, Accuracy {}".format(loss, acc))
Epoch 1/5
5/5 [==============================] - 9s 9ms/step - loss: 1.5774 - accuracy: 0.5063
Epoch 2/5
2021-09-22 22:13:26.932626: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
5/5 [==============================] - 0s 5ms/step - loss: 0.4498 - accuracy: 0.9125
Epoch 3/5
2021-09-22 22:13:27.323101: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
5/5 [==============================] - 0s 5ms/step - loss: 0.2929 - accuracy: 0.9563
Epoch 4/5
2021-09-22 22:13:27.717803: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
5/5 [==============================] - 0s 5ms/step - loss: 0.2055 - accuracy: 0.9875
Epoch 5/5
2021-09-22 22:13:28.088985: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
5/5 [==============================] - 0s 5ms/step - loss: 0.1669 - accuracy: 0.9937
2021-09-22 22:13:28.458529: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
5/5 [==============================] - 0s 3ms/step - loss: 1.6056 - accuracy: 0.6500
Loss 1.6056102514266968, Accuracy 0.6499999761581421
2021-09-22 22:13:28.956635: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.

Personalizza l'allenamento e scrivi il tuo loop

Se i modelli Keras funzionano per te, ma hai bisogno di maggiore flessibilità e controllo della fase di allenamento o dei cicli di allenamento esterni, puoi implementare i tuoi passaggi di allenamento o anche interi cicli di allenamento. Consultare la guida Keras sulla personalizzazione fit per saperne di più.

È inoltre possibile implementare molte cose come tf.keras.callbacks.Callback .

Questo metodo ha molti dei vantaggi citati in precedenza , ma vi dà il controllo della fase di treno e anche il ciclo esterno.

Ci sono tre passaggi per un ciclo di allenamento standard:

  1. Iterare su un generatore di Python o tf.data.Dataset per ottenere lotti di esempi.
  2. Utilizzare tf.GradientTape ai gradienti Raccogliere.
  3. Utilizzare uno dei tf.keras.optimizers per applicare gli aggiornamenti di peso alle variabili del modello.

Ricordare:

  • Sempre includere una training argomento sulla call metodo di strati e modelli sottoclasse.
  • Assicurarsi di chiamare il modello con la training correttamente impostato argomento.
  • A seconda dell'utilizzo, le variabili del modello potrebbero non esistere finché il modello non viene eseguito su un batch di dati.
  • Devi gestire manualmente cose come le perdite di regolarizzazione per il modello.

Non è necessario eseguire inizializzatori di variabili o aggiungere dipendenze di controllo manuali. tf.function gestisce le dipendenze di controllo automatico e inizializzazione delle variabili sulla creazione per voi.

model = tf.keras.Sequential([
    tf.keras.layers.Conv2D(32, 3, activation='relu',
                           kernel_regularizer=tf.keras.regularizers.l2(0.02),
                           input_shape=(28, 28, 1)),
    tf.keras.layers.MaxPooling2D(),
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dropout(0.1),
    tf.keras.layers.Dense(64, activation='relu'),
    tf.keras.layers.BatchNormalization(),
    tf.keras.layers.Dense(10)
])

optimizer = tf.keras.optimizers.Adam(0.001)
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)

@tf.function
def train_step(inputs, labels):
  with tf.GradientTape() as tape:
    predictions = model(inputs, training=True)
    regularization_loss=tf.math.add_n(model.losses)
    pred_loss=loss_fn(labels, predictions)
    total_loss=pred_loss + regularization_loss

  gradients = tape.gradient(total_loss, model.trainable_variables)
  optimizer.apply_gradients(zip(gradients, model.trainable_variables))

for epoch in range(NUM_EPOCHS):
  for inputs, labels in train_data:
    train_step(inputs, labels)
  print("Finished epoch", epoch)
2021-09-22 22:13:29.878252: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
Finished epoch 0
2021-09-22 22:13:30.266807: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
Finished epoch 1
2021-09-22 22:13:30.626589: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
Finished epoch 2
2021-09-22 22:13:31.040058: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
Finished epoch 3
Finished epoch 4
2021-09-22 22:13:31.417637: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.

Approfittate di tf.function con flusso di controllo Python

tf.function fornisce un modo per convertire il flusso di controllo dati dipendenti dalla in equivalenti grafico-mode come tf.cond e tf.while_loop .

Un luogo comune in cui appare il flusso di controllo dipendente dai dati è nei modelli di sequenza. tf.keras.layers.RNN avvolge una cella RNN, che consente di in modo statico o dinamico Srotolare la ricorrenza. Ad esempio, puoi reimplementare lo srotolamento dinamico come segue.

class DynamicRNN(tf.keras.Model):

  def __init__(self, rnn_cell):
    super(DynamicRNN, self).__init__(self)
    self.cell = rnn_cell

  @tf.function(input_signature=[tf.TensorSpec(dtype=tf.float32, shape=[None, None, 3])])
  def call(self, input_data):

    # [batch, time, features] -> [time, batch, features]
    input_data = tf.transpose(input_data, [1, 0, 2])
    timesteps =  tf.shape(input_data)[0]
    batch_size = tf.shape(input_data)[1]
    outputs = tf.TensorArray(tf.float32, timesteps)
    state = self.cell.get_initial_state(batch_size = batch_size, dtype=tf.float32)
    for i in tf.range(timesteps):
      output, state = self.cell(input_data[i], state)
      outputs = outputs.write(i, output)
    return tf.transpose(outputs.stack(), [1, 0, 2]), state
lstm_cell = tf.keras.layers.LSTMCell(units = 13)

my_rnn = DynamicRNN(lstm_cell)
outputs, state = my_rnn(tf.random.normal(shape=[10,20,3]))
print(outputs.shape)
(10, 20, 13)

Leggi la tf.function guida per una più informazioni.

Metriche e perdite di nuovo stile

Le metriche e le perdite sono entrambi gli oggetti che il lavoro con entusiasmo e in tf.function s.

Un oggetto perdita è richiamabile, e si aspetta ( y_true , y_pred ) come argomenti:

cce = tf.keras.losses.CategoricalCrossentropy(from_logits=True)
cce([[1, 0]], [[-1.0,3.0]]).numpy()
4.01815

Utilizza le metriche per raccogliere e visualizzare i dati

È possibile utilizzare tf.metrics ai dati aggregati e tf.summary per accedere sommari e destinarli a uno scrittore attraverso un gestore di contesto. Le sintesi sono emessi direttamente allo scrittore che significa che è necessario fornire il step valore alla callsite.

summary_writer = tf.summary.create_file_writer('/tmp/summaries')
with summary_writer.as_default():
  tf.summary.scalar('loss', 0.1, step=42)

Utilizzare tf.metrics ai dati aggregati prima di accedere come sommari. Le metriche sono stateful; si accumulano i valori e restituiscono un risultato cumulativo quando si chiama il result metodo (ad esempio Mean.result ). Cancella valori con accumulato Model.reset_states .

def train(model, optimizer, dataset, log_freq=10):
  avg_loss = tf.keras.metrics.Mean(name='loss', dtype=tf.float32)
  for images, labels in dataset:
    loss = train_step(model, optimizer, images, labels)
    avg_loss.update_state(loss)
    if tf.equal(optimizer.iterations % log_freq, 0):
      tf.summary.scalar('loss', avg_loss.result(), step=optimizer.iterations)
      avg_loss.reset_states()

def test(model, test_x, test_y, step_num):
  # training=False is only needed if there are layers with different
  # behavior during training versus inference (e.g. Dropout).
  loss = loss_fn(model(test_x, training=False), test_y)
  tf.summary.scalar('loss', loss, step=step_num)

train_summary_writer = tf.summary.create_file_writer('/tmp/summaries/train')
test_summary_writer = tf.summary.create_file_writer('/tmp/summaries/test')

with train_summary_writer.as_default():
  train(model, optimizer, dataset)

with test_summary_writer.as_default():
  test(model, test_x, test_y, optimizer.iterations)

Visualizza i riepiloghi generati puntando TensorBoard alla directory del registro di riepilogo:

tensorboard --logdir /tmp/summaries

Utilizzare il tf.summary API per i dati di riepilogo di scrittura per la visualizzazione in TensorBoard. Per ulteriori informazioni, leggere la tf.summary guida .

# Create the metrics
loss_metric = tf.keras.metrics.Mean(name='train_loss')
accuracy_metric = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy')

@tf.function
def train_step(inputs, labels):
  with tf.GradientTape() as tape:
    predictions = model(inputs, training=True)
    regularization_loss=tf.math.add_n(model.losses)
    pred_loss=loss_fn(labels, predictions)
    total_loss=pred_loss + regularization_loss

  gradients = tape.gradient(total_loss, model.trainable_variables)
  optimizer.apply_gradients(zip(gradients, model.trainable_variables))
  # Update the metrics
  loss_metric.update_state(total_loss)
  accuracy_metric.update_state(labels, predictions)


for epoch in range(NUM_EPOCHS):
  # Reset the metrics
  loss_metric.reset_states()
  accuracy_metric.reset_states()

  for inputs, labels in train_data:
    train_step(inputs, labels)
  # Get the metric results
  mean_loss=loss_metric.result()
  mean_accuracy = accuracy_metric.result()

  print('Epoch: ', epoch)
  print('  loss:     {:.3f}'.format(mean_loss))
  print('  accuracy: {:.3f}'.format(mean_accuracy))
2021-09-22 22:13:32.370558: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
Epoch:  0
  loss:     0.143
  accuracy: 0.997
2021-09-22 22:13:32.752675: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
Epoch:  1
  loss:     0.119
  accuracy: 0.997
2021-09-22 22:13:33.122889: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
Epoch:  2
  loss:     0.106
  accuracy: 0.997
2021-09-22 22:13:33.522935: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
Epoch:  3
  loss:     0.089
  accuracy: 1.000
Epoch:  4
  loss:     0.079
  accuracy: 1.000
2021-09-22 22:13:33.899024: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.

Nomi delle metriche Keras

I modelli Keras sono coerenti nella gestione dei nomi delle metriche. Quando si passa una stringa nella lista di metriche, la stringa esatta viene utilizzata come metrica name . Questi nomi sono visibili nell'oggetto storia restituito da model.fit , e nei registri passati al keras.callbacks . è impostato sulla stringa passata nell'elenco delle metriche.

model.compile(
    optimizer = tf.keras.optimizers.Adam(0.001),
    loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics = ['acc', 'accuracy', tf.keras.metrics.SparseCategoricalAccuracy(name="my_accuracy")])
history = model.fit(train_data)
5/5 [==============================] - 1s 5ms/step - loss: 0.0962 - acc: 0.9969 - accuracy: 0.9969 - my_accuracy: 0.9969
2021-09-22 22:13:34.802566: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
history.history.keys()
dict_keys(['loss', 'acc', 'accuracy', 'my_accuracy'])

Debug

Usa l'esecuzione entusiasta per eseguire il codice passo dopo passo per ispezionare forme, tipi di dati e valori. Alcune API, come tf.function , tf.keras , ecc sono progettati per utilizzare l'esecuzione grafico, per le prestazioni e portabilità. Quando il debug, uso tf.config.run_functions_eagerly(True) per usare esecuzione ansiosi all'interno di questo codice.

Per esempio:

@tf.function
def f(x):
  if x > 0:
    import pdb
    pdb.set_trace()
    x = x + 1
  return x

tf.config.run_functions_eagerly(True)
f(tf.constant(1))
f()
-> x = x + 1
(Pdb) l
  6     @tf.function
  7     def f(x):
  8       if x > 0:
  9         import pdb
 10         pdb.set_trace()
 11  ->     x = x + 1
 12       return x
 13
 14     tf.config.run_functions_eagerly(True)
 15     f(tf.constant(1))
[EOF]

Funziona anche all'interno dei modelli Keras e di altre API che supportano l'esecuzione desiderosa:

class CustomModel(tf.keras.models.Model):

  @tf.function
  def call(self, input_data):
    if tf.reduce_mean(input_data) > 0:
      return input_data
    else:
      import pdb
      pdb.set_trace()
      return input_data // 2


tf.config.run_functions_eagerly(True)
model = CustomModel()
model(tf.constant([-2, -4]))
call()
-> return input_data // 2
(Pdb) l
 10         if tf.reduce_mean(input_data) > 0:
 11           return input_data
 12         else:
 13           import pdb
 14           pdb.set_trace()
 15  ->       return input_data // 2
 16
 17
 18     tf.config.run_functions_eagerly(True)
 19     model = CustomModel()
 20     model(tf.constant([-2, -4]))

Appunti:

Non tenere tf.Tensors nei vostri oggetti

Questi oggetti tensore potrebbe avere creato o in un tf.function o nel contesto ansioso, e questi tensori comportarsi in modo diverso. Utilizzare sempre tf.Tensor s solo per valori intermedi.

Per tenere traccia dello stato, usare tf.Variable s in quanto sono sempre utilizzabile da entrambi i contesti. Leggi la tf.Variable guida per saperne di più.

Risorse e ulteriori letture

  • Leggere le TF2 guide e tutorial per imparare di più su come utilizzare TF2.

  • Se in precedenza hai utilizzato TF1.x, ti consigliamo vivamente di eseguire la migrazione del codice a TF2. Leggere le migrazione guide per saperne di più.