Classifying CIFAR-10 with XLA

View on TensorFlow.org Run in Google Colab View source on GitHub

This tutorial trains a TensorFlow model to classify the CIFAR-10 dataset, and we compile it using XLA.

Load and normalize the dataset using the Keras API:

import tensorflow as tf

# Check that GPU is available: cf. https://colab.research.google.com/notebooks/gpu.ipynb
assert(tf.test.gpu_device_name())

tf.keras.backend.clear_session()
tf.config.optimizer.set_jit(False) # Start with XLA disabled.

def load_data():
  (x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
  x_train = x_train.astype('float32') / 256
  x_test = x_test.astype('float32') / 256

  # Convert class vectors to binary class matrices.
  y_train = tf.keras.utils.to_categorical(y_train, num_classes=10)
  y_test = tf.keras.utils.to_categorical(y_test, num_classes=10)
  return ((x_train, y_train), (x_test, y_test))

(x_train, y_train), (x_test, y_test) = load_data()
Downloading data from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz
170500096/170498071 [==============================] - 2s 0us/step
170508288/170498071 [==============================] - 2s 0us/step

Keras CIFAR-10 예제를 기초로 모델을 정의합니다.

def generate_model():
  return tf.keras.models.Sequential([
    tf.keras.layers.Conv2D(32, (3, 3), padding='same', input_shape=x_train.shape[1:]),
    tf.keras.layers.Activation('relu'),
    tf.keras.layers.Conv2D(32, (3, 3)),
    tf.keras.layers.Activation('relu'),
    tf.keras.layers.MaxPooling2D(pool_size=(2, 2)),
    tf.keras.layers.Dropout(0.25),

    tf.keras.layers.Conv2D(64, (3, 3), padding='same'),
    tf.keras.layers.Activation('relu'),
    tf.keras.layers.Conv2D(64, (3, 3)),
    tf.keras.layers.Activation('relu'),
    tf.keras.layers.MaxPooling2D(pool_size=(2, 2)),
    tf.keras.layers.Dropout(0.25),

    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(512),
    tf.keras.layers.Activation('relu'),
    tf.keras.layers.Dropout(0.5),
    tf.keras.layers.Dense(10),
    tf.keras.layers.Activation('softmax')
  ])

model = generate_model()

RMSprop 옵티마이저를 사용하여 모델을 훈련합니다.

def compile_model(model):
  opt = tf.keras.optimizers.RMSprop(lr=0.0001, decay=1e-6)
  model.compile(loss='categorical_crossentropy',
                optimizer=opt,
                metrics=['accuracy'])
  return model

model = compile_model(model)

def train_model(model, x_train, y_train, x_test, y_test, epochs=25):
  model.fit(x_train, y_train, batch_size=256, epochs=epochs, validation_data=(x_test, y_test), shuffle=True)

def warmup(model, x_train, y_train, x_test, y_test):
  # Warm up the JIT, we do not wish to measure the compilation time.
  initial_weights = model.get_weights()
  train_model(model, x_train, y_train, x_test, y_test, epochs=1)
  model.set_weights(initial_weights)

warmup(model, x_train, y_train, x_test, y_test)
%time train_model(model, x_train, y_train, x_test, y_test)

scores = model.evaluate(x_test, y_test, verbose=1)
print('Test loss:', scores[0])
print('Test accuracy:', scores[1])
/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/keras/optimizer_v2/optimizer_v2.py:356: UserWarning: The `lr` argument is deprecated, use `learning_rate` instead.
  "The `lr` argument is deprecated, use `learning_rate` instead.")
