Classer CIFAR-10 avec XLA

Restez organisé à l'aide des collections Enregistrez et classez les contenus selon vos préférences.

Voir sur TensorFlow.org Exécuter dans Google Colab Voir la source sur GitHub

Ce train tutoriel un modèle tensorflow pour classer les ICRA-10 ensemble de données, et nous compilons à l'aide XLA.

Chargez et normalisez l'ensemble de données à l'aide de l'API Keras :

import tensorflow as tf

# Check that GPU is available: cf. https://colab.research.google.com/notebooks/gpu.ipynb
assert(tf.test.gpu_device_name())

tf.keras.backend.clear_session()
tf.config.optimizer.set_jit(False) # Start with XLA disabled.

def load_data():
  (x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
  x_train = x_train.astype('float32') / 256
  x_test = x_test.astype('float32') / 256

  # Convert class vectors to binary class matrices.
  y_train = tf.keras.utils.to_categorical(y_train, num_classes=10)
  y_test = tf.keras.utils.to_categorical(y_test, num_classes=10)
  return ((x_train, y_train), (x_test, y_test))

(x_train, y_train), (x_test, y_test) = load_data()
Downloading data from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz
170500096/170498071 [==============================] - 12s 0us/step
170508288/170498071 [==============================] - 12s 0us/step

Nous définissons le modèle, adapté du Keras exemple-10 ICRA :

def generate_model():
  return tf.keras.models.Sequential([
    tf.keras.layers.Conv2D(32, (3, 3), padding='same', input_shape=x_train.shape[1:]),
    tf.keras.layers.Activation('relu'),
    tf.keras.layers.Conv2D(32, (3, 3)),
    tf.keras.layers.Activation('relu'),
    tf.keras.layers.MaxPooling2D(pool_size=(2, 2)),
    tf.keras.layers.Dropout(0.25),

    tf.keras.layers.Conv2D(64, (3, 3), padding='same'),
    tf.keras.layers.Activation('relu'),
    tf.keras.layers.Conv2D(64, (3, 3)),
    tf.keras.layers.Activation('relu'),
    tf.keras.layers.MaxPooling2D(pool_size=(2, 2)),
    tf.keras.layers.Dropout(0.25),

    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(512),
    tf.keras.layers.Activation('relu'),
    tf.keras.layers.Dropout(0.5),
    tf.keras.layers.Dense(10),
    tf.keras.layers.Activation('softmax')
  ])

model = generate_model()

Nous formons le modèle en utilisant le RMSprop optimiseur:

def compile_model(model):
  opt = tf.keras.optimizers.RMSprop(lr=0.0001, decay=1e-6)
  model.compile(loss='categorical_crossentropy',
                optimizer=opt,
                metrics=['accuracy'])
  return model

model = compile_model(model)

def train_model(model, x_train, y_train, x_test, y_test, epochs=25):
  model.fit(x_train, y_train, batch_size=256, epochs=epochs, validation_data=(x_test, y_test), shuffle=True)

def warmup(model, x_train, y_train, x_test, y_test):
  # Warm up the JIT, we do not wish to measure the compilation time.
  initial_weights = model.get_weights()
  train_model(model, x_train, y_train, x_test, y_test, epochs=1)
  model.set_weights(initial_weights)

warmup(model, x_train, y_train, x_test, y_test)
%time train_model(model, x_train, y_train, x_test, y_test)

scores = model.evaluate(x_test, y_test, verbose=1)
print('Test loss:', scores[0])
print('Test accuracy:', scores[1])
/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/keras/optimizer_v2/optimizer_v2.py:356: UserWarning: The `lr` argument is deprecated, use `learning_rate` instead.
  "The `lr` argument is deprecated, use `learning_rate` instead.")
