Sintonizar con las primeras mujeres en ML Simposio este martes, 19 de octubre a 09 a.m. PST Registrar ahora

Tensorflow 2 efectivo

Ver en TensorFlow.org Ejecutar en Google Colab Ver en GitHub Descargar cuaderno

Visión general

Esta guía proporciona una lista de las mejores prácticas para escribir código con TensorFlow 2 (TF2). Consulte la sección de migración de la guía para obtener más información sobre la migración de su código TF1.x a TF2.

Configuración

Importa TensorFlow y otras dependencias para los ejemplos de esta guía.

import tensorflow as tf
import tensorflow_datasets as tfds

Recomendaciones para TensorFlow 2 idiomático

Refactorice su código en módulos más pequeños

Una buena práctica es refactorizar su código en funciones más pequeñas que se llaman según sea necesario. Para un mejor rendimiento, usted debe tratar de decorar los bloques más grandes de computación que pueda en un tf.function (nota que las funciones anidadas pitón llamados por un tf.function no requieren que sus propias decoraciones diferentes, a menos que desee utilizar diferentes jit_compile ajustes para el tf.function ). Dependiendo de su caso de uso, esto podría ser varios pasos de entrenamiento o incluso todo su ciclo de entrenamiento. Para casos de uso de inferencia, podría ser un pase directo de modelo único.

Ajustar la velocidad de aprendizaje por defecto para algunos tf.keras.optimizer s

Algunos optimizadores de Keras tienen diferentes tasas de aprendizaje en TF2. Si observa un cambio en el comportamiento de convergencia de sus modelos, verifique las tasas de aprendizaje predeterminadas.

No hay cambios para optimizers.SGD , optimizers.Adam o optimizers.RMSprop .

Las siguientes tasas de aprendizaje predeterminadas han cambiado:

Uso tf.Module s y capas Keras y administrar variables

tf.Module s y tf.keras.layers.Layer s oferta de los convenientes variables y trainable_variables propiedades, que se reúnen de forma recursiva hasta que todas las variables dependientes. Esto facilita la gestión de variables de forma local en el lugar donde se utilizan.

Capas Keras / modelos heredan de tf.train.Checkpointable y se integran con @tf.function , lo que hace posible en checkpoint directa o SavedModels exportación de objetos Keras. Usted no necesariamente tiene que usar Keras' Model.fit API para tomar ventaja de estas integraciones.

Lea la sección sobre el aprendizaje de transferencia y puesta a punto de la guía Keras para aprender a recoger un subconjunto de variables relevantes utilizando Keras.

Combinar tf.data.Dataset s y tf.function

El TensorFlow Conjuntos de datos de paquete ( tfds ) contiene utilidades para los conjuntos de datos predefinidos de carga como tf.data.Dataset objetos. Para este ejemplo, se puede cargar el conjunto de datos MNIST utilizando tfds :

datasets, info = tfds.load(name='mnist', with_info=True, as_supervised=True)
mnist_train, mnist_test = datasets['train'], datasets['test']

Luego, prepare los datos para el entrenamiento:

  • Vuelva a escalar cada imagen.
  • Mezcla el orden de los ejemplos.
  • Recopile lotes de imágenes y etiquetas.
BUFFER_SIZE = 10 # Use a much larger value for real code
BATCH_SIZE = 64
NUM_EPOCHS = 5


def scale(image, label):
  image = tf.cast(image, tf.float32)
  image /= 255

  return image, label

Para que el ejemplo sea breve, recorte el conjunto de datos para que solo devuelva 5 lotes:

train_data = mnist_train.map(scale).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
test_data = mnist_test.map(scale).batch(BATCH_SIZE)

STEPS_PER_EPOCH = 5

train_data = train_data.take(STEPS_PER_EPOCH)
test_data = test_data.take(STEPS_PER_EPOCH)
image_batch, label_batch = next(iter(train_data))
2021-09-22 22:13:17.284138: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.

