![]() |
![]() |
![]() |
![]() |
はじめに
教師あり学習をする場合、fit()
を使用すると全てがスムーズに動作します。
独自のトレーニングループを新規で書く必要がある場合には、GradientTape
を使用すると、細部までコントロールすることができます。
しかし、カスタムトレーニングアルゴリズムが必要で、なおかつコールバック、組み込み分散サポート、ステップ結合など、fit()
の便利な機能を利用したい場合には、どうしたらよいでしょうか?
Keras の核となる原則は、複雑性のプログレッシブディスクロージャ―です。常に低レベルのワークフローに段階的に入ることが可能です。高レベルの機能性がユースケースと完全に一致しない場合でも、急激に性能が落ちるようなことはありません。相応の高レベルの利便性を維持しながら細部をよりコントロールできるはずです。
fit()
の動作をカスタマイズする必要がある場合は、Model
クラスのトレーニングステップ関数をオーバーライドする必要があります。これはデータのバッチごとにfit()
に呼び出される関数です。これによって、通常通りのfit()
の呼び出しが可能になり、独自の学習アルゴリズムが実行されます。
このパターンは Functional API を使用したモデル構築を妨げるものではないことに注意してください。これは、Sequential
モデル、Functional API モデル、サブクラス化されたモデルのいずれを構築する場合にも適用可能です。
では、その仕組みを見ていきましょう。
セットアップ
TensorFlow 2.2 以降が必要です。
import tensorflow as tf
from tensorflow import keras
最初の簡単な例
簡単な例から始めてみましょう。
keras.Model
をサブクラス化する新しいクラスを作成します。train_step(self, data)
メソッドをオーバーライドするだけです。- メトリック名(損失を含む)をマッピングするディクショナリを現在の値に返します。
トレーニングデータとして fit() に渡されるのが、入力引数data
です。
fit(x, y, ...)
を呼び出して Numpy 配列を渡すと、data
はタプル(x, y)
になります。fit(dataset, ...)
を呼び出してtf.data.Dataset
を渡すと、data
は各バッチでdataset
により生成されたものになります。
train_step
メソッドの本体には、お馴染みの定期的なトレーニング更新を実装します。重要なのは、損失をself.compiled_loss
で計算するため、compile()
に渡された損失関数をラップしていることです。
同様に、self.compiled_metrics.update_state(y, y_pred)
を呼び出してcompile()
で渡されたメトリクスの状態を更新し、最後にself.metrics
の結果を照会して現在の値を取得します。
class CustomModel(keras.Model):
def train_step(self, data):
# Unpack the data. Its structure depends on your model and
# on what you pass to `fit()`.
x, y = data
with tf.GradientTape() as tape:
y_pred = self(x, training=True) # Forward pass
# Compute the loss value
# (the loss function is configured in `compile()`)
loss = self.compiled_loss(y, y_pred, regularization_losses=self.losses)
# Compute gradients
trainable_vars = self.trainable_variables
gradients = tape.gradient(loss, trainable_vars)
# Update weights
self.optimizer.apply_gradients(zip(gradients, trainable_vars))
# Update metrics (includes the metric that tracks the loss)
self.compiled_metrics.update_state(y, y_pred)
# Return a dict mapping metric names to current value
return {m.name: m.result() for m in self.metrics}
これを試してみましょう。
import numpy as np
# Construct and compile an instance of CustomModel
inputs = keras.Input(shape=(32,))
outputs = keras.layers.Dense(1)(inputs)
model = CustomModel(inputs, outputs)
model.compile(optimizer="adam", loss="mse", metrics=["mae"])
# Just use `fit` as usual
x = np.random.random((1000, 32))
y = np.random.random((1000, 1))
model.fit(x, y, epochs=3)
Epoch 1/3 32/32 [==============================] - 1s 2ms/step - loss: 1.2411 - mae: 1.0207 Epoch 2/3 32/32 [==============================] - 0s 2ms/step - loss: 0.4799 - mae: 0.5857 Epoch 3/3 32/32 [==============================] - 0s 2ms/step - loss: 0.2316 - mae: 0.3894 <tensorflow.python.keras.callbacks.History at 0x7f2f540400b8>
低レベルにする
当然ながら、compile()
で損失関数を渡すことを省略し、その代わりにtrain_step
で全てを手動で実行することは可能です。これはメトリクスの場合でも同様です。オプティマイザの構成にcompile()
のみを使用した、低レベルの例を次に示します。
mae_metric = keras.metrics.MeanAbsoluteError(name="mae")
loss_tracker = keras.metrics.Mean(name="loss")
class CustomModel(keras.Model):
def train_step(self, data):
x, y = data
with tf.GradientTape() as tape:
y_pred = self(x, training=True) # Forward pass
# Compute our own loss
loss = keras.losses.mean_squared_error(y, y_pred)
# Compute gradients
trainable_vars = self.trainable_variables
gradients = tape.gradient(loss, trainable_vars)
# Update weights
self.optimizer.apply_gradients(zip(gradients, trainable_vars))
# Compute our own metrics
loss_tracker.update_state(loss)
mae_metric.update_state(y, y_pred)
return {"loss": loss_tracker.result(), "mae": mae_metric.result()}
# Construct an instance of CustomModel
inputs = keras.Input(shape=(32,))
outputs = keras.layers.Dense(1)(inputs)
model = CustomModel(inputs, outputs)
# We don't passs a loss or metrics here.
model.compile(optimizer="adam")
# Just use `fit` as usual -- you can use callbacks, etc.
x = np.random.random((1000, 32))
y = np.random.random((1000, 1))
model.fit(x, y, epochs=1)
32/32 [==============================] - 0s 2ms/step - loss: 0.2481 - mae: 0.3957 <tensorflow.python.keras.callbacks.History at 0x7f2f481db6a0>
このセットアップでは、各エポックの後に、またはトレーニングと評価の間に、メトリクス上で手動でreset_states()
を呼び出す必要があることに注意してください。
sample_weight
とclass_weight
をサポートする
最初の基本的な例には、サンプルの重み付けに関する言及が一切なかったことにお気づきでしょうか。fit()
引数のsample_weight
とclass_weight
をサポートする場合には、以下のようにします。
data
引数からsample_weight
をアンパックする- それを
compiled_loss
とcompiled_metrics
に渡す(もちろん、 損失とメトリクスがcompile()
に依存しない場合は手動での適用が可能) - これだけで完了です。これがリストです。
class CustomModel(keras.Model):
def train_step(self, data):
# Unpack the data. Its structure depends on your model and
# on what you pass to `fit()`.
if len(data) == 3:
x, y, sample_weight = data
else:
x, y = data
with tf.GradientTape() as tape:
y_pred = self(x, training=True) # Forward pass
# Compute the loss value.
# The loss function is configured in `compile()`.
loss = self.compiled_loss(
y,
y_pred,
sample_weight=sample_weight,
regularization_losses=self.losses,
)
# Compute gradients
trainable_vars = self.trainable_variables
gradients = tape.gradient(loss, trainable_vars)
# Update weights
self.optimizer.apply_gradients(zip(gradients, trainable_vars))
# Update the metrics.
# Metrics are configured in `compile()`.
self.compiled_metrics.update_state(y, y_pred, sample_weight=sample_weight)
# Return a dict mapping metric names to current value.
# Note that it will include the loss (tracked in self.metrics).
return {m.name: m.result() for m in self.metrics}
# Construct and compile an instance of CustomModel
inputs = keras.Input(shape=(32,))
outputs = keras.layers.Dense(1)(inputs)
model = CustomModel(inputs, outputs)
model.compile(optimizer="adam", loss="mse", metrics=["mae"])
# You can now use sample_weight argument
x = np.random.random((1000, 32))
y = np.random.random((1000, 1))
sw = np.random.random((1000, 1))
model.fit(x, y, sample_weight=sw, epochs=3)
Epoch 1/3 32/32 [==============================] - 0s 2ms/step - loss: 0.0974 - mae: 0.3536 Epoch 2/3 32/32 [==============================] - 0s 2ms/step - loss: 0.0907 - mae: 0.3428 Epoch 3/3 32/32 [==============================] - 0s 2ms/step - loss: 0.0881 - mae: 0.3375 <tensorflow.python.keras.callbacks.History at 0x7f2f48167da0>
独自の評価ステップを提供する
model.evaluate()
の呼び出しに同じことをする場合はどうしたらよいでしょう?その場合は、全く同じ方法でtest_step
をオーバーライドします。これは次のようになります。
class CustomModel(keras.Model):
def test_step(self, data):
# Unpack the data
x, y = data
# Compute predictions
y_pred = self(x, training=False)
# Updates the metrics tracking the loss
self.compiled_loss(y, y_pred, regularization_losses=self.losses)
# Update the metrics.
self.compiled_metrics.update_state(y, y_pred)
# Return a dict mapping metric names to current value.
# Note that it will include the loss (tracked in self.metrics).
return {m.name: m.result() for m in self.metrics}
# Construct an instance of CustomModel
inputs = keras.Input(shape=(32,))
outputs = keras.layers.Dense(1)(inputs)
model = CustomModel(inputs, outputs)
model.compile(loss="mse", metrics=["mae"])
# Evaluate with our custom test_step
x = np.random.random((1000, 32))
y = np.random.random((1000, 1))
model.evaluate(x, y)
32/32 [==============================] - 0s 1ms/step - loss: 0.8946 - mae: 0.8234 [0.9146538972854614, 0.8362475633621216]
まとめ: エンドツーエンド GAN の例
ここで学んだ全てを活用した、エンドツーエンドの例を見てみましょう。
以下を検討してみましょう。
- 28x28x1 の画像を生成するジェネレーターネットワーク。
- 28x28x1 の画像を 2 つのクラス(「偽物」と「本物」)に分類するディスクリミネーターネットワーク。
- それぞれに 1 つのオプティマイザ。
- ディスクリミネーターをトレーニングする損失関数。
from tensorflow.keras import layers
# Create the discriminator
discriminator = keras.Sequential(
[
keras.Input(shape=(28, 28, 1)),
layers.Conv2D(64, (3, 3), strides=(2, 2), padding="same"),
layers.LeakyReLU(alpha=0.2),
layers.Conv2D(128, (3, 3), strides=(2, 2), padding="same"),
layers.LeakyReLU(alpha=0.2),
layers.GlobalMaxPooling2D(),
layers.Dense(1),
],
name="discriminator",
)
# Create the generator
latent_dim = 128
generator = keras.Sequential(
[
keras.Input(shape=(latent_dim,)),
# We want to generate 128 coefficients to reshape into a 7x7x128 map
layers.Dense(7 * 7 * 128),
layers.LeakyReLU(alpha=0.2),
layers.Reshape((7, 7, 128)),
layers.Conv2DTranspose(128, (4, 4), strides=(2, 2), padding="same"),
layers.LeakyReLU(alpha=0.2),
layers.Conv2DTranspose(128, (4, 4), strides=(2, 2), padding="same"),
layers.LeakyReLU(alpha=0.2),
layers.Conv2D(1, (7, 7), padding="same", activation="sigmoid"),
],
name="generator",
)
ここにフィーチャーコンプリートの GAN クラスがあります。compile()
をオーバーライドして独自のシグネチャを使用することにより、GAN アルゴリズム全体をtrain_step
の 17 行で実装しています。
class GAN(keras.Model):
def __init__(self, discriminator, generator, latent_dim):
super(GAN, self).__init__()
self.discriminator = discriminator
self.generator = generator
self.latent_dim = latent_dim
def compile(self, d_optimizer, g_optimizer, loss_fn):
super(GAN, self).compile()
self.d_optimizer = d_optimizer
self.g_optimizer = g_optimizer
self.loss_fn = loss_fn
def train_step(self, real_images):
if isinstance(real_images, tuple):
real_images = real_images[0]
# Sample random points in the latent space
batch_size = tf.shape(real_images)[0]
random_latent_vectors = tf.random.normal(shape=(batch_size, self.latent_dim))
# Decode them to fake images
generated_images = self.generator(random_latent_vectors)
# Combine them with real images
combined_images = tf.concat([generated_images, real_images], axis=0)
# Assemble labels discriminating real from fake images
labels = tf.concat(
[tf.ones((batch_size, 1)), tf.zeros((batch_size, 1))], axis=0
)
# Add random noise to the labels - important trick!
labels += 0.05 * tf.random.uniform(tf.shape(labels))
# Train the discriminator
with tf.GradientTape() as tape:
predictions = self.discriminator(combined_images)
d_loss = self.loss_fn(labels, predictions)
grads = tape.gradient(d_loss, self.discriminator.trainable_weights)
self.d_optimizer.apply_gradients(
zip(grads, self.discriminator.trainable_weights)
)
# Sample random points in the latent space
random_latent_vectors = tf.random.normal(shape=(batch_size, self.latent_dim))
# Assemble labels that say "all real images"
misleading_labels = tf.zeros((batch_size, 1))
# Train the generator (note that we should *not* update the weights
# of the discriminator)!
with tf.GradientTape() as tape:
predictions = self.discriminator(self.generator(random_latent_vectors))
g_loss = self.loss_fn(misleading_labels, predictions)
grads = tape.gradient(g_loss, self.generator.trainable_weights)
self.g_optimizer.apply_gradients(zip(grads, self.generator.trainable_weights))
return {"d_loss": d_loss, "g_loss": g_loss}
テストドライブしてみましょう。
# Prepare the dataset. We use both the training & test MNIST digits.
batch_size = 64
(x_train, _), (x_test, _) = keras.datasets.mnist.load_data()
all_digits = np.concatenate([x_train, x_test])
all_digits = all_digits.astype("float32") / 255.0
all_digits = np.reshape(all_digits, (-1, 28, 28, 1))
dataset = tf.data.Dataset.from_tensor_slices(all_digits)
dataset = dataset.shuffle(buffer_size=1024).batch(batch_size)
gan = GAN(discriminator=discriminator, generator=generator, latent_dim=latent_dim)
gan.compile(
d_optimizer=keras.optimizers.Adam(learning_rate=0.0003),
g_optimizer=keras.optimizers.Adam(learning_rate=0.0003),
loss_fn=keras.losses.BinaryCrossentropy(from_logits=True),
)
# To limit execution time, we only train on 100 batches. You can train on
# the entire dataset. You will need about 20 epochs to get nice results.
gan.fit(dataset.take(100), epochs=1)
100/100 [==============================] - 4s 12ms/step - d_loss: 0.4427 - g_loss: 0.9436 <tensorflow.python.keras.callbacks.History at 0x7f2f6df00518>
ディープラーニングの背景にある考え方は単純です。実装もそうあるべきだと思います。