196/196 [==============================] - 12s 11ms/step - loss: 2.0710 - accuracy: 0.2378 - val_loss: 1.8439 - val_accuracy: 0.3554
Epoch 1/25
196/196 [==============================] - 2s 10ms/step - loss: 2.1360 - accuracy: 0.2055 - val_loss: 1.9299 - val_accuracy: 0.3314
Epoch 2/25
196/196 [==============================] - 2s 8ms/step - loss: 1.8248 - accuracy: 0.3405 - val_loss: 1.6969 - val_accuracy: 0.3973
Epoch 3/25
196/196 [==============================] - 2s 8ms/step - loss: 1.6864 - accuracy: 0.3944 - val_loss: 1.5874 - val_accuracy: 0.4326
Epoch 4/25
196/196 [==============================] - 2s 8ms/step - loss: 1.5985 - accuracy: 0.4238 - val_loss: 1.5332 - val_accuracy: 0.4401
Epoch 5/25
196/196 [==============================] - 2s 8ms/step - loss: 1.5332 - accuracy: 0.4453 - val_loss: 1.5122 - val_accuracy: 0.4598
Epoch 6/25
196/196 [==============================] - 2s 8ms/step - loss: 1.4860 - accuracy: 0.4627 - val_loss: 1.4261 - val_accuracy: 0.4880
Epoch 7/25
196/196 [==============================] - 2s 8ms/step - loss: 1.4434 - accuracy: 0.4790 - val_loss: 1.3658 - val_accuracy: 0.5058
Epoch 8/25
196/196 [==============================] - 2s 8ms/step - loss: 1.4049 - accuracy: 0.4976 - val_loss: 1.3883 - val_accuracy: 0.5022
Epoch 9/25
196/196 [==============================] - 2s 8ms/step - loss: 1.3727 - accuracy: 0.5110 - val_loss: 1.3145 - val_accuracy: 0.5329
Epoch 10/25
196/196 [==============================] - 2s 8ms/step - loss: 1.3462 - accuracy: 0.5208 - val_loss: 1.2622 - val_accuracy: 0.5503
Epoch 11/25
196/196 [==============================] - 2s 8ms/step - loss: 1.3132 - accuracy: 0.5323 - val_loss: 1.2740 - val_accuracy: 0.5528
Epoch 12/25
196/196 [==============================] - 2s 8ms/step - loss: 1.2875 - accuracy: 0.5431 - val_loss: 1.2296 - val_accuracy: 0.5677
Epoch 13/25
196/196 [==============================] - 2s 8ms/step - loss: 1.2601 - accuracy: 0.5525 - val_loss: 1.3068 - val_accuracy: 0.5353
Epoch 14/25
196/196 [==============================] - 2s 8ms/step - loss: 1.2376 - accuracy: 0.5606 - val_loss: 1.1662 - val_accuracy: 0.5904
Epoch 15/25
196/196 [==============================] - 2s 8ms/step - loss: 1.2112 - accuracy: 0.5708 - val_loss: 1.1504 - val_accuracy: 0.5939
Epoch 16/25
196/196 [==============================] - 2s 8ms/step - loss: 1.1923 - accuracy: 0.5789 - val_loss: 1.1133 - val_accuracy: 0.6125
Epoch 17/25
196/196 [==============================] - 2s 8ms/step - loss: 1.1693 - accuracy: 0.5886 - val_loss: 1.1189 - val_accuracy: 0.6088
Epoch 18/25
196/196 [==============================] - 2s 8ms/step - loss: 1.1498 - accuracy: 0.5938 - val_loss: 1.1080 - val_accuracy: 0.6142
Epoch 19/25
196/196 [==============================] - 2s 8ms/step - loss: 1.1291 - accuracy: 0.6031 - val_loss: 1.0749 - val_accuracy: 0.6290
Epoch 20/25
196/196 [==============================] - 2s 8ms/step - loss: 1.1097 - accuracy: 0.6111 - val_loss: 1.0363 - val_accuracy: 0.6447
Epoch 21/25
196/196 [==============================] - 2s 8ms/step - loss: 1.0928 - accuracy: 0.6161 - val_loss: 1.0340 - val_accuracy: 0.6387
Epoch 22/25
196/196 [==============================] - 2s 8ms/step - loss: 1.0738 - accuracy: 0.6238 - val_loss: 1.0650 - val_accuracy: 0.6244
Epoch 23/25
196/196 [==============================] - 2s 8ms/step - loss: 1.0571 - accuracy: 0.6286 - val_loss: 0.9993 - val_accuracy: 0.6535
Epoch 24/25
196/196 [==============================] - 2s 8ms/step - loss: 1.0470 - accuracy: 0.6325 - val_loss: 0.9925 - val_accuracy: 0.6511
Epoch 25/25
196/196 [==============================] - 2s 8ms/step - loss: 1.0274 - accuracy: 0.6405 - val_loss: 1.0276 - val_accuracy: 0.6387
CPU times: user 49.6 s, sys: 7.41 s, total: 57.1 s
Wall time: 40.9 s
313/313 [==============================] - 1s 2ms/step - loss: 1.0276 - accuracy: 0.6387
Test loss: 1.0276497602462769
Test accuracy: 0.638700008392334

Entraînons à nouveau le modèle à l'aide du compilateur XLA. Pour activer le compilateur au milieu de l'application, nous devons réinitialiser la session Keras.

# We need to clear the session to enable JIT in the middle of the program.
tf.keras.backend.clear_session()
tf.config.optimizer.set_jit(True) # Enable XLA.
model = compile_model(generate_model())
(x_train, y_train), (x_test, y_test) = load_data()