Utilice la iteración regular de Python para iterar sobre los datos de entrenamiento que quepan en la memoria. De lo contrario, tf.data.Dataset es la mejor manera de transmitir datos de entrenamiento desde el disco. Los conjuntos de datos son iterables (no iteradores) , y funcionan igual que otros iterables Python en ejecución ansiosos. Se puede utilizar por completo conjunto de datos asíncrona captura previa / características de streaming envolviendo su código en tf.function , que sustituye Python iteración con las operaciones con gráficos equivalentes usando un autógrafo.

@tf.function
def train(model, dataset, optimizer):
  for x, y in dataset:
    with tf.GradientTape() as tape:
      # training=True is only needed if there are layers with different
      # behavior during training versus inference (e.g. Dropout).
      prediction = model(x, training=True)
      loss = loss_fn(prediction, y)
    gradients = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))

Si se utiliza el Keras Model.fit API, usted no tendrá que preocuparse por conjunto de datos iteración.

model.compile(optimizer=optimizer, loss=loss_fn)
model.fit(dataset)

Usa bucles de entrenamiento de Keras

Si no es necesario el control de bajo nivel de su proceso de formación, utilizando Keras integrado en un fit , evaluate y predict métodos se recomienda. Estos métodos proporcionan una interfaz uniforme para entrenar el modelo independientemente de la implementación (secuencial, funcional o subclasificada).

Las ventajas de estos métodos incluyen:

  • Ellos aceptan matrices numpy, generadores y Python, tf.data.Datasets .
  • Aplican pérdidas de regularización y activación de forma automática.
  • Apoyan tf.distribute donde el código de la formación sigue siendo el mismo , independientemente de la configuración de hardware .
  • Admiten reclamaciones arbitrarias como pérdidas y métricas.
  • Apoyan devoluciones de llamada como tf.keras.callbacks.TensorBoard y devoluciones de llamada personalizados.
  • Son eficaces y utilizan automáticamente los gráficos de TensorFlow.

Aquí está un ejemplo de la formación de un modelo utilizando un Dataset . Para más detalles sobre cómo funciona esto, echa un vistazo a los tutoriales .

model = tf.keras.Sequential([
    tf.keras.layers.Conv2D(32, 3, activation='relu',
                           kernel_regularizer=tf.keras.regularizers.l2(0.02),
                           input_shape=(28, 28, 1)),
    tf.keras.layers.MaxPooling2D(),
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dropout(0.1),
    tf.keras.layers.Dense(64, activation='relu'),
    tf.keras.layers.BatchNormalization(),
    tf.keras.layers.Dense(10)
])

# Model is the full model w/o custom layers
model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

model.fit(train_data, epochs=NUM_EPOCHS)
loss, acc = model.evaluate(test_data)

print("Loss {}, Accuracy {}".format(loss, acc))
Epoch 1/5
5/5 [==============================] - 9s 9ms/step - loss: 1.5774 - accuracy: 0.5063
Epoch 2/5
2021-09-22 22:13:26.932626: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
5/5 [==============================] - 0s 5ms/step - loss: 0.4498 - accuracy: 0.9125
Epoch 3/5
2021-09-22 22:13:27.323101: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
5/5 [==============================] - 0s 5ms/step - loss: 0.2929 - accuracy: 0.9563
Epoch 4/5
2021-09-22 22:13:27.717803: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
5/5 [==============================] - 0s 5ms/step - loss: 0.2055 - accuracy: 0.9875
Epoch 5/5
2021-09-22 22:13:28.088985: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
5/5 [==============================] - 0s 5ms/step - loss: 0.1669 - accuracy: 0.9937
2021-09-22 22:13:28.458529: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
5/5 [==============================] - 0s 3ms/step - loss: 1.6056 - accuracy: 0.6500
Loss 1.6056102514266968, Accuracy 0.6499999761581421
2021-09-22 22:13:28.956635: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.

Personaliza el entrenamiento y escribe tu propio bucle

