오늘 현지 TensorFlow Everywhere 이벤트에 참석하세요!

Classifying CIFAR-10 with XLA

View on TensorFlow.org Run in Google Colab View source on GitHub

This tutorial trains a TensorFlow model to classify the CIFAR-10 dataset, and we compile it using XLA.

Load and normalize the dataset using the Keras API:

import tensorflow as tf

# Check that GPU is available: cf. https://colab.research.google.com/notebooks/gpu.ipynb
assert(tf.test.is_gpu_available())

tf.keras.backend.clear_session()
tf.config.optimizer.set_jit(False) # Start with XLA disabled.

def load_data():
  (x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
  x_train = x_train.astype('float32') / 256
  x_test = x_test.astype('float32') / 256

  # Convert class vectors to binary class matrices.
  y_train = tf.keras.utils.to_categorical(y_train, num_classes=10)
  y_test = tf.keras.utils.to_categorical(y_test, num_classes=10)
  return ((x_train, y_train), (x_test, y_test))

(x_train, y_train), (x_test, y_test) = load_data()
WARNING:tensorflow:From <ipython-input-1-f643ac1e83e4>:4: is_gpu_available (from tensorflow.python.framework.test_util) is deprecated and will be removed in a future version.
Instructions for updating:
Use `tf.config.list_physical_devices('GPU')` instead.
Downloading data from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz
170500096/170498071 [==============================] - 11s 0us/step

Keras CIFAR-10 예제를 기초로 모델을 정의합니다.

def generate_model():
  return tf.keras.models.Sequential([
    tf.keras.layers.Conv2D(32, (3, 3), padding='same', input_shape=x_train.shape[1:]),
    tf.keras.layers.Activation('relu'),
    tf.keras.layers.Conv2D(32, (3, 3)),
    tf.keras.layers.Activation('relu'),
    tf.keras.layers.MaxPooling2D(pool_size=(2, 2)),
    tf.keras.layers.Dropout(0.25),

    tf.keras.layers.Conv2D(64, (3, 3), padding='same'),
    tf.keras.layers.Activation('relu'),
    tf.keras.layers.Conv2D(64, (3, 3)),
    tf.keras.layers.Activation('relu'),
    tf.keras.layers.MaxPooling2D(pool_size=(2, 2)),
    tf.keras.layers.Dropout(0.25),

    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(512),
    tf.keras.layers.Activation('relu'),
    tf.keras.layers.Dropout(0.5),
    tf.keras.layers.Dense(10),
    tf.keras.layers.Activation('softmax')
  ])

model = generate_model()

RMSprop 옵티마이저를 사용하여 모델을 훈련합니다.

def compile_model(model):
  opt = tf.keras.optimizers.RMSprop(lr=0.0001, decay=1e-6)
  model.compile(loss='categorical_crossentropy',
                optimizer=opt,
                metrics=['accuracy'])
  return model

model = compile_model(model)

def train_model(model, x_train, y_train, x_test, y_test, epochs=25):
  model.fit(x_train, y_train, batch_size=256, epochs=epochs, validation_data=(x_test, y_test), shuffle=True)

def warmup(model, x_train, y_train, x_test, y_test):
  # Warm up the JIT, we do not wish to measure the compilation time.
  initial_weights = model.get_weights()
  train_model(model, x_train, y_train, x_test, y_test, epochs=1)
  model.set_weights(initial_weights)

warmup(model, x_train, y_train, x_test, y_test)
%time train_model(model, x_train, y_train, x_test, y_test)

scores = model.evaluate(x_test, y_test, verbose=1)
print('Test loss:', scores[0])
print('Test accuracy:', scores[1])
196/196 [==============================] - 5s 12ms/step - loss: 2.1496 - accuracy: 0.1983 - val_loss: 1.8149 - val_accuracy: 0.3625
Epoch 1/25
196/196 [==============================] - 2s 11ms/step - loss: 2.1004 - accuracy: 0.2203 - val_loss: 1.8906 - val_accuracy: 0.3238
Epoch 2/25
196/196 [==============================] - 2s 9ms/step - loss: 1.8013 - accuracy: 0.3475 - val_loss: 1.6799 - val_accuracy: 0.3996
Epoch 3/25
196/196 [==============================] - 2s 9ms/step - loss: 1.6819 - accuracy: 0.3920 - val_loss: 1.5903 - val_accuracy: 0.4248
Epoch 4/25
196/196 [==============================] - 2s 9ms/step - loss: 1.6081 - accuracy: 0.4157 - val_loss: 1.5107 - val_accuracy: 0.4530
Epoch 5/25
196/196 [==============================] - 2s 9ms/step - loss: 1.5503 - accuracy: 0.4380 - val_loss: 1.4521 - val_accuracy: 0.4804
Epoch 6/25
196/196 [==============================] - 2s 9ms/step - loss: 1.4977 - accuracy: 0.4594 - val_loss: 1.4120 - val_accuracy: 0.4896
Epoch 7/25
196/196 [==============================] - 2s 9ms/step - loss: 1.4496 - accuracy: 0.4799 - val_loss: 1.3468 - val_accuracy: 0.5175
Epoch 8/25
196/196 [==============================] - 2s 9ms/step - loss: 1.4125 - accuracy: 0.4945 - val_loss: 1.3819 - val_accuracy: 0.5060
Epoch 9/25
196/196 [==============================] - 2s 9ms/step - loss: 1.3695 - accuracy: 0.5115 - val_loss: 1.3210 - val_accuracy: 0.5385
Epoch 10/25
196/196 [==============================] - 2s 9ms/step - loss: 1.3332 - accuracy: 0.5241 - val_loss: 1.2489 - val_accuracy: 0.5537
Epoch 11/25
196/196 [==============================] - 2s 9ms/step - loss: 1.3050 - accuracy: 0.5356 - val_loss: 1.2619 - val_accuracy: 0.5451
Epoch 12/25
196/196 [==============================] - 2s 9ms/step - loss: 1.2763 - accuracy: 0.5474 - val_loss: 1.2221 - val_accuracy: 0.5702
Epoch 13/25
196/196 [==============================] - 2s 9ms/step - loss: 1.2433 - accuracy: 0.5601 - val_loss: 1.1858 - val_accuracy: 0.5812
Epoch 14/25
196/196 [==============================] - 2s 9ms/step - loss: 1.2230 - accuracy: 0.5686 - val_loss: 1.1337 - val_accuracy: 0.6005
Epoch 15/25
196/196 [==============================] - 2s 9ms/step - loss: 1.1973 - accuracy: 0.5775 - val_loss: 1.1198 - val_accuracy: 0.6048
Epoch 16/25
196/196 [==============================] - 2s 9ms/step - loss: 1.1727 - accuracy: 0.5896 - val_loss: 1.0952 - val_accuracy: 0.6142
Epoch 17/25
196/196 [==============================] - 2s 9ms/step - loss: 1.1454 - accuracy: 0.5988 - val_loss: 1.0855 - val_accuracy: 0.6163
Epoch 18/25
196/196 [==============================] - 2s 9ms/step - loss: 1.1315 - accuracy: 0.6043 - val_loss: 1.0694 - val_accuracy: 0.6237
Epoch 19/25
196/196 [==============================] - 2s 9ms/step - loss: 1.1075 - accuracy: 0.6126 - val_loss: 1.0425 - val_accuracy: 0.6340
Epoch 20/25
196/196 [==============================] - 2s 9ms/step - loss: 1.0865 - accuracy: 0.6180 - val_loss: 1.0308 - val_accuracy: 0.6390
Epoch 21/25
196/196 [==============================] - 2s 9ms/step - loss: 1.0657 - accuracy: 0.6275 - val_loss: 1.0263 - val_accuracy: 0.6370
Epoch 22/25
196/196 [==============================] - 2s 9ms/step - loss: 1.0526 - accuracy: 0.6320 - val_loss: 1.0057 - val_accuracy: 0.6511
Epoch 23/25
196/196 [==============================] - 2s 9ms/step - loss: 1.0388 - accuracy: 0.6359 - val_loss: 1.0015 - val_accuracy: 0.6513
Epoch 24/25
196/196 [==============================] - 2s 10ms/step - loss: 1.0209 - accuracy: 0.6443 - val_loss: 0.9913 - val_accuracy: 0.6491
Epoch 25/25
196/196 [==============================] - 2s 9ms/step - loss: 1.0075 - accuracy: 0.6481 - val_loss: 0.9850 - val_accuracy: 0.6558
CPU times: user 58.9 s, sys: 7.5 s, total: 1min 6s
Wall time: 44.5 s
313/313 [==============================] - 1s 3ms/step - loss: 0.9850 - accuracy: 0.6558
Test loss: 0.9849947094917297
Test accuracy: 0.6557999849319458

이제 XLA 컴파일러를 사용하여 모델을 다시 훈련하겠습니다. 애플리케이션 중간에 컴파일러를 활성화하려면 Keras 세션을 재설정해야 합니다.

# We need to clear the session to enable JIT in the middle of the program.
tf.keras.backend.clear_session()
tf.config.optimizer.set_jit(True) # Enable XLA.
model = compile_model(generate_model())
(x_train, y_train), (x_test, y_test) = load_data()

warmup(model, x_train, y_train, x_test, y_test)
%time train_model(model, x_train, y_train, x_test, y_test)
196/196 [==============================] - 6s 14ms/step - loss: 2.1409 - accuracy: 0.1995 - val_loss: 1.8821 - val_accuracy: 0.3151
Epoch 1/25
196/196 [==============================] - 4s 20ms/step - loss: 2.0976 - accuracy: 0.2184 - val_loss: 1.8789 - val_accuracy: 0.3493
Epoch 2/25
196/196 [==============================] - 2s 8ms/step - loss: 1.8061 - accuracy: 0.3482 - val_loss: 1.6674 - val_accuracy: 0.4048
Epoch 3/25
196/196 [==============================] - 2s 8ms/step - loss: 1.6806 - accuracy: 0.3910 - val_loss: 1.6368 - val_accuracy: 0.4117
Epoch 4/25
196/196 [==============================] - 2s 8ms/step - loss: 1.6100 - accuracy: 0.4194 - val_loss: 1.5094 - val_accuracy: 0.4560
Epoch 5/25
196/196 [==============================] - 2s 8ms/step - loss: 1.5538 - accuracy: 0.4379 - val_loss: 1.5283 - val_accuracy: 0.4387
Epoch 6/25
196/196 [==============================] - 2s 8ms/step - loss: 1.5004 - accuracy: 0.4594 - val_loss: 1.4120 - val_accuracy: 0.4909
Epoch 7/25
196/196 [==============================] - 2s 8ms/step - loss: 1.4554 - accuracy: 0.4767 - val_loss: 1.3769 - val_accuracy: 0.4973
Epoch 8/25
196/196 [==============================] - 2s 8ms/step - loss: 1.4165 - accuracy: 0.4896 - val_loss: 1.4065 - val_accuracy: 0.4951
Epoch 9/25
196/196 [==============================] - 2s 8ms/step - loss: 1.3810 - accuracy: 0.5045 - val_loss: 1.3334 - val_accuracy: 0.5228
Epoch 10/25
196/196 [==============================] - 2s 8ms/step - loss: 1.3466 - accuracy: 0.5188 - val_loss: 1.2982 - val_accuracy: 0.5375
Epoch 11/25
196/196 [==============================] - 2s 8ms/step - loss: 1.3145 - accuracy: 0.5296 - val_loss: 1.2729 - val_accuracy: 0.5512
Epoch 12/25
196/196 [==============================] - 2s 9ms/step - loss: 1.2859 - accuracy: 0.5451 - val_loss: 1.2059 - val_accuracy: 0.5742
Epoch 13/25
196/196 [==============================] - 2s 8ms/step - loss: 1.2639 - accuracy: 0.5504 - val_loss: 1.1961 - val_accuracy: 0.5799
Epoch 14/25
196/196 [==============================] - 2s 8ms/step - loss: 1.2353 - accuracy: 0.5621 - val_loss: 1.1558 - val_accuracy: 0.5908
Epoch 15/25
196/196 [==============================] - 2s 8ms/step - loss: 1.2114 - accuracy: 0.5694 - val_loss: 1.2587 - val_accuracy: 0.5546
Epoch 16/25
196/196 [==============================] - 2s 8ms/step - loss: 1.1874 - accuracy: 0.5799 - val_loss: 1.1096 - val_accuracy: 0.6105
Epoch 17/25
196/196 [==============================] - 2s 8ms/step - loss: 1.1668 - accuracy: 0.5879 - val_loss: 1.1292 - val_accuracy: 0.6048
Epoch 18/25
196/196 [==============================] - 2s 8ms/step - loss: 1.1434 - accuracy: 0.5977 - val_loss: 1.0774 - val_accuracy: 0.6197
Epoch 19/25
196/196 [==============================] - 2s 8ms/step - loss: 1.1258 - accuracy: 0.6040 - val_loss: 1.0642 - val_accuracy: 0.6274
Epoch 20/25
196/196 [==============================] - 2s 8ms/step - loss: 1.1079 - accuracy: 0.6107 - val_loss: 1.0274 - val_accuracy: 0.6383
Epoch 21/25
196/196 [==============================] - 2s 8ms/step - loss: 1.0856 - accuracy: 0.6206 - val_loss: 1.0485 - val_accuracy: 0.6329
Epoch 22/25
196/196 [==============================] - 2s 8ms/step - loss: 1.0711 - accuracy: 0.6238 - val_loss: 1.0369 - val_accuracy: 0.6358
Epoch 23/25
196/196 [==============================] - 2s 8ms/step - loss: 1.0565 - accuracy: 0.6293 - val_loss: 1.0801 - val_accuracy: 0.6224
Epoch 24/25
196/196 [==============================] - 2s 8ms/step - loss: 1.0353 - accuracy: 0.6358 - val_loss: 1.1216 - val_accuracy: 0.6056
Epoch 25/25
196/196 [==============================] - 2s 8ms/step - loss: 1.0229 - accuracy: 0.6412 - val_loss: 1.0463 - val_accuracy: 0.6339
CPU times: user 49 s, sys: 6.5 s, total: 55.4 s
Wall time: 43.2 s

Titan V GPU 및 Intel Xeon E5-2690 CPU를 탑재한 시스템에서 속도 향상은 약 1.17배입니다.