効果的な TensorFlow 2

TensorFlow.org で表示 Google Colab で実行 GitHub で表示 ノートブックをダウンロード

概要

このガイドでは、TensorFlow 2(TF2)を使ってコードを記述する際のベストプラクティスを紹介しています。最近 TensorFlow 1(TF1)から切り替えたユーザーを対象としています。TF1 コードから TF2 への移行についての詳細は、このガイドの移行セクションをご覧ください。

セットアップ

このガイドの例に使用する TensorFlow とその他の依存関係をインポートします。

import tensorflow as tf
import tensorflow_datasets as tfds
2022-12-14 21:55:25.563282: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory
2022-12-14 21:55:25.563388: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory
2022-12-14 21:55:25.563398: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.

慣用的な TensorFlow 2 の推奨事項

コードを小さなモジュールにリファクタリングする

コードを、必要に応じて呼び出せるより小さな関数にリファクタリングすることをお勧めします。最高のパフォーマンスを得るには、tf.function で行える最も大きな計算ブロックをデコレートするとよいでしょう(tf.function が呼び出すネストされた Python 関数には、tf.function に異なる jit_compile 設定を使用しない限り、別途独自のデコレーションは不要であることに注意してください)。ユースケースに応じて、複数のトレーニングステップであったり、トレーニングループ全体である場合があります。推論のユースケースについては、単一モデルのフォワードパスである場合があります。

一部の tf.keras.optimizer のデフォルトの学習速度を調整する

TF2 では、一部の Keras オプティマイザの学習速度が異なります。モデルの収束の動作に変化がある場合は、デフォルトの学習速度を確認してください。

optimizers.SGDoptimizers.Adam、または optimizers.RMSprop に変更はありません。

デフォルトの学習率は次のように変更されました。

tf.Module と Keras レイヤーを使用して変数を管理する

tf.Moduletf.keras.layers.Layer には、すべての従属変数を帰属的に収集する便利な variablestrainable_variables プロパティがあります。このため、変数が使用されている場所での変数の管理を簡単に行うことができます。

Keras レイヤー/モデルは tf.train.Checkpointable から継承し、@tf.function と統合されています。このため、Keras オブジェクトに直接チェックポイントを設定したり、SavedModels をエクスポートしたりすることができます。この統合を利用するために、Keras の Model.fit API を必ずしも使用する必要はありません。

Keras を使用して関連する変数のサブセットを収集する方法については、Keras ガイドの転移学習とファインチューニングに関するセクションをご覧ください。

tf.data.Datasettf.function を組み合わせる

TensorFlow Datasets パッケージ(tfds)には、事前定義済みのデータセットを tf.data.Dataset オブジェクトとして読み込むためのユーティリティが含まれます。この例では、tfds を使用して MNIST データセットを読み込めます。

datasets, info = tfds.load(name='mnist', with_info=True, as_supervised=True)
mnist_train, mnist_test = datasets['train'], datasets['test']

次に、トレーニングのためのデータを準備します。

  • 各画像をリスケールする。
  • 例の順序をシャッフルする。
  • 画像とラベルのバッチを集める。
BUFFER_SIZE = 10 # Use a much larger value for real code
BATCH_SIZE = 64
NUM_EPOCHS = 5


def scale(image, label):
  image = tf.cast(image, tf.float32)
  image /= 255

  return image, label

例を短く保つために、データセットを 5 バッチだけ返すようにトリミングします。

train_data = mnist_train.map(scale).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
test_data = mnist_test.map(scale).batch(BATCH_SIZE)

STEPS_PER_EPOCH = 5

train_data = train_data.take(STEPS_PER_EPOCH)
test_data = test_data.take(STEPS_PER_EPOCH)
image_batch, label_batch = next(iter(train_data))
2022-12-14 21:55:31.576520: W tensorflow/core/kernels/data/cache_dataset_ops.cc:856] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.

メモリに収まるトレーニングデータは、通常の Python イテレーションでイテレートしますが、そうでない場合は tf.data.Dataset を使ってディスクからトレーニングをストリーミングするのが最適です。データセットはイテラブル(イテレータではない)であり、Eager モードの Python インテラブルとまったく同様に機能します。コードを tf.function でラップすることで、データセットの非同期プリフェッチ/ストリーム機能をそのまま利用することができます。この方法は、Python イテレーションを、同等の、AutoGraph を使用したグラフ演算に置き換えます。

