Efektywny przepływ tensorowy 2

Zobacz na TensorFlow.org Uruchom w Google Colab Zobacz na GitHub Pobierz notatnik

Przegląd

Ten przewodnik zawiera listę najlepszych praktyk dotyczących pisania kodu przy użyciu TensorFlow 2 (TF2), jest napisany dla użytkowników, którzy niedawno przeszli z TensorFlow 1 (TF1). Zapoznaj się z sekcją dotyczącą migracji w przewodniku, aby uzyskać więcej informacji na temat migracji kodu TF1 do TF2.

Ustawiać

Zaimportuj TensorFlow i inne zależności dla przykładów w tym przewodniku.

import tensorflow as tf
import tensorflow_datasets as tfds

Zalecenia dotyczące idiomatycznego TensorFlow 2

Refaktoryzuj swój kod na mniejsze moduły

Dobrą praktyką jest refaktoryzacja kodu na mniejsze funkcje, które są wywoływane w razie potrzeby. Aby uzyskać najlepszą wydajność, powinieneś spróbować udekorować największe bloki obliczeń, które możesz w tf.function (zauważ, że zagnieżdżone funkcje Pythona wywoływane przez tf.function nie wymagają oddzielnych dekoracji, chyba że chcesz użyć innego jit_compile ustawienia funkcji tf.function . W zależności od przypadku użycia może to być wiele kroków treningowych lub nawet cała pętla treningowa. W przypadku użycia wnioskowania może to być pojedynczy przebieg modelu.

Dostosuj domyślną szybkość uczenia się dla niektórych tf.keras.optimizer s

Niektóre optymalizatory Keras mają różne szybkości uczenia się w TF2. Jeśli zauważysz zmianę w zachowaniu zbieżności modeli, sprawdź domyślne współczynniki uczenia się.

Nie wprowadzono żadnych zmian w optimizers.RMSprop , optimizers.SGD optimizers.Adam .

Zmieniły się następujące domyślne współczynniki uczenia się:

Użyj tf.Module s i Keras do zarządzania zmiennymi

tf.Module s i tf.keras.layers.Layer s oferują wygodne variables i właściwości trainable_variables , które rekurencyjnie gromadzą wszystkie zmienne zależne. Ułatwia to lokalne zarządzanie zmiennymi tam, gdzie są używane.

Warstwy/modele Keras dziedziczą po tf.train.Checkpointable i są zintegrowane z @tf.function , co umożliwia bezpośrednie sprawdzanie punktu kontrolnego lub eksportowanie SavedModels z obiektów Keras. Nie musisz koniecznie korzystać z API Keras Model.fit , aby skorzystać z tych integracji.

Przeczytaj sekcję o transferze uczenia się i dostrajaniu w przewodniku Keras, aby dowiedzieć się, jak zebrać podzbiór odpowiednich zmiennych za pomocą Keras.

Połącz tf.data.Dataset s i tf.function

Pakiet TensorFlow Datasets ( tfds ) zawiera narzędzia do ładowania predefiniowanych zestawów danych jako obiektów tf.data.Dataset . W tym przykładzie możesz załadować zestaw danych MNIST za pomocą tfds :

datasets, info = tfds.load(name='mnist', with_info=True, as_supervised=True)
mnist_train, mnist_test = datasets['train'], datasets['test']

Następnie przygotuj dane do treningu:

  • Skaluj ponownie każdy obraz.
  • Potasuj kolejność przykładów.
  • Zbieraj partie obrazów i etykiet.
BUFFER_SIZE = 10 # Use a much larger value for real code
BATCH_SIZE = 64
NUM_EPOCHS = 5


def scale(image, label):
  image = tf.cast(image, tf.float32)
  image /= 255

  return image, label

Aby przykład był krótki, przytnij zbiór danych, aby zwracał tylko 5 partii:

train_data = mnist_train.map(scale).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
test_data = mnist_test.map(scale).batch(BATCH_SIZE)

STEPS_PER_EPOCH = 5

train_data = train_data.take(STEPS_PER_EPOCH)
test_data = test_data.take(STEPS_PER_EPOCH)
image_batch, label_batch = next(iter(train_data))
2021-12-08 17:15:01.637157: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.

Użyj zwykłej iteracji Pythona do iteracji danych uczących, które mieszczą się w pamięci. W przeciwnym razie tf.data.Dataset to najlepszy sposób na przesyłanie strumieniowe danych szkoleniowych z dysku. Zbiory danych to iterable (nie iteratory) i działają tak samo jak inne iterable Pythona w gorliwym wykonywaniu. Możesz w pełni wykorzystać asynchroniczne funkcje wstępnego pobierania/przesyłania strumieniowego zestawu danych, opakowując swój kod w tf.function , który zastępuje iterację Pythona równoważnymi operacjami wykresu przy użyciu AutoGraph.

@tf.function
def train(model, dataset, optimizer):
  for x, y in dataset:
    with tf.GradientTape() as tape:
      # training=True is only needed if there are layers with different
      # behavior during training versus inference (e.g. Dropout).
      prediction = model(x, training=True)
      loss = loss_fn(prediction, y)
    gradients = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))

