Google I / O kehrt vom 18. bis 20. Mai zurück! Reservieren Sie Platz und erstellen Sie Ihren Zeitplan Registrieren Sie sich jetzt
Diese Seite wurde von der Cloud Translation API übersetzt.
Switch to English

Klassifizierung von CIFAR-10 mit XLA

Ansicht auf TensorFlow.org In Google Colab ausführen Quelle auf GitHub anzeigen

In diesem Lernprogramm wird ein TensorFlow-Modell trainiert, um das CIFAR-10- Dataset zu klassifizieren, und wir kompilieren es mit XLA.

Laden und normalisieren Sie das Dataset mithilfe der Keras-API:

import tensorflow as tf

# Check that GPU is available: cf. https://colab.research.google.com/notebooks/gpu.ipynb
assert(tf.test.gpu_device_name())

tf.keras.backend.clear_session()
tf.config.optimizer.set_jit(False) # Start with XLA disabled.

def load_data():
  (x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
  x_train = x_train.astype('float32') / 256
  x_test = x_test.astype('float32') / 256

  # Convert class vectors to binary class matrices.
  y_train = tf.keras.utils.to_categorical(y_train, num_classes=10)
  y_test = tf.keras.utils.to_categorical(y_test, num_classes=10)
  return ((x_train, y_train), (x_test, y_test))

(x_train, y_train), (x_test, y_test) = load_data()
Downloading data from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz
170500096/170498071 [==============================] - 4s 0us/step

Wir definieren das Modell, angepasst an das Keras CIFAR-10-Beispiel :

def generate_model():
  return tf.keras.models.Sequential([
    tf.keras.layers.Conv2D(32, (3, 3), padding='same', input_shape=x_train.shape[1:]),
    tf.keras.layers.Activation('relu'),
    tf.keras.layers.Conv2D(32, (3, 3)),
    tf.keras.layers.Activation('relu'),
    tf.keras.layers.MaxPooling2D(pool_size=(2, 2)),
    tf.keras.layers.Dropout(0.25),

    tf.keras.layers.Conv2D(64, (3, 3), padding='same'),
    tf.keras.layers.Activation('relu'),
    tf.keras.layers.Conv2D(64, (3, 3)),
    tf.keras.layers.Activation('relu'),
    tf.keras.layers.MaxPooling2D(pool_size=(2, 2)),
    tf.keras.layers.Dropout(0.25),

    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(512),
    tf.keras.layers.Activation('relu'),
    tf.keras.layers.Dropout(0.5),
    tf.keras.layers.Dense(10),
    tf.keras.layers.Activation('softmax')
  ])

model = generate_model()

Wir trainieren das Modell mit dem RMSprop- Optimierer:

def compile_model(model):
  opt = tf.keras.optimizers.RMSprop(lr=0.0001, decay=1e-6)
  model.compile(loss='categorical_crossentropy',
                optimizer=opt,
                metrics=['accuracy'])
  return model

model = compile_model(model)

def train_model(model, x_train, y_train, x_test, y_test, epochs=25):
  model.fit(x_train, y_train, batch_size=256, epochs=epochs, validation_data=(x_test, y_test), shuffle=True)

def warmup(model, x_train, y_train, x_test, y_test):
  # Warm up the JIT, we do not wish to measure the compilation time.
  initial_weights = model.get_weights()
  train_model(model, x_train, y_train, x_test, y_test, epochs=1)
  model.set_weights(initial_weights)

warmup(model, x_train, y_train, x_test, y_test)
%time train_model(model, x_train, y_train, x_test, y_test)

scores = model.evaluate(x_test, y_test, verbose=1)
print('Test loss:', scores[0])
print('Test accuracy:', scores[1])
196/196 [==============================] - 14s 12ms/step - loss: 2.1887 - accuracy: 0.1773 - val_loss: 1.8642 - val_accuracy: 0.3363
Epoch 1/25
196/196 [==============================] - 2s 10ms/step - loss: 2.1349 - accuracy: 0.2016 - val_loss: 1.8978 - val_accuracy: 0.3353
Epoch 2/25
196/196 [==============================] - 2s 9ms/step - loss: 1.8088 - accuracy: 0.3462 - val_loss: 1.7053 - val_accuracy: 0.3935
Epoch 3/25
196/196 [==============================] - 2s 8ms/step - loss: 1.6749 - accuracy: 0.3941 - val_loss: 1.5618 - val_accuracy: 0.4401
Epoch 4/25
196/196 [==============================] - 2s 8ms/step - loss: 1.5902 - accuracy: 0.4240 - val_loss: 1.5003 - val_accuracy: 0.4683
Epoch 5/25
196/196 [==============================] - 2s 8ms/step - loss: 1.5168 - accuracy: 0.4486 - val_loss: 1.4156 - val_accuracy: 0.4967
Epoch 6/25
196/196 [==============================] - 2s 8ms/step - loss: 1.4654 - accuracy: 0.4703 - val_loss: 1.4081 - val_accuracy: 0.4961
Epoch 7/25
196/196 [==============================] - 2s 8ms/step - loss: 1.4231 - accuracy: 0.4869 - val_loss: 1.3556 - val_accuracy: 0.5162
Epoch 8/25
196/196 [==============================] - 2s 9ms/step - loss: 1.3901 - accuracy: 0.5019 - val_loss: 1.3041 - val_accuracy: 0.5368
Epoch 9/25
196/196 [==============================] - 2s 9ms/step - loss: 1.3559 - accuracy: 0.5162 - val_loss: 1.2992 - val_accuracy: 0.5475
Epoch 10/25
196/196 [==============================] - 2s 9ms/step - loss: 1.3259 - accuracy: 0.5255 - val_loss: 1.2536 - val_accuracy: 0.5587
Epoch 11/25
196/196 [==============================] - 2s 9ms/step - loss: 1.2971 - accuracy: 0.5375 - val_loss: 1.2550 - val_accuracy: 0.5607
Epoch 12/25
196/196 [==============================] - 2s 8ms/step - loss: 1.2713 - accuracy: 0.5505 - val_loss: 1.1769 - val_accuracy: 0.5906
Epoch 13/25
196/196 [==============================] - 2s 8ms/step - loss: 1.2476 - accuracy: 0.5582 - val_loss: 1.1955 - val_accuracy: 0.5770
Epoch 14/25
196/196 [==============================] - 2s 8ms/step - loss: 1.2216 - accuracy: 0.5679 - val_loss: 1.1839 - val_accuracy: 0.5813
Epoch 15/25
196/196 [==============================] - 2s 8ms/step - loss: 1.2023 - accuracy: 0.5745 - val_loss: 1.1746 - val_accuracy: 0.5912
Epoch 16/25
196/196 [==============================] - 2s 8ms/step - loss: 1.1755 - accuracy: 0.5842 - val_loss: 1.1104 - val_accuracy: 0.6097
Epoch 17/25
196/196 [==============================] - 2s 8ms/step - loss: 1.1513 - accuracy: 0.5954 - val_loss: 1.0757 - val_accuracy: 0.6233
Epoch 18/25
196/196 [==============================] - 2s 8ms/step - loss: 1.1392 - accuracy: 0.5998 - val_loss: 1.0859 - val_accuracy: 0.6209
Epoch 19/25
196/196 [==============================] - 2s 8ms/step - loss: 1.1167 - accuracy: 0.6059 - val_loss: 1.0935 - val_accuracy: 0.6183
Epoch 20/25
196/196 [==============================] - 2s 8ms/step - loss: 1.0941 - accuracy: 0.6143 - val_loss: 1.0590 - val_accuracy: 0.6329
Epoch 21/25
196/196 [==============================] - 2s 8ms/step - loss: 1.0814 - accuracy: 0.6208 - val_loss: 1.0499 - val_accuracy: 0.6334
Epoch 22/25
196/196 [==============================] - 2s 9ms/step - loss: 1.0638 - accuracy: 0.6275 - val_loss: 0.9962 - val_accuracy: 0.6580
Epoch 23/25
196/196 [==============================] - 2s 8ms/step - loss: 1.0448 - accuracy: 0.6342 - val_loss: 1.0240 - val_accuracy: 0.6419
Epoch 24/25
196/196 [==============================] - 2s 9ms/step - loss: 1.0301 - accuracy: 0.6402 - val_loss: 0.9885 - val_accuracy: 0.6512
Epoch 25/25
196/196 [==============================] - 2s 8ms/step - loss: 1.0165 - accuracy: 0.6423 - val_loss: 0.9609 - val_accuracy: 0.6659
CPU times: user 55.2 s, sys: 8.16 s, total: 1min 3s
Wall time: 42.7 s
313/313 [==============================] - 1s 3ms/step - loss: 0.9609 - accuracy: 0.6659
Test loss: 0.9608545899391174
Test accuracy: 0.6658999919891357

Nun trainieren wir das Modell erneut mit dem XLA-Compiler. Um den Compiler in der Mitte der Anwendung zu aktivieren, müssen wir die Keras-Sitzung zurücksetzen.

# We need to clear the session to enable JIT in the middle of the program.
tf.keras.backend.clear_session()
tf.config.optimizer.set_jit(True) # Enable XLA.
model = compile_model(generate_model())
(x_train, y_train), (x_test, y_test) = load_data()

warmup(model, x_train, y_train, x_test, y_test)
%time train_model(model, x_train, y_train, x_test, y_test)
196/196 [==============================] - 6s 14ms/step - loss: 2.1602 - accuracy: 0.1890 - val_loss: 1.8198 - val_accuracy: 0.3543
Epoch 1/25
196/196 [==============================] - 4s 18ms/step - loss: 2.1084 - accuracy: 0.2171 - val_loss: 1.8668 - val_accuracy: 0.3387
Epoch 2/25
196/196 [==============================] - 2s 8ms/step - loss: 1.8050 - accuracy: 0.3503 - val_loss: 1.7050 - val_accuracy: 0.3965
Epoch 3/25
196/196 [==============================] - 2s 8ms/step - loss: 1.6757 - accuracy: 0.3943 - val_loss: 1.5696 - val_accuracy: 0.4384
Epoch 4/25
196/196 [==============================] - 2s 8ms/step - loss: 1.5956 - accuracy: 0.4228 - val_loss: 1.5098 - val_accuracy: 0.4567
Epoch 5/25
196/196 [==============================] - 2s 8ms/step - loss: 1.5310 - accuracy: 0.4456 - val_loss: 1.4722 - val_accuracy: 0.4707
Epoch 6/25
196/196 [==============================] - 2s 8ms/step - loss: 1.4825 - accuracy: 0.4631 - val_loss: 1.5245 - val_accuracy: 0.4628
Epoch 7/25
196/196 [==============================] - 2s 8ms/step - loss: 1.4374 - accuracy: 0.4837 - val_loss: 1.4239 - val_accuracy: 0.4915
Epoch 8/25
196/196 [==============================] - 2s 8ms/step - loss: 1.3900 - accuracy: 0.5026 - val_loss: 1.3184 - val_accuracy: 0.5260
Epoch 9/25
196/196 [==============================] - 2s 8ms/step - loss: 1.3600 - accuracy: 0.5143 - val_loss: 1.2731 - val_accuracy: 0.5496
Epoch 10/25
196/196 [==============================] - 2s 8ms/step - loss: 1.3260 - accuracy: 0.5281 - val_loss: 1.2552 - val_accuracy: 0.5542
Epoch 11/25
196/196 [==============================] - 2s 8ms/step - loss: 1.2938 - accuracy: 0.5387 - val_loss: 1.2242 - val_accuracy: 0.5738
Epoch 12/25
196/196 [==============================] - 2s 9ms/step - loss: 1.2642 - accuracy: 0.5510 - val_loss: 1.2240 - val_accuracy: 0.5596
Epoch 13/25
196/196 [==============================] - 2s 8ms/step - loss: 1.2389 - accuracy: 0.5622 - val_loss: 1.1663 - val_accuracy: 0.5868
Epoch 14/25
196/196 [==============================] - 2s 8ms/step - loss: 1.2110 - accuracy: 0.5711 - val_loss: 1.1312 - val_accuracy: 0.5983
Epoch 15/25
196/196 [==============================] - 2s 8ms/step - loss: 1.1856 - accuracy: 0.5821 - val_loss: 1.1978 - val_accuracy: 0.5730
Epoch 16/25
196/196 [==============================] - 2s 8ms/step - loss: 1.1619 - accuracy: 0.5890 - val_loss: 1.2709 - val_accuracy: 0.5568
Epoch 17/25
196/196 [==============================] - 2s 8ms/step - loss: 1.1430 - accuracy: 0.5975 - val_loss: 1.0918 - val_accuracy: 0.6181
Epoch 18/25
196/196 [==============================] - 2s 8ms/step - loss: 1.1190 - accuracy: 0.6074 - val_loss: 1.0924 - val_accuracy: 0.6148
Epoch 19/25
196/196 [==============================] - 2s 8ms/step - loss: 1.0970 - accuracy: 0.6130 - val_loss: 1.0485 - val_accuracy: 0.6277
Epoch 20/25
196/196 [==============================] - 2s 8ms/step - loss: 1.0830 - accuracy: 0.6191 - val_loss: 1.0675 - val_accuracy: 0.6241
Epoch 21/25
196/196 [==============================] - 2s 8ms/step - loss: 1.0616 - accuracy: 0.6289 - val_loss: 1.0053 - val_accuracy: 0.6540
Epoch 22/25
196/196 [==============================] - 2s 8ms/step - loss: 1.0483 - accuracy: 0.6331 - val_loss: 0.9849 - val_accuracy: 0.6615
Epoch 23/25
196/196 [==============================] - 2s 8ms/step - loss: 1.0322 - accuracy: 0.6364 - val_loss: 0.9753 - val_accuracy: 0.6617
Epoch 24/25
196/196 [==============================] - 2s 8ms/step - loss: 1.0190 - accuracy: 0.6421 - val_loss: 0.9657 - val_accuracy: 0.6639
Epoch 25/25
196/196 [==============================] - 1s 8ms/step - loss: 0.9991 - accuracy: 0.6517 - val_loss: 0.9733 - val_accuracy: 0.6630
CPU times: user 45.4 s, sys: 6.48 s, total: 51.9 s
Wall time: 41.2 s

Auf einem Computer mit einer Titan V-GPU und einer Intel Xeon E5-2690-CPU beträgt die Geschwindigkeit ~ 1,17x.