@tf.function
def train(model, dataset, optimizer):
  for x, y in dataset:
    with tf.GradientTape() as tape:
      # training=True is only needed if there are layers with different
      # behavior during training versus inference (e.g. Dropout).
      prediction = model(x, training=True)
      loss = loss_fn(prediction, y)
    gradients = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))

Keras の Model.fit API を使用する場合、データセットのイテレーションを気にする必要はありません。

model.compile(optimizer=optimizer, loss=loss_fn)
model.fit(dataset)

Keras トレーニングループを使用する

トレーニングプロセスの低レベル制御が不要な場合は、Keras 組み込みの fitevaluate、および predict メソッドの使用が推奨されます。これらのメソッドは(シーケンシャル、関数型、またはサブクラス化)実装を問わず、モデルをトレーニングするための統一インターフェースを提供します。

これらのメソッドには次のような優位点があります。

  • Numpy 配列、Python ジェネレータ、tf.data.Datasets を受け取ります。
  • これらは正則化と活性化損失を自動的に適用します。
  • ハードウェア構成に関係なくトレーニングコードが変化しない tf.distribute をサポートします。
  • 任意の callable は損失とメトリクスとしてサポートします。
  • tf.data.Datasets のようなコールバックとカスタムコールバックをサポートします。
  • 自動的に TensorFlow グラフを使用し、高性能です。

ここに Dataset を使用したモデルのトレーニング例を示します。この仕組みについての詳細は、チュートリアルをご覧ください。

model = tf.keras.Sequential([
    tf.keras.layers.Conv2D(32, 3, activation='relu',
                           kernel_regularizer=tf.keras.regularizers.l2(0.02),
                           input_shape=(28, 28, 1)),
    tf.keras.layers.MaxPooling2D(),
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dropout(0.1),
    tf.keras.layers.Dense(64, activation='relu'),
    tf.keras.layers.BatchNormalization(),
    tf.keras.layers.Dense(10)
])

# Model is the full model w/o custom layers
model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

model.fit(train_data, epochs=NUM_EPOCHS)
loss, acc = model.evaluate(test_data)

print("Loss {}, Accuracy {}".format(loss, acc))
Epoch 1/5
5/5 [==============================] - 3s 6ms/step - loss: 1.5985 - accuracy: 0.5063
Epoch 2/5
2022-12-14 21:55:35.111311: W tensorflow/core/kernels/data/cache_dataset_ops.cc:856] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
5/5 [==============================] - 0s 5ms/step - loss: 0.5048 - accuracy: 0.9125
Epoch 3/5
2022-12-14 21:55:35.399474: W tensorflow/core/kernels/data/cache_dataset_ops.cc:856] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
5/5 [==============================] - 0s 5ms/step - loss: 0.3562 - accuracy: 0.9563
Epoch 4/5
2022-12-14 21:55:35.703960: W tensorflow/core/kernels/data/cache_dataset_ops.cc:856] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
5/5 [==============================] - 0s 6ms/step - loss: 0.2913 - accuracy: 0.9656
Epoch 5/5
2022-12-14 21:55:36.008689: W tensorflow/core/kernels/data/cache_dataset_ops.cc:856] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
5/5 [==============================] - 0s 5ms/step - loss: 0.2378 - accuracy: 0.9719
2022-12-14 21:55:36.343486: W tensorflow/core/kernels/data/cache_dataset_ops.cc:856] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
5/5 [==============================] - 0s 3ms/step - loss: 1.5050 - accuracy: 0.6156
Loss 1.5049774646759033, Accuracy 0.6156250238418579
2022-12-14 21:55:36.682374: W tensorflow/core/kernels/data/cache_dataset_ops.cc:856] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.

トレーニングをカスタマイズして独自のループを記述する

Keras モデルは機能しても、トレーニングステップまたは外側のトレーニングループに柔軟性と制御がさらに必要な場合は、独自のトレーニングステップやトレーニングループ全体を実装することができます。詳細については、Keras ガイドのfit のカスタマイズをご覧ください。