Jeśli korzystasz z interfejsu API Keras Model.fit , nie musisz się martwić o iterację zestawu danych.

model.compile(optimizer=optimizer, loss=loss_fn)
model.fit(dataset)

Użyj pętli treningowych Keras

Jeśli nie potrzebujesz niskopoziomowej kontroli nad procesem treningowym, zalecane jest użycie wbudowanych metod Keras fit , evaluate i predict . Te metody zapewniają jednolity interfejs do uczenia modelu niezależnie od implementacji (sekwencyjnej, funkcjonalnej lub podklasy).

Zaletami tych metod są:

  • Akceptują tablice Numpy, generatory Pythona i tf.data.Datasets .
  • Stosują regularyzację i straty aktywacyjne automatycznie.
  • Obsługują tf.distribute , gdzie kod szkolenia pozostaje taki sam niezależnie od konfiguracji sprzętu .
  • Obsługują arbitralne nabytki jako straty i metryki.
  • Obsługują one wywołania zwrotne, takie jak tf.keras.callbacks.TensorBoard i niestandardowe wywołania zwrotne.
  • Są wydajne, automatycznie przy użyciu wykresów TensorFlow.

Oto przykład uczenia modelu przy użyciu Dataset . Aby dowiedzieć się, jak to działa, zapoznaj się z samouczkami .

model = tf.keras.Sequential([
    tf.keras.layers.Conv2D(32, 3, activation='relu',
                           kernel_regularizer=tf.keras.regularizers.l2(0.02),
                           input_shape=(28, 28, 1)),
    tf.keras.layers.MaxPooling2D(),
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dropout(0.1),
    tf.keras.layers.Dense(64, activation='relu'),
    tf.keras.layers.BatchNormalization(),
    tf.keras.layers.Dense(10)
])

# Model is the full model w/o custom layers
model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

model.fit(train_data, epochs=NUM_EPOCHS)
loss, acc = model.evaluate(test_data)

print("Loss {}, Accuracy {}".format(loss, acc))
Epoch 1/5
5/5 [==============================] - 9s 7ms/step - loss: 1.5762 - accuracy: 0.4938
Epoch 2/5
2021-12-08 17:15:11.145429: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
5/5 [==============================] - 0s 6ms/step - loss: 0.5087 - accuracy: 0.8969
Epoch 3/5
2021-12-08 17:15:11.559374: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
5/5 [==============================] - 2s 5ms/step - loss: 0.3348 - accuracy: 0.9469
Epoch 4/5
2021-12-08 17:15:13.860407: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
5/5 [==============================] - 0s 5ms/step - loss: 0.2445 - accuracy: 0.9688
Epoch 5/5
2021-12-08 17:15:14.269850: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
5/5 [==============================] - 0s 6ms/step - loss: 0.2006 - accuracy: 0.9719
2021-12-08 17:15:14.717552: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
5/5 [==============================] - 1s 4ms/step - loss: 1.4553 - accuracy: 0.5781
Loss 1.4552843570709229, Accuracy 0.578125
2021-12-08 17:15:15.862684: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.

Dostosuj trening i napisz własną pętlę

Jeśli modele Keras działają dla Ciebie, ale potrzebujesz większej elastyczności i kontroli nad krokiem treningowym lub zewnętrznymi pętlami treningowymi, możesz wdrożyć własne kroki treningowe lub nawet całe pętle treningowe. Zobacz przewodnik Keras dotyczący dostosowywania fit , aby dowiedzieć się więcej.