196/196 [==============================] - 4s 11ms/step - loss: 2.0395 - accuracy: 0.2455 - val_loss: 1.8326 - val_accuracy: 0.3586
Epoch 1/25
196/196 [==============================] - 2s 9ms/step - loss: 2.1007 - accuracy: 0.2191 - val_loss: 1.8853 - val_accuracy: 0.3260
Epoch 2/25
196/196 [==============================] - 2s 8ms/step - loss: 1.7937 - accuracy: 0.3516 - val_loss: 1.6568 - val_accuracy: 0.4097
Epoch 3/25
196/196 [==============================] - 2s 8ms/step - loss: 1.6670 - accuracy: 0.3938 - val_loss: 1.5464 - val_accuracy: 0.4434
Epoch 4/25
196/196 [==============================] - 1s 8ms/step - loss: 1.5847 - accuracy: 0.4243 - val_loss: 1.5004 - val_accuracy: 0.4614
Epoch 5/25
196/196 [==============================] - 2s 8ms/step - loss: 1.5202 - accuracy: 0.4511 - val_loss: 1.4879 - val_accuracy: 0.4710
Epoch 6/25
196/196 [==============================] - 2s 8ms/step - loss: 1.4666 - accuracy: 0.4723 - val_loss: 1.3830 - val_accuracy: 0.5013
Epoch 7/25
196/196 [==============================] - 1s 8ms/step - loss: 1.4273 - accuracy: 0.4845 - val_loss: 1.3415 - val_accuracy: 0.5174
Epoch 8/25
196/196 [==============================] - 2s 8ms/step - loss: 1.3935 - accuracy: 0.4995 - val_loss: 1.3505 - val_accuracy: 0.5193
Epoch 9/25
196/196 [==============================] - 2s 8ms/step - loss: 1.3599 - accuracy: 0.5139 - val_loss: 1.2670 - val_accuracy: 0.5488
Epoch 10/25
196/196 [==============================] - 2s 8ms/step - loss: 1.3300 - accuracy: 0.5268 - val_loss: 1.2622 - val_accuracy: 0.5519
Epoch 11/25
196/196 [==============================] - 1s 8ms/step - loss: 1.3020 - accuracy: 0.5390 - val_loss: 1.2196 - val_accuracy: 0.5689
Epoch 12/25
196/196 [==============================] - 1s 8ms/step - loss: 1.2768 - accuracy: 0.5469 - val_loss: 1.1969 - val_accuracy: 0.5762
Epoch 13/25
196/196 [==============================] - 2s 8ms/step - loss: 1.2519 - accuracy: 0.5558 - val_loss: 1.2510 - val_accuracy: 0.5621
Epoch 14/25
196/196 [==============================] - 1s 8ms/step - loss: 1.2337 - accuracy: 0.5644 - val_loss: 1.1758 - val_accuracy: 0.5872
Epoch 15/25
196/196 [==============================] - 2s 8ms/step - loss: 1.2085 - accuracy: 0.5738 - val_loss: 1.1580 - val_accuracy: 0.5941
Epoch 16/25
196/196 [==============================] - 2s 8ms/step - loss: 1.1831 - accuracy: 0.5841 - val_loss: 1.1352 - val_accuracy: 0.6045
Epoch 17/25
196/196 [==============================] - 1s 8ms/step - loss: 1.1627 - accuracy: 0.5897 - val_loss: 1.1194 - val_accuracy: 0.6086
Epoch 18/25
196/196 [==============================] - 1s 8ms/step - loss: 1.1435 - accuracy: 0.5948 - val_loss: 1.1733 - val_accuracy: 0.5908
Epoch 19/25
196/196 [==============================] - 2s 8ms/step - loss: 1.1218 - accuracy: 0.6072 - val_loss: 1.0623 - val_accuracy: 0.6298
Epoch 20/25
196/196 [==============================] - 2s 8ms/step - loss: 1.1053 - accuracy: 0.6113 - val_loss: 1.0589 - val_accuracy: 0.6335
Epoch 21/25
196/196 [==============================] - 2s 8ms/step - loss: 1.0872 - accuracy: 0.6198 - val_loss: 1.0317 - val_accuracy: 0.6380
Epoch 22/25
196/196 [==============================] - 2s 8ms/step - loss: 1.0719 - accuracy: 0.6252 - val_loss: 1.0427 - val_accuracy: 0.6369
Epoch 23/25
196/196 [==============================] - 2s 8ms/step - loss: 1.0527 - accuracy: 0.6308 - val_loss: 0.9874 - val_accuracy: 0.6531
Epoch 24/25
196/196 [==============================] - 1s 8ms/step - loss: 1.0351 - accuracy: 0.6361 - val_loss: 1.0075 - val_accuracy: 0.6492
Epoch 25/25
196/196 [==============================] - 1s 8ms/step - loss: 1.0236 - accuracy: 0.6436 - val_loss: 0.9797 - val_accuracy: 0.6587
CPU times: user 47 s, sys: 5.72 s, total: 52.7 s
Wall time: 39.4 s
313/313 [==============================] - 1s 2ms/step - loss: 0.9797 - accuracy: 0.6587
Test loss: 0.9796982407569885
Test accuracy: 0.6586999893188477

이제 XLA 컴파일러를 사용하여 모델을 다시 훈련하겠습니다. 애플리케이션 중간에 컴파일러를 활성화하려면 Keras 세션을 재설정해야 합니다.

# We need to clear the session to enable JIT in the middle of the program.
tf.keras.backend.clear_session()
tf.config.optimizer.set_jit(True) # Enable XLA.
model = compile_model(generate_model())
(x_train, y_train), (x_test, y_test) = load_data()