様々な機能を tf.keras.callbacks.Callback として実装することもできます。

この方法には、前述した多数のメリットがありますが、トレーニングステップだけでなく、外側のループを制御することができます。

標準のトレーニングループには、以下の 3 つのステップがあります。

  1. Python ジェネレータか tf.data.Datasets をイテレーションして例のバッチを作成します。
  2. tf.GradientTape を使用して勾配を集めます。
  3. tf.keras.optimizers の 1 つを使用して、モデルの変数に重み更新を適用します。

覚えておきましょう:

  • サブクラス化されたレイヤーとモデルの call メソッドには、常に training 引数を含めます。
  • training 引数を確実に正しくセットしてモデルを呼び出します。
  • 使用方法によっては、モデルがデータのバッチ上で実行されるまでモデル変数は存在しないかもしれません。
  • モデルの正則化損失などを手動で処理する必要があります。

変数イニシャライザを実行したり、手動制御の依存関係を追加したりする必要はありません。自動制御依存関係と変数の初期化は、作成時に tf.function によって処理されます。

model = tf.keras.Sequential([
    tf.keras.layers.Conv2D(32, 3, activation='relu',
                           kernel_regularizer=tf.keras.regularizers.l2(0.02),
                           input_shape=(28, 28, 1)),
    tf.keras.layers.MaxPooling2D(),
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dropout(0.1),
    tf.keras.layers.Dense(64, activation='relu'),
    tf.keras.layers.BatchNormalization(),
    tf.keras.layers.Dense(10)
])

optimizer = tf.keras.optimizers.Adam(0.001)
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)

@tf.function
def train_step(inputs, labels):
  with tf.GradientTape() as tape:
    predictions = model(inputs, training=True)
    regularization_loss=tf.math.add_n(model.losses)
    pred_loss=loss_fn(labels, predictions)
    total_loss=pred_loss + regularization_loss

  gradients = tape.gradient(total_loss, model.trainable_variables)
  optimizer.apply_gradients(zip(gradients, model.trainable_variables))

for epoch in range(NUM_EPOCHS):
  for inputs, labels in train_data:
    train_step(inputs, labels)
  print("Finished epoch", epoch)
2022-12-14 21:55:38.581615: W tensorflow/core/kernels/data/cache_dataset_ops.cc:856] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
Finished epoch 0
2022-12-14 21:55:38.829348: W tensorflow/core/kernels/data/cache_dataset_ops.cc:856] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
Finished epoch 1
2022-12-14 21:55:39.109153: W tensorflow/core/kernels/data/cache_dataset_ops.cc:856] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
Finished epoch 2
2022-12-14 21:55:39.409925: W tensorflow/core/kernels/data/cache_dataset_ops.cc:856] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
Finished epoch 3
Finished epoch 4
2022-12-14 21:55:39.650521: W tensorflow/core/kernels/data/cache_dataset_ops.cc:856] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.

Python 制御フローで tf.function を利用する

tf.function は、データに依存する制御フローを tf.condtf.while_loop といったグラフモード相当のフローに変換する方法を提供しています。

データ依存の制御フローがよく見られる場所に、シーケンスモデルが挙げられます。tf.keras.layers.RNN は RNN セルをラップするため、静的または動的にリカレンスを展開することができます。例として、動的な展開を次のように実装しなおすことができます。

class DynamicRNN(tf.keras.Model):

  def __init__(self, rnn_cell):
    super(DynamicRNN, self).__init__(self)
    self.cell = rnn_cell

  @tf.function(input_signature=[tf.TensorSpec(dtype=tf.float32, shape=[None, None, 3])])
  def call(self, input_data):

    # [batch, time, features] -> [time, batch, features]
    input_data = tf.transpose(input_data, [1, 0, 2])
    timesteps =  tf.shape(input_data)[0]
    batch_size = tf.shape(input_data)[1]
    outputs = tf.TensorArray(tf.float32, timesteps)
    state = self.cell.get_initial_state(batch_size = batch_size, dtype=tf.float32)
    for i in tf.range(timesteps):
      output, state = self.cell(input_data[i], state)
      outputs = outputs.write(i, output)
    return tf.transpose(outputs.stack(), [1, 0, 2]), state