Możesz także zaimplementować wiele rzeczy jako tf.keras.callbacks.Callback .

Ta metoda ma wiele zalet wspomnianych wcześniej , ale daje kontrolę nad krokiem pociągu, a nawet zewnętrzną pętlą.

Standardowa pętla treningowa składa się z trzech kroków:

  1. Wykonaj iterację przez generator Pythona lub tf.data.Dataset , aby uzyskać partie przykładów.
  2. Użyj tf.GradientTape do zbierania gradientów.
  3. Użyj jednego z tf.keras.optimizers , aby zastosować aktualizacje wagi do zmiennych modelu.

Pamiętać:

  • Zawsze dołączaj argument training w metodzie call podklas warstw i modeli.
  • Upewnij się, że wywołałeś model z poprawnie ustawionym argumentem training .
  • W zależności od użycia zmienne modelu mogą nie istnieć, dopóki model nie zostanie uruchomiony na partii danych.
  • Musisz ręcznie poradzić sobie z takimi rzeczami, jak straty związane z regularyzacją modelu.

Nie ma potrzeby uruchamiania inicjatorów zmiennych ani dodawania ręcznych zależności sterowania. tf.function obsługuje automatyczne zależności sterowania i inicjalizację zmiennych podczas tworzenia.

model = tf.keras.Sequential([
    tf.keras.layers.Conv2D(32, 3, activation='relu',
                           kernel_regularizer=tf.keras.regularizers.l2(0.02),
                           input_shape=(28, 28, 1)),
    tf.keras.layers.MaxPooling2D(),
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dropout(0.1),
    tf.keras.layers.Dense(64, activation='relu'),
    tf.keras.layers.BatchNormalization(),
    tf.keras.layers.Dense(10)
])

optimizer = tf.keras.optimizers.Adam(0.001)
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)

@tf.function
def train_step(inputs, labels):
  with tf.GradientTape() as tape:
    predictions = model(inputs, training=True)
    regularization_loss=tf.math.add_n(model.losses)
    pred_loss=loss_fn(labels, predictions)
    total_loss=pred_loss + regularization_loss

  gradients = tape.gradient(total_loss, model.trainable_variables)
  optimizer.apply_gradients(zip(gradients, model.trainable_variables))

for epoch in range(NUM_EPOCHS):
  for inputs, labels in train_data:
    train_step(inputs, labels)
  print("Finished epoch", epoch)
2021-12-08 17:15:16.714849: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
Finished epoch 0
2021-12-08 17:15:17.097043: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
Finished epoch 1
2021-12-08 17:15:17.502480: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
Finished epoch 2
2021-12-08 17:15:17.873701: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
Finished epoch 3
Finished epoch 4
2021-12-08 17:15:18.344196: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.

Skorzystaj z tf.function z przepływem sterowania w Pythonie

tf.function umożliwia konwersję zależnego od danych przepływu sterowania na jego odpowiedniki w trybie wykresu, takie jak tf.cond i tf.while_loop .

Jednym z powszechnych miejsc, w których pojawia się przepływ sterowania zależny od danych, są modele sekwencyjne. tf.keras.layers.RNN otacza komórkę RNN, umożliwiając statyczne lub dynamiczne rozwijanie cyklu. Na przykład możesz ponownie wdrożyć dynamiczne rozwijanie w następujący sposób.

class DynamicRNN(tf.keras.Model):

  def __init__(self, rnn_cell):
    super(DynamicRNN, self).__init__(self)
    self.cell = rnn_cell

  @tf.function(input_signature=[tf.TensorSpec(dtype=tf.float32, shape=[None, None, 3])])
  def call(self, input_data):

    # [batch, time, features] -> [time, batch, features]
    input_data = tf.transpose(input_data, [1, 0, 2])
    timesteps =  tf.shape(input_data)[0]
    batch_size = tf.shape(input_data)[1]
    outputs = tf.TensorArray(tf.float32, timesteps)
    state = self.cell.get_initial_state(batch_size = batch_size, dtype=tf.float32)
    for i in tf.range(timesteps):
      output, state = self.cell(input_data[i], state)
      outputs = outputs.write(i, output)
    return tf.transpose(outputs.stack(), [1, 0, 2]), state