Si los modelos de Keras funcionan para usted, pero necesita más flexibilidad y control del paso de entrenamiento o de los bucles de entrenamiento externos, puede implementar sus propios pasos de entrenamiento o incluso bucles de entrenamiento completos. Consulte la guía Keras en la personalización de fit para aprender más.

También puede implementar muchas cosas como tf.keras.callbacks.Callback .

Este método tiene muchas de las ventajas mencionadas anteriormente , pero le da el control del paso de trenes e incluso el bucle externo.

Hay tres pasos para un ciclo de entrenamiento estándar:

  1. Iterar sobre un generador de Python o tf.data.Dataset para obtener lotes de ejemplos.
  2. Utilice tf.GradientTape a los gradientes de cobro revertido.
  3. Utilice uno de los tf.keras.optimizers para aplicar las actualizaciones de peso a las variables del modelo.

Recordar:

  • Siempre incluya una training argumento en la call método de capas y modelos subclases.
  • Asegúrese de llamar el modelo con la training correcta conjunto argumento.
  • Dependiendo del uso, es posible que las variables del modelo no existan hasta que el modelo se ejecute en un lote de datos.
  • Necesita manejar manualmente cosas como pérdidas de regularización para el modelo.

No es necesario ejecutar inicializadores de variables ni agregar dependencias de control manual. tf.function maneja dependencias de control automático y la inicialización de variables en la creación para usted.

model = tf.keras.Sequential([
    tf.keras.layers.Conv2D(32, 3, activation='relu',
                           kernel_regularizer=tf.keras.regularizers.l2(0.02),
                           input_shape=(28, 28, 1)),
    tf.keras.layers.MaxPooling2D(),
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dropout(0.1),
    tf.keras.layers.Dense(64, activation='relu'),
    tf.keras.layers.BatchNormalization(),
    tf.keras.layers.Dense(10)
])

optimizer = tf.keras.optimizers.Adam(0.001)
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)

@tf.function
def train_step(inputs, labels):
  with tf.GradientTape() as tape:
    predictions = model(inputs, training=True)
    regularization_loss=tf.math.add_n(model.losses)
    pred_loss=loss_fn(labels, predictions)
    total_loss=pred_loss + regularization_loss

  gradients = tape.gradient(total_loss, model.trainable_variables)
  optimizer.apply_gradients(zip(gradients, model.trainable_variables))

for epoch in range(NUM_EPOCHS):
  for inputs, labels in train_data:
    train_step(inputs, labels)
  print("Finished epoch", epoch)
2021-09-22 22:13:29.878252: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
Finished epoch 0
2021-09-22 22:13:30.266807: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
Finished epoch 1
2021-09-22 22:13:30.626589: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
Finished epoch 2
2021-09-22 22:13:31.040058: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
Finished epoch 3
Finished epoch 4
2021-09-22 22:13:31.417637: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.

Tome ventaja de tf.function con flujo de control Python

tf.function proporciona una manera de convertir el flujo de control dependiente de los datos en equivalentes en modo gráfico como tf.cond y tf.while_loop .

Un lugar común donde aparece el flujo de control dependiente de los datos es en los modelos de secuencia. tf.keras.layers.RNN envuelve una célula RNN, lo que le permite ya sea estática o dinámica desenrollar la recurrencia. Como ejemplo, puede volver a implementar el desenrollado dinámico de la siguiente manera.

class DynamicRNN(tf.keras.Model):

  def __init__(self, rnn_cell):
    super(DynamicRNN, self).__init__(self)
    self.cell = rnn_cell

  @tf.function(input_signature=[tf.TensorSpec(dtype=tf.float32, shape=[None, None, 3])])
  def call(self, input_data):

    # [batch, time, features] -> [time, batch, features]
    input_data = tf.transpose(input_data, [1, 0, 2])
    timesteps =  tf.shape(input_data)[0]
    batch_size = tf.shape(input_data)[1]
    outputs = tf.TensorArray(tf.float32, timesteps)
    state = self.cell.get_initial_state(batch_size = batch_size, dtype=tf.float32)
    for i in tf.range(timesteps):
      output, state = self.cell(input_data[i], state)
      outputs = outputs.write(i, output)
    return tf.transpose(outputs.stack(), [1, 0, 2]), state
