¡Confirme su asistencia a su evento local de TensorFlow Everywhere hoy!
Se usó la API de Cloud Translation para traducir esta página.
Switch to English

Clasificación de CIFAR-10 con XLA

Ver en TensorFlow.org Ejecutar en Google Colab Ver fuente en GitHub

Este instructivo entrena un modelo de TensorFlow para clasificar el conjunto de datos CIFAR-10 , y lo compilamos usando XLA.

Cargue y normalice el conjunto de datos utilizando la API de Keras:

import tensorflow as tf

# Check that GPU is available: cf. https://colab.research.google.com/notebooks/gpu.ipynb
assert(tf.test.gpu_device_name())

tf.keras.backend.clear_session()
tf.config.optimizer.set_jit(False) # Start with XLA disabled.

def load_data():
  (x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
  x_train = x_train.astype('float32') / 256
  x_test = x_test.astype('float32') / 256

  # Convert class vectors to binary class matrices.
  y_train = tf.keras.utils.to_categorical(y_train, num_classes=10)
  y_test = tf.keras.utils.to_categorical(y_test, num_classes=10)
  return ((x_train, y_train), (x_test, y_test))

(x_train, y_train), (x_test, y_test) = load_data()
Downloading data from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz
170500096/170498071 [==============================] - 4s 0us/step

Definimos el modelo, adaptado del ejemplo de Keras CIFAR-10 :

def generate_model():
  return tf.keras.models.Sequential([
    tf.keras.layers.Conv2D(32, (3, 3), padding='same', input_shape=x_train.shape[1:]),
    tf.keras.layers.Activation('relu'),
    tf.keras.layers.Conv2D(32, (3, 3)),
    tf.keras.layers.Activation('relu'),
    tf.keras.layers.MaxPooling2D(pool_size=(2, 2)),
    tf.keras.layers.Dropout(0.25),

    tf.keras.layers.Conv2D(64, (3, 3), padding='same'),
    tf.keras.layers.Activation('relu'),
    tf.keras.layers.Conv2D(64, (3, 3)),
    tf.keras.layers.Activation('relu'),
    tf.keras.layers.MaxPooling2D(pool_size=(2, 2)),
    tf.keras.layers.Dropout(0.25),

    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(512),
    tf.keras.layers.Activation('relu'),
    tf.keras.layers.Dropout(0.5),
    tf.keras.layers.Dense(10),
    tf.keras.layers.Activation('softmax')
  ])

model = generate_model()

Entrenamos el modelo utilizando el optimizador RMSprop :

def compile_model(model):
  opt = tf.keras.optimizers.RMSprop(lr=0.0001, decay=1e-6)
  model.compile(loss='categorical_crossentropy',
                optimizer=opt,
                metrics=['accuracy'])
  return model

model = compile_model(model)

def train_model(model, x_train, y_train, x_test, y_test, epochs=25):
  model.fit(x_train, y_train, batch_size=256, epochs=epochs, validation_data=(x_test, y_test), shuffle=True)

def warmup(model, x_train, y_train, x_test, y_test):
  # Warm up the JIT, we do not wish to measure the compilation time.
  initial_weights = model.get_weights()
  train_model(model, x_train, y_train, x_test, y_test, epochs=1)
  model.set_weights(initial_weights)

warmup(model, x_train, y_train, x_test, y_test)
%time train_model(model, x_train, y_train, x_test, y_test)

scores = model.evaluate(x_test, y_test, verbose=1)
print('Test loss:', scores[0])
print('Test accuracy:', scores[1])
196/196 [==============================] - 14s 12ms/step - loss: 2.1887 - accuracy: 0.1773 - val_loss: 1.8642 - val_accuracy: 0.3363
Epoch 1/25
196/196 [==============================] - 2s 10ms/step - loss: 2.1349 - accuracy: 0.2016 - val_loss: 1.8978 - val_accuracy: 0.3353
Epoch 2/25
196/196 [==============================] - 2s 9ms/step - loss: 1.8088 - accuracy: 0.3462 - val_loss: 1.7053 - val_accuracy: 0.3935
Epoch 3/25
196/196 [==============================] - 2s 8ms/step - loss: 1.6749 - accuracy: 0.3941 - val_loss: 1.5618 - val_accuracy: 0.4401
Epoch 4/25
196/196 [==============================] - 2s 8ms/step - loss: 1.5902 - accuracy: 0.4240 - val_loss: 1.5003 - val_accuracy: 0.4683
Epoch 5/25
196/196 [==============================] - 2s 8ms/step - loss: 1.5168 - accuracy: 0.4486 - val_loss: 1.4156 - val_accuracy: 0.4967
Epoch 6/25
196/196 [==============================] - 2s 8ms/step - loss: 1.4654 - accuracy: 0.4703 - val_loss: 1.4081 - val_accuracy: 0.4961
Epoch 7/25
196/196 [==============================] - 2s 8ms/step - loss: 1.4231 - accuracy: 0.4869 - val_loss: 1.3556 - val_accuracy: 0.5162
Epoch 8/25
196/196 [==============================] - 2s 9ms/step - loss: 1.3901 - accuracy: 0.5019 - val_loss: 1.3041 - val_accuracy: 0.5368
Epoch 9/25
196/196 [==============================] - 2s 9ms/step - loss: 1.3559 - accuracy: 0.5162 - val_loss: 1.2992 - val_accuracy: 0.5475
Epoch 10/25
196/196 [==============================] - 2s 9ms/step - loss: 1.3259 - accuracy: 0.5255 - val_loss: 1.2536 - val_accuracy: 0.5587
Epoch 11/25
196/196 [==============================] - 2s 9ms/step - loss: 1.2971 - accuracy: 0.5375 - val_loss: 1.2550 - val_accuracy: 0.5607
Epoch 12/25
196/196 [==============================] - 2s 8ms/step - loss: 1.2713 - accuracy: 0.5505 - val_loss: 1.1769 - val_accuracy: 0.5906
Epoch 13/25
196/196 [==============================] - 2s 8ms/step - loss: 1.2476 - accuracy: 0.5582 - val_loss: 1.1955 - val_accuracy: 0.5770
Epoch 14/25
196/196 [==============================] - 2s 8ms/step - loss: 1.2216 - accuracy: 0.5679 - val_loss: 1.1839 - val_accuracy: 0.5813
Epoch 15/25
196/196 [==============================] - 2s 8ms/step - loss: 1.2023 - accuracy: 0.5745 - val_loss: 1.1746 - val_accuracy: 0.5912
Epoch 16/25
196/196 [==============================] - 2s 8ms/step - loss: 1.1755 - accuracy: 0.5842 - val_loss: 1.1104 - val_accuracy: 0.6097
Epoch 17/25
196/196 [==============================] - 2s 8ms/step - loss: 1.1513 - accuracy: 0.5954 - val_loss: 1.0757 - val_accuracy: 0.6233
Epoch 18/25
196/196 [==============================] - 2s 8ms/step - loss: 1.1392 - accuracy: 0.5998 - val_loss: 1.0859 - val_accuracy: 0.6209
Epoch 19/25
196/196 [==============================] - 2s 8ms/step - loss: 1.1167 - accuracy: 0.6059 - val_loss: 1.0935 - val_accuracy: 0.6183
Epoch 20/25
196/196 [==============================] - 2s 8ms/step - loss: 1.0941 - accuracy: 0.6143 - val_loss: 1.0590 - val_accuracy: 0.6329
Epoch 21/25
196/196 [==============================] - 2s 8ms/step - loss: 1.0814 - accuracy: 0.6208 - val_loss: 1.0499 - val_accuracy: 0.6334
Epoch 22/25
196/196 [==============================] - 2s 9ms/step - loss: 1.0638 - accuracy: 0.6275 - val_loss: 0.9962 - val_accuracy: 0.6580
Epoch 23/25
196/196 [==============================] - 2s 8ms/step - loss: 1.0448 - accuracy: 0.6342 - val_loss: 1.0240 - val_accuracy: 0.6419
Epoch 24/25
196/196 [==============================] - 2s 9ms/step - loss: 1.0301 - accuracy: 0.6402 - val_loss: 0.9885 - val_accuracy: 0.6512
Epoch 25/25
196/196 [==============================] - 2s 8ms/step - loss: 1.0165 - accuracy: 0.6423 - val_loss: 0.9609 - val_accuracy: 0.6659
CPU times: user 55.2 s, sys: 8.16 s, total: 1min 3s
Wall time: 42.7 s
313/313 [==============================] - 1s 3ms/step - loss: 0.9609 - accuracy: 0.6659
Test loss: 0.9608545899391174
Test accuracy: 0.6658999919891357

Ahora entrenemos el modelo nuevamente, usando el compilador XLA. Para habilitar el compilador en medio de la aplicación, necesitamos restablecer la sesión de Keras.

# We need to clear the session to enable JIT in the middle of the program.
tf.keras.backend.clear_session()
tf.config.optimizer.set_jit(True) # Enable XLA.
model = compile_model(generate_model())
(x_train, y_train), (x_test, y_test) = load_data()

warmup(model, x_train, y_train, x_test, y_test)
%time train_model(model, x_train, y_train, x_test, y_test)
196/196 [==============================] - 6s 14ms/step - loss: 2.1602 - accuracy: 0.1890 - val_loss: 1.8198 - val_accuracy: 0.3543
Epoch 1/25
196/196 [==============================] - 4s 18ms/step - loss: 2.1084 - accuracy: 0.2171 - val_loss: 1.8668 - val_accuracy: 0.3387
Epoch 2/25
196/196 [==============================] - 2s 8ms/step - loss: 1.8050 - accuracy: 0.3503 - val_loss: 1.7050 - val_accuracy: 0.3965
Epoch 3/25
196/196 [==============================] - 2s 8ms/step - loss: 1.6757 - accuracy: 0.3943 - val_loss: 1.5696 - val_accuracy: 0.4384
Epoch 4/25
196/196 [==============================] - 2s 8ms/step - loss: 1.5956 - accuracy: 0.4228 - val_loss: 1.5098 - val_accuracy: 0.4567
Epoch 5/25
196/196 [==============================] - 2s 8ms/step - loss: 1.5310 - accuracy: 0.4456 - val_loss: 1.4722 - val_accuracy: 0.4707
Epoch 6/25
196/196 [==============================] - 2s 8ms/step - loss: 1.4825 - accuracy: 0.4631 - val_loss: 1.5245 - val_accuracy: 0.4628
Epoch 7/25
196/196 [==============================] - 2s 8ms/step - loss: 1.4374 - accuracy: 0.4837 - val_loss: 1.4239 - val_accuracy: 0.4915
Epoch 8/25
196/196 [==============================] - 2s 8ms/step - loss: 1.3900 - accuracy: 0.5026 - val_loss: 1.3184 - val_accuracy: 0.5260
Epoch 9/25
196/196 [==============================] - 2s 8ms/step - loss: 1.3600 - accuracy: 0.5143 - val_loss: 1.2731 - val_accuracy: 0.5496
Epoch 10/25
196/196 [==============================] - 2s 8ms/step - loss: 1.3260 - accuracy: 0.5281 - val_loss: 1.2552 - val_accuracy: 0.5542
Epoch 11/25
196/196 [==============================] - 2s 8ms/step - loss: 1.2938 - accuracy: 0.5387 - val_loss: 1.2242 - val_accuracy: 0.5738
Epoch 12/25
196/196 [==============================] - 2s 9ms/step - loss: 1.2642 - accuracy: 0.5510 - val_loss: 1.2240 - val_accuracy: 0.5596
Epoch 13/25
196/196 [==============================] - 2s 8ms/step - loss: 1.2389 - accuracy: 0.5622 - val_loss: 1.1663 - val_accuracy: 0.5868
Epoch 14/25
196/196 [==============================] - 2s 8ms/step - loss: 1.2110 - accuracy: 0.5711 - val_loss: 1.1312 - val_accuracy: 0.5983
Epoch 15/25
196/196 [==============================] - 2s 8ms/step - loss: 1.1856 - accuracy: 0.5821 - val_loss: 1.1978 - val_accuracy: 0.5730
Epoch 16/25
196/196 [==============================] - 2s 8ms/step - loss: 1.1619 - accuracy: 0.5890 - val_loss: 1.2709 - val_accuracy: 0.5568
Epoch 17/25
196/196 [==============================] - 2s 8ms/step - loss: 1.1430 - accuracy: 0.5975 - val_loss: 1.0918 - val_accuracy: 0.6181
Epoch 18/25
196/196 [==============================] - 2s 8ms/step - loss: 1.1190 - accuracy: 0.6074 - val_loss: 1.0924 - val_accuracy: 0.6148
Epoch 19/25
196/196 [==============================] - 2s 8ms/step - loss: 1.0970 - accuracy: 0.6130 - val_loss: 1.0485 - val_accuracy: 0.6277
Epoch 20/25
196/196 [==============================] - 2s 8ms/step - loss: 1.0830 - accuracy: 0.6191 - val_loss: 1.0675 - val_accuracy: 0.6241
Epoch 21/25
196/196 [==============================] - 2s 8ms/step - loss: 1.0616 - accuracy: 0.6289 - val_loss: 1.0053 - val_accuracy: 0.6540
Epoch 22/25
196/196 [==============================] - 2s 8ms/step - loss: 1.0483 - accuracy: 0.6331 - val_loss: 0.9849 - val_accuracy: 0.6615
Epoch 23/25
196/196 [==============================] - 2s 8ms/step - loss: 1.0322 - accuracy: 0.6364 - val_loss: 0.9753 - val_accuracy: 0.6617
Epoch 24/25
196/196 [==============================] - 2s 8ms/step - loss: 1.0190 - accuracy: 0.6421 - val_loss: 0.9657 - val_accuracy: 0.6639
Epoch 25/25
196/196 [==============================] - 1s 8ms/step - loss: 0.9991 - accuracy: 0.6517 - val_loss: 0.9733 - val_accuracy: 0.6630
CPU times: user 45.4 s, sys: 6.48 s, total: 51.9 s
Wall time: 41.2 s

En una máquina con una GPU Titan V y una CPU Intel Xeon E5-2690, la velocidad es de ~ 1,17x.