lstm_cell = tf.keras.layers.LSTMCell(units = 13)

my_rnn = DynamicRNN(lstm_cell)
outputs, state = my_rnn(tf.random.normal(shape=[10,20,3]))
print(outputs.shape)
(10, 20, 13)

Przeczytaj przewodnik tf.function , aby uzyskać więcej informacji.

Wskaźniki i straty w nowym stylu

Metryki i straty to zarówno obiekty, które chętnie pracują, jak i w tf.function s.

Obiekt straty można wywołać i oczekuje ( y_true , y_pred ) jako argumentów:

cce = tf.keras.losses.CategoricalCrossentropy(from_logits=True)
cce([[1, 0]], [[-1.0,3.0]]).numpy()
4.01815

Używaj metryk do zbierania i wyświetlania danych

Możesz użyć tf.metrics do agregowania danych i tf.summary do dzienników podsumowań i przekierować je do autora za pomocą menedżera kontekstu. Podsumowania są emitowane bezpośrednio do piszącego, co oznacza, że ​​musisz podać wartość step w miejscu wywołania.

summary_writer = tf.summary.create_file_writer('/tmp/summaries')
with summary_writer.as_default():
  tf.summary.scalar('loss', 0.1, step=42)

Użyj tf.metrics do agregowania danych przed zarejestrowaniem ich jako podsumowań. Metryki są stanowe; gromadzą wartości i zwracają skumulowany wynik po wywołaniu metody result (takiej jak Mean.result ). Wyczyść skumulowane wartości za pomocą Model.reset_states .

def train(model, optimizer, dataset, log_freq=10):
  avg_loss = tf.keras.metrics.Mean(name='loss', dtype=tf.float32)
  for images, labels in dataset:
    loss = train_step(model, optimizer, images, labels)
    avg_loss.update_state(loss)
    if tf.equal(optimizer.iterations % log_freq, 0):
      tf.summary.scalar('loss', avg_loss.result(), step=optimizer.iterations)
      avg_loss.reset_states()

def test(model, test_x, test_y, step_num):
  # training=False is only needed if there are layers with different
  # behavior during training versus inference (e.g. Dropout).
  loss = loss_fn(model(test_x, training=False), test_y)
  tf.summary.scalar('loss', loss, step=step_num)

train_summary_writer = tf.summary.create_file_writer('/tmp/summaries/train')
test_summary_writer = tf.summary.create_file_writer('/tmp/summaries/test')

with train_summary_writer.as_default():
  train(model, optimizer, dataset)

with test_summary_writer.as_default():
  test(model, test_x, test_y, optimizer.iterations)

Wizualizuj wygenerowane podsumowania, wskazując TensorBoard na katalog dziennika podsumowań:

tensorboard --logdir /tmp/summaries

Użyj interfejsu API tf.summary , aby zapisać dane podsumowujące do wizualizacji w TensorBoard. Aby uzyskać więcej informacji, przeczytaj przewodnik tf.summary .

# Create the metrics
loss_metric = tf.keras.metrics.Mean(name='train_loss')
accuracy_metric = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy')

@tf.function
def train_step(inputs, labels):
  with tf.GradientTape() as tape:
    predictions = model(inputs, training=True)
    regularization_loss=tf.math.add_n(model.losses)
    pred_loss=loss_fn(labels, predictions)
    total_loss=pred_loss + regularization_loss

  gradients = tape.gradient(total_loss, model.trainable_variables)
  optimizer.apply_gradients(zip(gradients, model.trainable_variables))
  # Update the metrics
  loss_metric.update_state(total_loss)
  accuracy_metric.update_state(labels, predictions)


for epoch in range(NUM_EPOCHS):
  # Reset the metrics
  loss_metric.reset_states()
  accuracy_metric.reset_states()

  for inputs, labels in train_data:
    train_step(inputs, labels)
  # Get the metric results
  mean_loss=loss_metric.result()
  mean_accuracy = accuracy_metric.result()

  print('Epoch: ', epoch)
  print('  loss:     {:.3f}'.format(mean_loss))
  print('  accuracy: {:.3f}'.format(mean_accuracy))