lstm_cell = tf.keras.layers.LSTMCell(units = 13)

my_rnn = DynamicRNN(lstm_cell)
outputs, state = my_rnn(tf.random.normal(shape=[10,20,3]))
print(outputs.shape)
(10, 20, 13)

Lea la tf.function guía para una mayor información.

Pérdidas y métricas de nuevo estilo

Métricas y las pérdidas son ambos objetos que el trabajo con entusiasmo y en tf.function s.

Un objeto pérdida es exigible, y espera ( y_true , y_pred ) como argumentos:

cce = tf.keras.losses.CategoricalCrossentropy(from_logits=True)
cce([[1, 0]], [[-1.0,3.0]]).numpy()
4.01815

Utilice métricas para recopilar y mostrar datos

Puede utilizar tf.metrics a los datos agregados y tf.summary para registrar resúmenes y redirigirlo a un escritor usando un gestor de contexto. Los resúmenes se emiten directamente a la escritora que significa que debe proporcionar el step valor en el callsite.

summary_writer = tf.summary.create_file_writer('/tmp/summaries')
with summary_writer.as_default():
  tf.summary.scalar('loss', 0.1, step=42)

Utilice tf.metrics a datos agregados antes de iniciar sesión como resúmenes. Las métricas tienen estado; se acumulan los valores y devuelven un resultado acumulativo cuando se llama al result método (como Mean.result ). Claro valores con acumuló Model.reset_states .

def train(model, optimizer, dataset, log_freq=10):
  avg_loss = tf.keras.metrics.Mean(name='loss', dtype=tf.float32)
  for images, labels in dataset:
    loss = train_step(model, optimizer, images, labels)
    avg_loss.update_state(loss)
    if tf.equal(optimizer.iterations % log_freq, 0):
      tf.summary.scalar('loss', avg_loss.result(), step=optimizer.iterations)
      avg_loss.reset_states()

def test(model, test_x, test_y, step_num):
  # training=False is only needed if there are layers with different
  # behavior during training versus inference (e.g. Dropout).
  loss = loss_fn(model(test_x, training=False), test_y)
  tf.summary.scalar('loss', loss, step=step_num)

train_summary_writer = tf.summary.create_file_writer('/tmp/summaries/train')
test_summary_writer = tf.summary.create_file_writer('/tmp/summaries/test')

with train_summary_writer.as_default():
  train(model, optimizer, dataset)

with test_summary_writer.as_default():
  test(model, test_x, test_y, optimizer.iterations)

Visualice los resúmenes generados apuntando a TensorBoard al directorio de registro de resumen:

tensorboard --logdir /tmp/summaries

Usar la tf.summary API para escribir los datos de resumen para la visualización en TensorBoard. Para obtener más información, lea la tf.summary guía .

# Create the metrics
loss_metric = tf.keras.metrics.Mean(name='train_loss')
accuracy_metric = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy')

@tf.function
def train_step(inputs, labels):
  with tf.GradientTape() as tape:
    predictions = model(inputs, training=True)
    regularization_loss=tf.math.add_n(model.losses)
    pred_loss=loss_fn(labels, predictions)
    total_loss=pred_loss + regularization_loss

  gradients = tape.gradient(total_loss, model.trainable_variables)
  optimizer.apply_gradients(zip(gradients, model.trainable_variables))
  # Update the metrics
  loss_metric.update_state(total_loss)
  accuracy_metric.update_state(labels, predictions)


for epoch in range(NUM_EPOCHS):
  # Reset the metrics
  loss_metric.reset_states()
  accuracy_metric.reset_states()

  for inputs, labels in train_data:
    train_step(inputs, labels)
  # Get the metric results
  mean_loss=loss_metric.result()
  mean_accuracy = accuracy_metric.result()

  print('Epoch: ', epoch)
  print('  loss:     {:.3f}'.format(mean_loss))
  print('  accuracy: {:.3f}'.format(mean_accuracy))