lstm_cell = tf.keras.layers.LSTMCell(units = 13)

my_rnn = DynamicRNN(lstm_cell)
outputs, state = my_rnn(tf.random.normal(shape=[10,20,3]))
print(outputs.shape)
(10, 20, 13)

詳細は、tf.function ガイドをご覧ください。

新しいスタイルのメトリクスと損失

メトリクスと損失は、Eager と tf.function で動作するオブジェクトです。

損失オブジェクトは呼び出し可能で、(y_true, y_pred) を引数として期待します。

cce = tf.keras.losses.CategoricalCrossentropy(from_logits=True)
cce([[1, 0]], [[-1.0,3.0]]).numpy()
4.01815

メトリクスを使用してデータの収集と表示を行う

tf.metrics を使ってデータを集計し、tf.summary を使ってサマリーをログに記録してから、コンテキストマネージャーを使ってライターにリダイレクトすることができます。サマリーはライターに直接送信されるため、コールサイトにstep 値を提供する必要があります。

summary_writer = tf.summary.create_file_writer('/tmp/summaries')
with summary_writer.as_default():
  tf.summary.scalar('loss', 0.1, step=42)

サマリーとしてデータをログに記録する前にデータを集計するには、tf.metrics を使用します。メトリクスはステートフルです。つまり、値を蓄積し、result メソッド(Mean.result など)が呼び出されたときに累積結果を返します。累積された値は、Model.reset_states を使用すると消去されます。

def train(model, optimizer, dataset, log_freq=10):
  avg_loss = tf.keras.metrics.Mean(name='loss', dtype=tf.float32)
  for images, labels in dataset:
    loss = train_step(model, optimizer, images, labels)
    avg_loss.update_state(loss)
    if tf.equal(optimizer.iterations % log_freq, 0):
      tf.summary.scalar('loss', avg_loss.result(), step=optimizer.iterations)
      avg_loss.reset_states()

def test(model, test_x, test_y, step_num):
  # training=False is only needed if there are layers with different
  # behavior during training versus inference (e.g. Dropout).
  loss = loss_fn(model(test_x, training=False), test_y)
  tf.summary.scalar('loss', loss, step=step_num)

train_summary_writer = tf.summary.create_file_writer('/tmp/summaries/train')
test_summary_writer = tf.summary.create_file_writer('/tmp/summaries/test')

with train_summary_writer.as_default():
  train(model, optimizer, dataset)

with test_summary_writer.as_default():
  test(model, test_x, test_y, optimizer.iterations)

TensorBoard をサマリーログのディレクトリにポイントし、生成されたサマリーを可視化します。

tensorboard --logdir /tmp/summaries

tf.summary API を使用して、TensorBoard での可視化に使用するサマリーデータを記述します。詳細については、tf.summary ガイドをご覧ください。

# Create the metrics
loss_metric = tf.keras.metrics.Mean(name='train_loss')
accuracy_metric = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy')

@tf.function
def train_step(inputs, labels):
  with tf.GradientTape() as tape:
    predictions = model(inputs, training=True)
    regularization_loss=tf.math.add_n(model.losses)
    pred_loss=loss_fn(labels, predictions)
    total_loss=pred_loss + regularization_loss

  gradients = tape.gradient(total_loss, model.trainable_variables)
  optimizer.apply_gradients(zip(gradients, model.trainable_variables))
  # Update the metrics
  loss_metric.update_state(total_loss)
  accuracy_metric.update_state(labels, predictions)


for epoch in range(NUM_EPOCHS):
  # Reset the metrics
  loss_metric.reset_states()
  accuracy_metric.reset_states()

  for inputs, labels in train_data:
    train_step(inputs, labels)
  # Get the metric results
  mean_loss=loss_metric.result()
  mean_accuracy = accuracy_metric.result()

  print('Epoch: ', epoch)
  print('  loss:     {:.3f}'.format(mean_loss))
  print('  accuracy: {:.3f}'.format(mean_accuracy))
2022-12-14 21:55:40.582380: W tensorflow/core/kernels/data/cache_dataset_ops.cc:856] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
Epoch:  0
  loss:     0.129
  accuracy: 0.997