2021-12-08 17:15:19.339736: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
Epoch:  0
  loss:     0.142
  accuracy: 0.991
2021-12-08 17:15:19.781743: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
Epoch:  1
  loss:     0.125
  accuracy: 0.997
2021-12-08 17:15:20.219033: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
Epoch:  2
  loss:     0.110
  accuracy: 0.997
2021-12-08 17:15:20.598085: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
Epoch:  3
  loss:     0.099
  accuracy: 0.997
Epoch:  4
  loss:     0.085
  accuracy: 1.000
2021-12-08 17:15:20.981787: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.

Keras nazwy metryczne

Modele Keras są spójne pod względem obsługi nazw metryk. Gdy przekazujesz ciąg na liście metryk, ten właśnie ciąg jest używany jako name metryki . Nazwy te są widoczne w obiekcie historii zwróconym przez model.fit oraz w logach przekazanych do keras.callbacks . jest ustawiony na ciąg znaków przekazany na liście metryk.

model.compile(
    optimizer = tf.keras.optimizers.Adam(0.001),
    loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics = ['acc', 'accuracy', tf.keras.metrics.SparseCategoricalAccuracy(name="my_accuracy")])
history = model.fit(train_data)
5/5 [==============================] - 1s 5ms/step - loss: 0.0963 - acc: 0.9969 - accuracy: 0.9969 - my_accuracy: 0.9969
2021-12-08 17:15:21.942940: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
history.history.keys()
dict_keys(['loss', 'acc', 'accuracy', 'my_accuracy'])

Debugowanie

Użyj szybkiego wykonania, aby uruchomić kod krok po kroku, aby sprawdzić kształty, typy danych i wartości. Niektóre interfejsy API, takie jak tf.function , tf.keras itp., są zaprojektowane do korzystania z wykonywania wykresów w celu zapewnienia wydajności i przenośności. Podczas debugowania użyj tf.config.run_functions_eagerly(True) , aby użyć szybkiego wykonywania wewnątrz tego kodu.

Na przykład:

@tf.function
def f(x):
  if x > 0:
    import pdb
    pdb.set_trace()
    x = x + 1
  return x

tf.config.run_functions_eagerly(True)
f(tf.constant(1))
>>> f()
-> x = x + 1
(Pdb) l
  6     @tf.function
  7     def f(x):
  8       if x > 0:
  9         import pdb
 10         pdb.set_trace()
 11  ->     x = x + 1
 12       return x
 13
 14     tf.config.run_functions_eagerly(True)
 15     f(tf.constant(1))
[EOF]

Działa to również w modelach Keras i innych interfejsach API, które obsługują szybkie wykonanie:

class CustomModel(tf.keras.models.Model):

  @tf.function
  def call(self, input_data):
    if tf.reduce_mean(input_data) > 0:
      return input_data
    else:
      import pdb
      pdb.set_trace()
      return input_data // 2


tf.config.run_functions_eagerly(True)
model = CustomModel()
model(tf.constant([-2, -4]))
>>> call()
-> return input_data // 2
(Pdb) l
 10         if tf.reduce_mean(input_data) > 0:
 11           return input_data
 12         else:
 13           import pdb
 14           pdb.set_trace()
 15  ->       return input_data // 2
 16
 17
 18     tf.config.run_functions_eagerly(True)
 19     model = CustomModel()
 20     model(tf.constant([-2, -4]))

Uwagi:

Nie trzymaj tf.Tensors w swoich obiektach

Te tensory mogą zostać utworzone w funkcji tf.function lub w gorliwym kontekście, a te tensory zachowują się inaczej. Zawsze używaj tf.Tensor s tylko dla wartości pośrednich.

Aby śledzić stan, użyj tf.Variable s, ponieważ zawsze można ich używać z obu kontekstów. Przeczytaj przewodnik tf.Variable , aby dowiedzieć się więcej.

Zasoby i dalsze czytanie

  • Przeczytaj przewodniki i samouczki TF2, aby dowiedzieć się więcej o korzystaniu z TF2.

  • Jeśli wcześniej używałeś TF1.x, zdecydowanie zaleca się migrację kodu do TF2. Przeczytaj przewodniki po migracji, aby dowiedzieć się więcej.