warmup(model, x_train, y_train, x_test, y_test)
%time train_model(model, x_train, y_train, x_test, y_test)
196/196 [==============================] - 5s 13ms/step - loss: 2.0439 - accuracy: 0.2498 - val_loss: 1.8283 - val_accuracy: 0.3566
Epoch 1/25
196/196 [==============================] - 4s 18ms/step - loss: 2.1271 - accuracy: 0.2144 - val_loss: 1.8623 - val_accuracy: 0.3491
Epoch 2/25
196/196 [==============================] - 1s 7ms/step - loss: 1.8081 - accuracy: 0.3496 - val_loss: 1.6823 - val_accuracy: 0.4058
Epoch 3/25
196/196 [==============================] - 1s 7ms/step - loss: 1.6905 - accuracy: 0.3908 - val_loss: 1.5872 - val_accuracy: 0.4324
Epoch 4/25
196/196 [==============================] - 1s 7ms/step - loss: 1.6168 - accuracy: 0.4183 - val_loss: 1.5310 - val_accuracy: 0.4419
Epoch 5/25
196/196 [==============================] - 1s 7ms/step - loss: 1.5570 - accuracy: 0.4401 - val_loss: 1.4528 - val_accuracy: 0.4819
Epoch 6/25
196/196 [==============================] - 1s 7ms/step - loss: 1.5004 - accuracy: 0.4583 - val_loss: 1.4114 - val_accuracy: 0.4932
Epoch 7/25
196/196 [==============================] - 1s 7ms/step - loss: 1.4591 - accuracy: 0.4765 - val_loss: 1.3647 - val_accuracy: 0.5160
Epoch 8/25
196/196 [==============================] - 1s 7ms/step - loss: 1.4189 - accuracy: 0.4897 - val_loss: 1.3653 - val_accuracy: 0.5151
Epoch 9/25
196/196 [==============================] - 1s 7ms/step - loss: 1.3828 - accuracy: 0.5049 - val_loss: 1.3127 - val_accuracy: 0.5288
Epoch 10/25
196/196 [==============================] - 1s 7ms/step - loss: 1.3481 - accuracy: 0.5168 - val_loss: 1.3534 - val_accuracy: 0.5285
Epoch 11/25
196/196 [==============================] - 1s 7ms/step - loss: 1.3209 - accuracy: 0.5288 - val_loss: 1.2366 - val_accuracy: 0.5606
Epoch 12/25
196/196 [==============================] - 1s 7ms/step - loss: 1.2934 - accuracy: 0.5397 - val_loss: 1.2379 - val_accuracy: 0.5622
Epoch 13/25
196/196 [==============================] - 1s 7ms/step - loss: 1.2630 - accuracy: 0.5498 - val_loss: 1.2640 - val_accuracy: 0.5523
Epoch 14/25
196/196 [==============================] - 1s 7ms/step - loss: 1.2403 - accuracy: 0.5584 - val_loss: 1.2333 - val_accuracy: 0.5618
Epoch 15/25
196/196 [==============================] - 1s 7ms/step - loss: 1.2169 - accuracy: 0.5699 - val_loss: 1.1787 - val_accuracy: 0.5851
Epoch 16/25
196/196 [==============================] - 1s 7ms/step - loss: 1.1944 - accuracy: 0.5809 - val_loss: 1.1339 - val_accuracy: 0.5962
Epoch 17/25
196/196 [==============================] - 1s 7ms/step - loss: 1.1746 - accuracy: 0.5850 - val_loss: 1.1283 - val_accuracy: 0.6029
Epoch 18/25
196/196 [==============================] - 1s 7ms/step - loss: 1.1515 - accuracy: 0.5963 - val_loss: 1.1050 - val_accuracy: 0.6090
Epoch 19/25
196/196 [==============================] - 1s 7ms/step - loss: 1.1324 - accuracy: 0.6014 - val_loss: 1.0778 - val_accuracy: 0.6210
Epoch 20/25
196/196 [==============================] - 1s 7ms/step - loss: 1.1116 - accuracy: 0.6091 - val_loss: 1.1027 - val_accuracy: 0.6124
Epoch 21/25
196/196 [==============================] - 1s 7ms/step - loss: 1.0944 - accuracy: 0.6158 - val_loss: 1.0454 - val_accuracy: 0.6356
Epoch 22/25
196/196 [==============================] - 1s 7ms/step - loss: 1.0744 - accuracy: 0.6225 - val_loss: 1.0302 - val_accuracy: 0.6387
Epoch 23/25
196/196 [==============================] - 1s 7ms/step - loss: 1.0569 - accuracy: 0.6280 - val_loss: 1.0352 - val_accuracy: 0.6383
Epoch 24/25
196/196 [==============================] - 1s 7ms/step - loss: 1.0431 - accuracy: 0.6328 - val_loss: 0.9780 - val_accuracy: 0.6603
Epoch 25/25
196/196 [==============================] - 1s 7ms/step - loss: 1.0278 - accuracy: 0.6406 - val_loss: 0.9785 - val_accuracy: 0.6616
CPU times: user 39.4 s, sys: 5.3 s, total: 44.7 s
Wall time: 39.9 s

Titan V GPU 및 Intel Xeon E5-2690 CPU를 탑재한 시스템에서 속도 향상은 약 1.17배입니다.