2022-12-14 21:55:40.864762: W tensorflow/core/kernels/data/cache_dataset_ops.cc:856] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
Epoch:  1
  loss:     0.113
  accuracy: 0.997
2022-12-14 21:55:41.171972: W tensorflow/core/kernels/data/cache_dataset_ops.cc:856] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
Epoch:  2
  loss:     0.093
  accuracy: 1.000
2022-12-14 21:55:41.429435: W tensorflow/core/kernels/data/cache_dataset_ops.cc:856] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
Epoch:  3
  loss:     0.090
  accuracy: 0.997
Epoch:  4
  loss:     0.078
  accuracy: 1.000
2022-12-14 21:55:41.668782: W tensorflow/core/kernels/data/cache_dataset_ops.cc:856] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.

Keras メトリクス名

Keras モデルはメトリクス名の処理を一貫して行います。メトリクスリストで文字列を渡すと、まさにその文字列がメトリクスの name として使用されます。これらの名前は model.fit によって返される履歴オブジェクトと、keras.callbacks に渡されるログに表示されます。これはメトリクスリストで渡した文字列に設定されています。**

model.compile(
    optimizer = tf.keras.optimizers.Adam(0.001),
    loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics = ['acc', 'accuracy', tf.keras.metrics.SparseCategoricalAccuracy(name="my_accuracy")])
history = model.fit(train_data)
5/5 [==============================] - 2s 6ms/step - loss: 0.0940 - acc: 0.9969 - accuracy: 0.9969 - my_accuracy: 0.9969
2022-12-14 21:55:43.368318: W tensorflow/core/kernels/data/cache_dataset_ops.cc:856] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
history.history.keys()
dict_keys(['loss', 'acc', 'accuracy', 'my_accuracy'])

デバッグ

Eager execution を使用してコードをステップごとに実行すると、形状、データ型、および値を検査することができます。tf.functiontf.keras などの特定の API は、パフォーマンスや移植性の目的で、Graph execution を使用するように設計されていますが、デバッグの際は、tf.config.run_functions_eagerly(True) を使って、このコード内で Eager execution を使用することができます。

以下に例を示します。

@tf.function
def f(x):
  if x > 0:
    import pdb
    pdb.set_trace()
    x = x + 1
  return x

tf.config.run_functions_eagerly(True)
f(tf.constant(1))
>>> f()
-> x = x + 1
(Pdb) l
  6     @tf.function
  7     def f(x):
  8       if x > 0:
  9         import pdb
 10         pdb.set_trace()
 11  ->     x = x + 1
 12       return x
 13
 14     tf.config.run_functions_eagerly(True)
 15     f(tf.constant(1))
[EOF]

これは Keras モデルや、Eager execution をサポートするほかの API 内でも機能します。

class CustomModel(tf.keras.models.Model):

  @tf.function
  def call(self, input_data):
    if tf.reduce_mean(input_data) > 0:
      return input_data
    else:
      import pdb
      pdb.set_trace()
      return input_data // 2


tf.config.run_functions_eagerly(True)
model = CustomModel()
model(tf.constant([-2, -4]))
>>> call()
-> return input_data // 2
(Pdb) l
 10         if tf.reduce_mean(input_data) > 0:
 11           return input_data
 12         else:
 13           import pdb
 14           pdb.set_trace()
 15  ->       return input_data // 2
 16
 17
 18     tf.config.run_functions_eagerly(True)
 19     model = CustomModel()
 20     model(tf.constant([-2, -4]))

注意:

オブジェクトに tf.Tensors を保持しないこと

これらのテンソルオブジェクトは、tf.function または Eager のコンテキストで作成される可能性があり、これらのテンソルは異なった振る舞いをします。tf.Tensor は必ず中間値のみに使用してください。

状態を追跡するには、tf.Variable を使用してください。これらはいずれのコンテキストからも常に使用可能です。詳細については、tf.Variable ガイドをご覧ください。

リソースとその他の文献

  • TF2 の使用方法についての詳細は、TF2 のガイドチュートリアルをご覧ください。

  • 前に TF1.x を使用していた場合は、コードを TF2 に移行することを強くお勧めします。詳細は、移行ガイドをご覧ください。