このページは Cloud Translation API によって翻訳されました。
Switch to English

XLAによるCIFAR-10の分類

TensorFlow.orgで見る Google Colabで実行 GitHubでソースを表示する

このチュートリアルでは、TensorFlowモデルをトレーニングしてCIFAR-10データセットを分類し、XLAを使用してそれをコンパイルします。

Keras APIを使用してデータセットを読み込み、正規化します。

import tensorflow as tf

# Check that GPU is available: cf. https://colab.research.google.com/notebooks/gpu.ipynb
assert(tf.test.is_gpu_available())

tf.keras.backend.clear_session()
tf.config.optimizer.set_jit(False) # Start with XLA disabled.

def load_data():
  (x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
  x_train = x_train.astype('float32') / 256
  x_test = x_test.astype('float32') / 256

  # Convert class vectors to binary class matrices.
  y_train = tf.keras.utils.to_categorical(y_train, num_classes=10)
  y_test = tf.keras.utils.to_categorical(y_test, num_classes=10)
  return ((x_train, y_train), (x_test, y_test))

(x_train, y_train), (x_test, y_test) = load_data()
WARNING:tensorflow:From <ipython-input-1-f643ac1e83e4>:4: is_gpu_available (from tensorflow.python.framework.test_util) is deprecated and will be removed in a future version.
Instructions for updating:
Use `tf.config.list_physical_devices('GPU')` instead.
Downloading data from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz
170500096/170498071 [==============================] - 11s 0us/step

Keras CIFAR-10の例から適応したモデルを定義します。

def generate_model():
  return tf.keras.models.Sequential([
    tf.keras.layers.Conv2D(32, (3, 3), padding='same', input_shape=x_train.shape[1:]),
    tf.keras.layers.Activation('relu'),
    tf.keras.layers.Conv2D(32, (3, 3)),
    tf.keras.layers.Activation('relu'),
    tf.keras.layers.MaxPooling2D(pool_size=(2, 2)),
    tf.keras.layers.Dropout(0.25),

    tf.keras.layers.Conv2D(64, (3, 3), padding='same'),
    tf.keras.layers.Activation('relu'),
    tf.keras.layers.Conv2D(64, (3, 3)),
    tf.keras.layers.Activation('relu'),
    tf.keras.layers.MaxPooling2D(pool_size=(2, 2)),
    tf.keras.layers.Dropout(0.25),

    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(512),
    tf.keras.layers.Activation('relu'),
    tf.keras.layers.Dropout(0.5),
    tf.keras.layers.Dense(10),
    tf.keras.layers.Activation('softmax')
  ])

model = generate_model()

RMSpropオプティマイザーを使用してモデルをトレーニングします。

def compile_model(model):
  opt = tf.keras.optimizers.RMSprop(lr=0.0001, decay=1e-6)
  model.compile(loss='categorical_crossentropy',
                optimizer=opt,
                metrics=['accuracy'])
  return model

model = compile_model(model)

def train_model(model, x_train, y_train, x_test, y_test, epochs=25):
  model.fit(x_train, y_train, batch_size=256, epochs=epochs, validation_data=(x_test, y_test), shuffle=True)

def warmup(model, x_train, y_train, x_test, y_test):
  # Warm up the JIT, we do not wish to measure the compilation time.
  initial_weights = model.get_weights()
  train_model(model, x_train, y_train, x_test, y_test, epochs=1)
  model.set_weights(initial_weights)

warmup(model, x_train, y_train, x_test, y_test)
train_model(model, x_train, y_train, x_test, y_test)

scores = model.evaluate(x_test, y_test, verbose=1)
print('Test loss:', scores[0])
print('Test accuracy:', scores[1])
196/196 [==============================] - 2s 10ms/step - loss: 2.0271 - accuracy: 0.2556 - val_loss: 1.8257 - val_accuracy: 0.3530
Epoch 1/25
196/196 [==============================] - 2s 9ms/step - loss: 2.1027 - accuracy: 0.2207 - val_loss: 1.8864 - val_accuracy: 0.3285
Epoch 2/25
196/196 [==============================] - 2s 8ms/step - loss: 1.8008 - accuracy: 0.3525 - val_loss: 1.6833 - val_accuracy: 0.4039
Epoch 3/25
196/196 [==============================] - 2s 8ms/step - loss: 1.6928 - accuracy: 0.3898 - val_loss: 1.6162 - val_accuracy: 0.4227
Epoch 4/25
196/196 [==============================] - 2s 8ms/step - loss: 1.6204 - accuracy: 0.4132 - val_loss: 1.5381 - val_accuracy: 0.4375
Epoch 5/25
196/196 [==============================] - 2s 8ms/step - loss: 1.5642 - accuracy: 0.4350 - val_loss: 1.4725 - val_accuracy: 0.4622
Epoch 6/25
196/196 [==============================] - 2s 8ms/step - loss: 1.5066 - accuracy: 0.4535 - val_loss: 1.4279 - val_accuracy: 0.4832
Epoch 7/25
196/196 [==============================] - 2s 8ms/step - loss: 1.4529 - accuracy: 0.4767 - val_loss: 1.3504 - val_accuracy: 0.5118
Epoch 8/25
196/196 [==============================] - 2s 8ms/step - loss: 1.4064 - accuracy: 0.4942 - val_loss: 1.4257 - val_accuracy: 0.5027
Epoch 9/25
196/196 [==============================] - 2s 8ms/step - loss: 1.3686 - accuracy: 0.5101 - val_loss: 1.3362 - val_accuracy: 0.5258
Epoch 10/25
196/196 [==============================] - 2s 8ms/step - loss: 1.3322 - accuracy: 0.5230 - val_loss: 1.3372 - val_accuracy: 0.5292
Epoch 11/25
196/196 [==============================] - 2s 8ms/step - loss: 1.3061 - accuracy: 0.5333 - val_loss: 1.2191 - val_accuracy: 0.5691
Epoch 12/25
196/196 [==============================] - 2s 8ms/step - loss: 1.2757 - accuracy: 0.5425 - val_loss: 1.2260 - val_accuracy: 0.5709
Epoch 13/25
196/196 [==============================] - 2s 8ms/step - loss: 1.2515 - accuracy: 0.5543 - val_loss: 1.1758 - val_accuracy: 0.5888
Epoch 14/25
196/196 [==============================] - 2s 8ms/step - loss: 1.2234 - accuracy: 0.5648 - val_loss: 1.1960 - val_accuracy: 0.5795
Epoch 15/25
196/196 [==============================] - 2s 8ms/step - loss: 1.2087 - accuracy: 0.5715 - val_loss: 1.1267 - val_accuracy: 0.6048
Epoch 16/25
196/196 [==============================] - 2s 8ms/step - loss: 1.1799 - accuracy: 0.5812 - val_loss: 1.1146 - val_accuracy: 0.6117
Epoch 17/25
196/196 [==============================] - 2s 8ms/step - loss: 1.1560 - accuracy: 0.5892 - val_loss: 1.1543 - val_accuracy: 0.5996
Epoch 18/25
196/196 [==============================] - 2s 8ms/step - loss: 1.1375 - accuracy: 0.5971 - val_loss: 1.0818 - val_accuracy: 0.6250
Epoch 19/25
196/196 [==============================] - 2s 8ms/step - loss: 1.1155 - accuracy: 0.6073 - val_loss: 1.0693 - val_accuracy: 0.6275
Epoch 20/25
196/196 [==============================] - 2s 8ms/step - loss: 1.1018 - accuracy: 0.6135 - val_loss: 1.0532 - val_accuracy: 0.6338
Epoch 21/25
196/196 [==============================] - 2s 8ms/step - loss: 1.0821 - accuracy: 0.6203 - val_loss: 1.0163 - val_accuracy: 0.6450
Epoch 22/25
196/196 [==============================] - 2s 8ms/step - loss: 1.0656 - accuracy: 0.6237 - val_loss: 1.0094 - val_accuracy: 0.6496
Epoch 23/25
196/196 [==============================] - 2s 9ms/step - loss: 1.0470 - accuracy: 0.6311 - val_loss: 0.9986 - val_accuracy: 0.6573
Epoch 24/25
196/196 [==============================] - 2s 9ms/step - loss: 1.0308 - accuracy: 0.6372 - val_loss: 0.9791 - val_accuracy: 0.6579
Epoch 25/25
196/196 [==============================] - 2s 9ms/step - loss: 1.0156 - accuracy: 0.6449 - val_loss: 0.9698 - val_accuracy: 0.6635
313/313 [==============================] - 1s 3ms/step - loss: 0.9698 - accuracy: 0.6635
Test loss: 0.969793438911438
Test accuracy: 0.6635000109672546

次に、XLAコンパイラーを使用して、モデルを再度トレーニングします。アプリケーションの途中でコンパイラを有効にするには、Kerasセッションをリセットする必要があります。

# We need to clear the session to enable JIT in the middle of the program.
tf.keras.backend.clear_session()
tf.config.optimizer.set_jit(True) # Enable XLA.
model = compile_model(generate_model())
(x_train, y_train), (x_test, y_test) = load_data()

warmup(model, x_train, y_train, x_test, y_test)
%time train_model(model, x_train, y_train, x_test, y_test)
196/196 [==============================] - 3s 13ms/step - loss: 2.0589 - accuracy: 0.2391 - val_loss: 1.8392 - val_accuracy: 0.3439
Epoch 1/25
196/196 [==============================] - 3s 17ms/step - loss: 2.1530 - accuracy: 0.1975 - val_loss: 1.8981 - val_accuracy: 0.3395
Epoch 2/25
196/196 [==============================] - 1s 7ms/step - loss: 1.8210 - accuracy: 0.3430 - val_loss: 1.7153 - val_accuracy: 0.3904
Epoch 3/25
196/196 [==============================] - 1s 7ms/step - loss: 1.6835 - accuracy: 0.3884 - val_loss: 1.5915 - val_accuracy: 0.4194
Epoch 4/25
196/196 [==============================] - 1s 7ms/step - loss: 1.5976 - accuracy: 0.4196 - val_loss: 1.4913 - val_accuracy: 0.4648
Epoch 5/25
196/196 [==============================] - 1s 7ms/step - loss: 1.5304 - accuracy: 0.4448 - val_loss: 1.4355 - val_accuracy: 0.4801
Epoch 6/25
196/196 [==============================] - 1s 7ms/step - loss: 1.4768 - accuracy: 0.4661 - val_loss: 1.4313 - val_accuracy: 0.4836
Epoch 7/25
196/196 [==============================] - 1s 7ms/step - loss: 1.4299 - accuracy: 0.4828 - val_loss: 1.3828 - val_accuracy: 0.5031
Epoch 8/25
196/196 [==============================] - 1s 7ms/step - loss: 1.3932 - accuracy: 0.4985 - val_loss: 1.3157 - val_accuracy: 0.5274
Epoch 9/25
196/196 [==============================] - 1s 7ms/step - loss: 1.3571 - accuracy: 0.5150 - val_loss: 1.3079 - val_accuracy: 0.5299
Epoch 10/25
196/196 [==============================] - 1s 7ms/step - loss: 1.3325 - accuracy: 0.5222 - val_loss: 1.2482 - val_accuracy: 0.5544
Epoch 11/25
196/196 [==============================] - 1s 7ms/step - loss: 1.2994 - accuracy: 0.5359 - val_loss: 1.2617 - val_accuracy: 0.5510
Epoch 12/25
196/196 [==============================] - 1s 7ms/step - loss: 1.2733 - accuracy: 0.5451 - val_loss: 1.2264 - val_accuracy: 0.5624
Epoch 13/25
196/196 [==============================] - 1s 7ms/step - loss: 1.2478 - accuracy: 0.5568 - val_loss: 1.2450 - val_accuracy: 0.5644
Epoch 14/25
196/196 [==============================] - 1s 7ms/step - loss: 1.2223 - accuracy: 0.5649 - val_loss: 1.1971 - val_accuracy: 0.5785
Epoch 15/25
196/196 [==============================] - 1s 7ms/step - loss: 1.2038 - accuracy: 0.5730 - val_loss: 1.1459 - val_accuracy: 0.6017
Epoch 16/25
196/196 [==============================] - 1s 7ms/step - loss: 1.1826 - accuracy: 0.5801 - val_loss: 1.1202 - val_accuracy: 0.6047
Epoch 17/25
196/196 [==============================] - 1s 7ms/step - loss: 1.1602 - accuracy: 0.5886 - val_loss: 1.1146 - val_accuracy: 0.6131
Epoch 18/25
196/196 [==============================] - 1s 7ms/step - loss: 1.1402 - accuracy: 0.5976 - val_loss: 1.0740 - val_accuracy: 0.6285
Epoch 19/25
196/196 [==============================] - 1s 7ms/step - loss: 1.1202 - accuracy: 0.6036 - val_loss: 1.0712 - val_accuracy: 0.6196
Epoch 20/25
196/196 [==============================] - 1s 7ms/step - loss: 1.1067 - accuracy: 0.6093 - val_loss: 1.0343 - val_accuracy: 0.6418
Epoch 21/25
196/196 [==============================] - 1s 7ms/step - loss: 1.0867 - accuracy: 0.6178 - val_loss: 1.0376 - val_accuracy: 0.6425
Epoch 22/25
196/196 [==============================] - 1s 7ms/step - loss: 1.0667 - accuracy: 0.6257 - val_loss: 0.9948 - val_accuracy: 0.6555
Epoch 23/25
196/196 [==============================] - 1s 7ms/step - loss: 1.0488 - accuracy: 0.6300 - val_loss: 1.0192 - val_accuracy: 0.6433
Epoch 24/25
196/196 [==============================] - 1s 7ms/step - loss: 1.0300 - accuracy: 0.6375 - val_loss: 0.9870 - val_accuracy: 0.6576
Epoch 25/25
196/196 [==============================] - 1s 7ms/step - loss: 1.0175 - accuracy: 0.6421 - val_loss: 0.9753 - val_accuracy: 0.6580
CPU times: user 44.2 s, sys: 6.08 s, total: 50.3 s
Wall time: 38.9 s

Titan V GPUとIntel Xeon E5-2690 CPUを搭載したマシンでは、速度が最大1.17倍になります。