![]() | ![]() | ![]() |
このチュートリアルでは、TensorFlowモデルをトレーニングしてCIFAR-10データセットを分類し、XLAを使用してコンパイルします。
Keras APIを使用して、データセットをロードして正規化します。
import tensorflow as tf
# Check that GPU is available: cf. https://colab.research.google.com/notebooks/gpu.ipynb
assert(tf.test.gpu_device_name())
tf.keras.backend.clear_session()
tf.config.optimizer.set_jit(False) # Start with XLA disabled.
def load_data():
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
x_train = x_train.astype('float32') / 256
x_test = x_test.astype('float32') / 256
# Convert class vectors to binary class matrices.
y_train = tf.keras.utils.to_categorical(y_train, num_classes=10)
y_test = tf.keras.utils.to_categorical(y_test, num_classes=10)
return ((x_train, y_train), (x_test, y_test))
(x_train, y_train), (x_test, y_test) = load_data()
Downloading data from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz 170500096/170498071 [==============================] - 4s 0us/step
KerasCIFAR-10の例を応用したモデルを定義します。
def generate_model():
return tf.keras.models.Sequential([
tf.keras.layers.Conv2D(32, (3, 3), padding='same', input_shape=x_train.shape[1:]),
tf.keras.layers.Activation('relu'),
tf.keras.layers.Conv2D(32, (3, 3)),
tf.keras.layers.Activation('relu'),
tf.keras.layers.MaxPooling2D(pool_size=(2, 2)),
tf.keras.layers.Dropout(0.25),
tf.keras.layers.Conv2D(64, (3, 3), padding='same'),
tf.keras.layers.Activation('relu'),
tf.keras.layers.Conv2D(64, (3, 3)),
tf.keras.layers.Activation('relu'),
tf.keras.layers.MaxPooling2D(pool_size=(2, 2)),
tf.keras.layers.Dropout(0.25),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(512),
tf.keras.layers.Activation('relu'),
tf.keras.layers.Dropout(0.5),
tf.keras.layers.Dense(10),
tf.keras.layers.Activation('softmax')
])
model = generate_model()
RMSpropオプティマイザーを使用してモデルをトレーニングします。
def compile_model(model):
opt = tf.keras.optimizers.RMSprop(lr=0.0001, decay=1e-6)
model.compile(loss='categorical_crossentropy',
optimizer=opt,
metrics=['accuracy'])
return model
model = compile_model(model)
def train_model(model, x_train, y_train, x_test, y_test, epochs=25):
model.fit(x_train, y_train, batch_size=256, epochs=epochs, validation_data=(x_test, y_test), shuffle=True)
def warmup(model, x_train, y_train, x_test, y_test):
# Warm up the JIT, we do not wish to measure the compilation time.
initial_weights = model.get_weights()
train_model(model, x_train, y_train, x_test, y_test, epochs=1)
model.set_weights(initial_weights)
warmup(model, x_train, y_train, x_test, y_test)
%time train_model(model, x_train, y_train, x_test, y_test)
scores = model.evaluate(x_test, y_test, verbose=1)
print('Test loss:', scores[0])
print('Test accuracy:', scores[1])
196/196 [==============================] - 14s 12ms/step - loss: 2.1887 - accuracy: 0.1773 - val_loss: 1.8642 - val_accuracy: 0.3363 Epoch 1/25 196/196 [==============================] - 2s 10ms/step - loss: 2.1349 - accuracy: 0.2016 - val_loss: 1.8978 - val_accuracy: 0.3353 Epoch 2/25 196/196 [==============================] - 2s 9ms/step - loss: 1.8088 - accuracy: 0.3462 - val_loss: 1.7053 - val_accuracy: 0.3935 Epoch 3/25 196/196 [==============================] - 2s 8ms/step - loss: 1.6749 - accuracy: 0.3941 - val_loss: 1.5618 - val_accuracy: 0.4401 Epoch 4/25 196/196 [==============================] - 2s 8ms/step - loss: 1.5902 - accuracy: 0.4240 - val_loss: 1.5003 - val_accuracy: 0.4683 Epoch 5/25 196/196 [==============================] - 2s 8ms/step - loss: 1.5168 - accuracy: 0.4486 - val_loss: 1.4156 - val_accuracy: 0.4967 Epoch 6/25 196/196 [==============================] - 2s 8ms/step - loss: 1.4654 - accuracy: 0.4703 - val_loss: 1.4081 - val_accuracy: 0.4961 Epoch 7/25 196/196 [==============================] - 2s 8ms/step - loss: 1.4231 - accuracy: 0.4869 - val_loss: 1.3556 - val_accuracy: 0.5162 Epoch 8/25 196/196 [==============================] - 2s 9ms/step - loss: 1.3901 - accuracy: 0.5019 - val_loss: 1.3041 - val_accuracy: 0.5368 Epoch 9/25 196/196 [==============================] - 2s 9ms/step - loss: 1.3559 - accuracy: 0.5162 - val_loss: 1.2992 - val_accuracy: 0.5475 Epoch 10/25 196/196 [==============================] - 2s 9ms/step - loss: 1.3259 - accuracy: 0.5255 - val_loss: 1.2536 - val_accuracy: 0.5587 Epoch 11/25 196/196 [==============================] - 2s 9ms/step - loss: 1.2971 - accuracy: 0.5375 - val_loss: 1.2550 - val_accuracy: 0.5607 Epoch 12/25 196/196 [==============================] - 2s 8ms/step - loss: 1.2713 - accuracy: 0.5505 - val_loss: 1.1769 - val_accuracy: 0.5906 Epoch 13/25 196/196 [==============================] - 2s 8ms/step - loss: 1.2476 - accuracy: 0.5582 - val_loss: 1.1955 - val_accuracy: 0.5770 Epoch 14/25 196/196 [==============================] - 2s 8ms/step - loss: 1.2216 - accuracy: 0.5679 - val_loss: 1.1839 - val_accuracy: 0.5813 Epoch 15/25 196/196 [==============================] - 2s 8ms/step - loss: 1.2023 - accuracy: 0.5745 - val_loss: 1.1746 - val_accuracy: 0.5912 Epoch 16/25 196/196 [==============================] - 2s 8ms/step - loss: 1.1755 - accuracy: 0.5842 - val_loss: 1.1104 - val_accuracy: 0.6097 Epoch 17/25 196/196 [==============================] - 2s 8ms/step - loss: 1.1513 - accuracy: 0.5954 - val_loss: 1.0757 - val_accuracy: 0.6233 Epoch 18/25 196/196 [==============================] - 2s 8ms/step - loss: 1.1392 - accuracy: 0.5998 - val_loss: 1.0859 - val_accuracy: 0.6209 Epoch 19/25 196/196 [==============================] - 2s 8ms/step - loss: 1.1167 - accuracy: 0.6059 - val_loss: 1.0935 - val_accuracy: 0.6183 Epoch 20/25 196/196 [==============================] - 2s 8ms/step - loss: 1.0941 - accuracy: 0.6143 - val_loss: 1.0590 - val_accuracy: 0.6329 Epoch 21/25 196/196 [==============================] - 2s 8ms/step - loss: 1.0814 - accuracy: 0.6208 - val_loss: 1.0499 - val_accuracy: 0.6334 Epoch 22/25 196/196 [==============================] - 2s 9ms/step - loss: 1.0638 - accuracy: 0.6275 - val_loss: 0.9962 - val_accuracy: 0.6580 Epoch 23/25 196/196 [==============================] - 2s 8ms/step - loss: 1.0448 - accuracy: 0.6342 - val_loss: 1.0240 - val_accuracy: 0.6419 Epoch 24/25 196/196 [==============================] - 2s 9ms/step - loss: 1.0301 - accuracy: 0.6402 - val_loss: 0.9885 - val_accuracy: 0.6512 Epoch 25/25 196/196 [==============================] - 2s 8ms/step - loss: 1.0165 - accuracy: 0.6423 - val_loss: 0.9609 - val_accuracy: 0.6659 CPU times: user 55.2 s, sys: 8.16 s, total: 1min 3s Wall time: 42.7 s 313/313 [==============================] - 1s 3ms/step - loss: 0.9609 - accuracy: 0.6659 Test loss: 0.9608545899391174 Test accuracy: 0.6658999919891357
次に、XLAコンパイラを使用してモデルを再度トレーニングしましょう。アプリケーションの途中でコンパイラを有効にするには、Kerasセッションをリセットする必要があります。
# We need to clear the session to enable JIT in the middle of the program.
tf.keras.backend.clear_session()
tf.config.optimizer.set_jit(True) # Enable XLA.
model = compile_model(generate_model())
(x_train, y_train), (x_test, y_test) = load_data()
warmup(model, x_train, y_train, x_test, y_test)
%time train_model(model, x_train, y_train, x_test, y_test)
196/196 [==============================] - 6s 14ms/step - loss: 2.1602 - accuracy: 0.1890 - val_loss: 1.8198 - val_accuracy: 0.3543 Epoch 1/25 196/196 [==============================] - 4s 18ms/step - loss: 2.1084 - accuracy: 0.2171 - val_loss: 1.8668 - val_accuracy: 0.3387 Epoch 2/25 196/196 [==============================] - 2s 8ms/step - loss: 1.8050 - accuracy: 0.3503 - val_loss: 1.7050 - val_accuracy: 0.3965 Epoch 3/25 196/196 [==============================] - 2s 8ms/step - loss: 1.6757 - accuracy: 0.3943 - val_loss: 1.5696 - val_accuracy: 0.4384 Epoch 4/25 196/196 [==============================] - 2s 8ms/step - loss: 1.5956 - accuracy: 0.4228 - val_loss: 1.5098 - val_accuracy: 0.4567 Epoch 5/25 196/196 [==============================] - 2s 8ms/step - loss: 1.5310 - accuracy: 0.4456 - val_loss: 1.4722 - val_accuracy: 0.4707 Epoch 6/25 196/196 [==============================] - 2s 8ms/step - loss: 1.4825 - accuracy: 0.4631 - val_loss: 1.5245 - val_accuracy: 0.4628 Epoch 7/25 196/196 [==============================] - 2s 8ms/step - loss: 1.4374 - accuracy: 0.4837 - val_loss: 1.4239 - val_accuracy: 0.4915 Epoch 8/25 196/196 [==============================] - 2s 8ms/step - loss: 1.3900 - accuracy: 0.5026 - val_loss: 1.3184 - val_accuracy: 0.5260 Epoch 9/25 196/196 [==============================] - 2s 8ms/step - loss: 1.3600 - accuracy: 0.5143 - val_loss: 1.2731 - val_accuracy: 0.5496 Epoch 10/25 196/196 [==============================] - 2s 8ms/step - loss: 1.3260 - accuracy: 0.5281 - val_loss: 1.2552 - val_accuracy: 0.5542 Epoch 11/25 196/196 [==============================] - 2s 8ms/step - loss: 1.2938 - accuracy: 0.5387 - val_loss: 1.2242 - val_accuracy: 0.5738 Epoch 12/25 196/196 [==============================] - 2s 9ms/step - loss: 1.2642 - accuracy: 0.5510 - val_loss: 1.2240 - val_accuracy: 0.5596 Epoch 13/25 196/196 [==============================] - 2s 8ms/step - loss: 1.2389 - accuracy: 0.5622 - val_loss: 1.1663 - val_accuracy: 0.5868 Epoch 14/25 196/196 [==============================] - 2s 8ms/step - loss: 1.2110 - accuracy: 0.5711 - val_loss: 1.1312 - val_accuracy: 0.5983 Epoch 15/25 196/196 [==============================] - 2s 8ms/step - loss: 1.1856 - accuracy: 0.5821 - val_loss: 1.1978 - val_accuracy: 0.5730 Epoch 16/25 196/196 [==============================] - 2s 8ms/step - loss: 1.1619 - accuracy: 0.5890 - val_loss: 1.2709 - val_accuracy: 0.5568 Epoch 17/25 196/196 [==============================] - 2s 8ms/step - loss: 1.1430 - accuracy: 0.5975 - val_loss: 1.0918 - val_accuracy: 0.6181 Epoch 18/25 196/196 [==============================] - 2s 8ms/step - loss: 1.1190 - accuracy: 0.6074 - val_loss: 1.0924 - val_accuracy: 0.6148 Epoch 19/25 196/196 [==============================] - 2s 8ms/step - loss: 1.0970 - accuracy: 0.6130 - val_loss: 1.0485 - val_accuracy: 0.6277 Epoch 20/25 196/196 [==============================] - 2s 8ms/step - loss: 1.0830 - accuracy: 0.6191 - val_loss: 1.0675 - val_accuracy: 0.6241 Epoch 21/25 196/196 [==============================] - 2s 8ms/step - loss: 1.0616 - accuracy: 0.6289 - val_loss: 1.0053 - val_accuracy: 0.6540 Epoch 22/25 196/196 [==============================] - 2s 8ms/step - loss: 1.0483 - accuracy: 0.6331 - val_loss: 0.9849 - val_accuracy: 0.6615 Epoch 23/25 196/196 [==============================] - 2s 8ms/step - loss: 1.0322 - accuracy: 0.6364 - val_loss: 0.9753 - val_accuracy: 0.6617 Epoch 24/25 196/196 [==============================] - 2s 8ms/step - loss: 1.0190 - accuracy: 0.6421 - val_loss: 0.9657 - val_accuracy: 0.6639 Epoch 25/25 196/196 [==============================] - 1s 8ms/step - loss: 0.9991 - accuracy: 0.6517 - val_loss: 0.9733 - val_accuracy: 0.6630 CPU times: user 45.4 s, sys: 6.48 s, total: 51.9 s Wall time: 41.2 s
Titan VGPUとIntelXeon E5-2690 CPUを搭載したマシンでは、スピードアップは約1.17倍です。