warmup(model, x_train, y_train, x_test, y_test)
%time train_model(model, x_train, y_train, x_test, y_test)
196/196 [==============================] - 6s 14ms/step - loss: 2.0028 - accuracy: 0.2637 - val_loss: 1.7774 - val_accuracy: 0.3784
Epoch 1/25
196/196 [==============================] - 4s 19ms/step - loss: 2.1142 - accuracy: 0.2136 - val_loss: 1.8591 - val_accuracy: 0.3303
Epoch 2/25
196/196 [==============================] - 2s 8ms/step - loss: 1.7614 - accuracy: 0.3622 - val_loss: 1.6321 - val_accuracy: 0.4149
Epoch 3/25
196/196 [==============================] - 2s 8ms/step - loss: 1.6471 - accuracy: 0.3991 - val_loss: 1.5480 - val_accuracy: 0.4379
Epoch 4/25
196/196 [==============================] - 2s 8ms/step - loss: 1.5706 - accuracy: 0.4303 - val_loss: 1.4671 - val_accuracy: 0.4691
Epoch 5/25
196/196 [==============================] - 2s 8ms/step - loss: 1.5098 - accuracy: 0.4514 - val_loss: 1.4270 - val_accuracy: 0.4858
Epoch 6/25
196/196 [==============================] - 2s 8ms/step - loss: 1.4561 - accuracy: 0.4734 - val_loss: 1.4270 - val_accuracy: 0.4919
Epoch 7/25
196/196 [==============================] - 2s 8ms/step - loss: 1.4199 - accuracy: 0.4906 - val_loss: 1.3290 - val_accuracy: 0.5272
Epoch 8/25
196/196 [==============================] - 2s 8ms/step - loss: 1.3831 - accuracy: 0.5049 - val_loss: 1.3550 - val_accuracy: 0.5205
Epoch 9/25
196/196 [==============================] - 2s 8ms/step - loss: 1.3487 - accuracy: 0.5176 - val_loss: 1.3339 - val_accuracy: 0.5240
Epoch 10/25
196/196 [==============================] - 2s 8ms/step - loss: 1.3190 - accuracy: 0.5275 - val_loss: 1.2579 - val_accuracy: 0.5528
Epoch 11/25
196/196 [==============================] - 2s 8ms/step - loss: 1.2930 - accuracy: 0.5415 - val_loss: 1.2364 - val_accuracy: 0.5654
Epoch 12/25
196/196 [==============================] - 2s 8ms/step - loss: 1.2656 - accuracy: 0.5499 - val_loss: 1.2646 - val_accuracy: 0.5561
Epoch 13/25
196/196 [==============================] - 2s 8ms/step - loss: 1.2416 - accuracy: 0.5614 - val_loss: 1.2615 - val_accuracy: 0.5540
Epoch 14/25
196/196 [==============================] - 2s 8ms/step - loss: 1.2196 - accuracy: 0.5681 - val_loss: 1.1970 - val_accuracy: 0.5797
Epoch 15/25
196/196 [==============================] - 2s 8ms/step - loss: 1.1996 - accuracy: 0.5751 - val_loss: 1.1142 - val_accuracy: 0.6099
Epoch 16/25
196/196 [==============================] - 2s 8ms/step - loss: 1.1777 - accuracy: 0.5849 - val_loss: 1.2225 - val_accuracy: 0.5653
Epoch 17/25
196/196 [==============================] - 2s 8ms/step - loss: 1.1542 - accuracy: 0.5912 - val_loss: 1.0860 - val_accuracy: 0.6187
Epoch 18/25
196/196 [==============================] - 2s 8ms/step - loss: 1.1358 - accuracy: 0.6009 - val_loss: 1.0767 - val_accuracy: 0.6180
Epoch 19/25
196/196 [==============================] - 2s 8ms/step - loss: 1.1197 - accuracy: 0.6062 - val_loss: 1.0517 - val_accuracy: 0.6318
Epoch 20/25
196/196 [==============================] - 2s 8ms/step - loss: 1.0980 - accuracy: 0.6139 - val_loss: 1.0362 - val_accuracy: 0.6390
Epoch 21/25
196/196 [==============================] - 2s 8ms/step - loss: 1.0738 - accuracy: 0.6233 - val_loss: 1.0777 - val_accuracy: 0.6256
Epoch 22/25
196/196 [==============================] - 2s 8ms/step - loss: 1.0555 - accuracy: 0.6285 - val_loss: 1.0615 - val_accuracy: 0.6353
Epoch 23/25
196/196 [==============================] - 2s 8ms/step - loss: 1.0434 - accuracy: 0.6331 - val_loss: 1.0025 - val_accuracy: 0.6498
Epoch 24/25
196/196 [==============================] - 2s 8ms/step - loss: 1.0285 - accuracy: 0.6386 - val_loss: 0.9670 - val_accuracy: 0.6614
Epoch 25/25
196/196 [==============================] - 2s 8ms/step - loss: 1.0116 - accuracy: 0.6443 - val_loss: 0.9806 - val_accuracy: 0.6576
CPU times: user 43.8 s, sys: 6.15 s, total: 49.9 s
Wall time: 42.6 s

Sur une machine avec un GPU Titan V et un processeur Intel Xeon E5-2690, la vitesse est d'environ 1,17x.