2021-09-22 22:13:32.370558: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
Epoch:  0
  loss:     0.143
  accuracy: 0.997
2021-09-22 22:13:32.752675: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
Epoch:  1
  loss:     0.119
  accuracy: 0.997
2021-09-22 22:13:33.122889: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
Epoch:  2
  loss:     0.106
  accuracy: 0.997
2021-09-22 22:13:33.522935: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
Epoch:  3
  loss:     0.089
  accuracy: 1.000
Epoch:  4
  loss:     0.079
  accuracy: 1.000
2021-09-22 22:13:33.899024: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.

Nombres de métricas de Keras

Los modelos de Keras son consistentes en el manejo de nombres de métricas. Cuando se pasa una cadena en la lista de métricas, exactamente esa cadena se utiliza como métrica del name . Estos nombres son visibles en la historia del objeto devuelto por model.fit , y en los registros pasan a keras.callbacks . se establece en la cadena que pasó en la lista de métricas.

model.compile(
    optimizer = tf.keras.optimizers.Adam(0.001),
    loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics = ['acc', 'accuracy', tf.keras.metrics.SparseCategoricalAccuracy(name="my_accuracy")])
history = model.fit(train_data)
5/5 [==============================] - 1s 5ms/step - loss: 0.0962 - acc: 0.9969 - accuracy: 0.9969 - my_accuracy: 0.9969
2021-09-22 22:13:34.802566: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
history.history.keys()
dict_keys(['loss', 'acc', 'accuracy', 'my_accuracy'])

Depuración

Utilice una ejecución ávida para ejecutar su código paso a paso para inspeccionar formas, tipos de datos y valores. Determinadas API, como tf.function , tf.keras , etc. están diseñados para utilizar Ejecución de gráfico, para un rendimiento y portabilidad. Al depurar, el uso tf.config.run_functions_eagerly(True) para utilizar la ejecución ansiosos dentro de este código.

Por ejemplo:

@tf.function
def f(x):
  if x > 0:
    import pdb
    pdb.set_trace()
    x = x + 1
  return x

tf.config.run_functions_eagerly(True)
f(tf.constant(1))
f()
-> x = x + 1
(Pdb) l
  6     @tf.function
  7     def f(x):
  8       if x > 0:
  9         import pdb
 10         pdb.set_trace()
 11  ->     x = x + 1
 12       return x
 13
 14     tf.config.run_functions_eagerly(True)
 15     f(tf.constant(1))
[EOF]

Esto también funciona dentro de los modelos de Keras y otras API que admiten una ejecución ávida:

class CustomModel(tf.keras.models.Model):

  @tf.function
  def call(self, input_data):
    if tf.reduce_mean(input_data) > 0:
      return input_data
    else:
      import pdb
      pdb.set_trace()
      return input_data // 2


tf.config.run_functions_eagerly(True)
model = CustomModel()
model(tf.constant([-2, -4]))
call()
-> return input_data // 2
(Pdb) l
 10         if tf.reduce_mean(input_data) > 0:
 11           return input_data
 12         else:
 13           import pdb
 14           pdb.set_trace()
 15  ->       return input_data // 2
 16
 17
 18     tf.config.run_functions_eagerly(True)
 19     model = CustomModel()
 20     model(tf.constant([-2, -4]))

Notas:

No mantenga tf.Tensors en sus objetos

Estos objetos tensor puede ser que consiga creado ya sea en un tf.function o en el contexto ansiosos, y estos tensores se comportan de manera diferente. Siempre use tf.Tensor s solamente para los valores intermedios.

Para realizar el seguimiento del estado, utilice tf.Variable s, ya que siempre se pueden utilizar desde ambos contextos. Lea la tf.Variable guía para aprender más.

Recursos y lectura adicional

  • Lea las TF2 guías y tutoriales para aprender más acerca de cómo usar TF2.

  • Si utilizó TF1.x anteriormente, se recomienda encarecidamente que migre su código a TF2. Lea las migraciones guías para aprender más.