इस पेज का अनुवाद Cloud Translation API से किया गया है.
Switch to English

XLA के साथ CIFAR-10 का वर्गीकरण

TensorFlow.org पर देखें Google Colab में चलाएं GitHub पर स्रोत देखें

यह ट्यूटोरियल CIFAR-10 डेटासेट को वर्गीकृत करने के लिए एक TensorFlow मॉडल को प्रशिक्षित करता है, और हम इसे XLA का उपयोग करके संकलित करते हैं।

Keras API का उपयोग करके डेटासेट लोड और सामान्य करें:

import tensorflow as tf

# Check that GPU is available: cf. https://colab.research.google.com/notebooks/gpu.ipynb
assert(tf.test.is_gpu_available())

tf.keras.backend.clear_session()
tf.config.optimizer.set_jit(False) # Start with XLA disabled.

def load_data():
  (x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
  x_train = x_train.astype('float32') / 256
  x_test = x_test.astype('float32') / 256

  # Convert class vectors to binary class matrices.
  y_train = tf.keras.utils.to_categorical(y_train, num_classes=10)
  y_test = tf.keras.utils.to_categorical(y_test, num_classes=10)
  return ((x_train, y_train), (x_test, y_test))

(x_train, y_train), (x_test, y_test) = load_data()
WARNING:tensorflow:From <ipython-input-1-f643ac1e83e4>:4: is_gpu_available (from tensorflow.python.framework.test_util) is deprecated and will be removed in a future version.
Instructions for updating:
Use `tf.config.list_physical_devices('GPU')` instead.
Downloading data from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz
170500096/170498071 [==============================] - 11s 0us/step

हम मॉडल को परिभाषित करते हैं, केरस CIFAR-10 उदाहरण से अनुकूलित:

def generate_model():
  return tf.keras.models.Sequential([
    tf.keras.layers.Conv2D(32, (3, 3), padding='same', input_shape=x_train.shape[1:]),
    tf.keras.layers.Activation('relu'),
    tf.keras.layers.Conv2D(32, (3, 3)),
    tf.keras.layers.Activation('relu'),
    tf.keras.layers.MaxPooling2D(pool_size=(2, 2)),
    tf.keras.layers.Dropout(0.25),

    tf.keras.layers.Conv2D(64, (3, 3), padding='same'),
    tf.keras.layers.Activation('relu'),
    tf.keras.layers.Conv2D(64, (3, 3)),
    tf.keras.layers.Activation('relu'),
    tf.keras.layers.MaxPooling2D(pool_size=(2, 2)),
    tf.keras.layers.Dropout(0.25),

    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(512),
    tf.keras.layers.Activation('relu'),
    tf.keras.layers.Dropout(0.5),
    tf.keras.layers.Dense(10),
    tf.keras.layers.Activation('softmax')
  ])

model = generate_model()

हम RMSprop अनुकूलक का उपयोग करके मॉडल को प्रशिक्षित करते हैं:

def compile_model(model):
  opt = tf.keras.optimizers.RMSprop(lr=0.0001, decay=1e-6)
  model.compile(loss='categorical_crossentropy',
                optimizer=opt,
                metrics=['accuracy'])
  return model

model = compile_model(model)

def train_model(model, x_train, y_train, x_test, y_test, epochs=25):
  model.fit(x_train, y_train, batch_size=256, epochs=epochs, validation_data=(x_test, y_test), shuffle=True)

def warmup(model, x_train, y_train, x_test, y_test):
  # Warm up the JIT, we do not wish to measure the compilation time.
  initial_weights = model.get_weights()
  train_model(model, x_train, y_train, x_test, y_test, epochs=1)
  model.set_weights(initial_weights)

warmup(model, x_train, y_train, x_test, y_test)
train_model(model, x_train, y_train, x_test, y_test)

scores = model.evaluate(x_test, y_test, verbose=1)
print('Test loss:', scores[0])
print('Test accuracy:', scores[1])
196/196 [==============================] - 2s 10ms/step - loss: 2.0271 - accuracy: 0.2556 - val_loss: 1.8257 - val_accuracy: 0.3530
Epoch 1/25
196/196 [==============================] - 2s 9ms/step - loss: 2.1027 - accuracy: 0.2207 - val_loss: 1.8864 - val_accuracy: 0.3285
Epoch 2/25
196/196 [==============================] - 2s 8ms/step - loss: 1.8008 - accuracy: 0.3525 - val_loss: 1.6833 - val_accuracy: 0.4039
Epoch 3/25
196/196 [==============================] - 2s 8ms/step - loss: 1.6928 - accuracy: 0.3898 - val_loss: 1.6162 - val_accuracy: 0.4227
Epoch 4/25
196/196 [==============================] - 2s 8ms/step - loss: 1.6204 - accuracy: 0.4132 - val_loss: 1.5381 - val_accuracy: 0.4375
Epoch 5/25
196/196 [==============================] - 2s 8ms/step - loss: 1.5642 - accuracy: 0.4350 - val_loss: 1.4725 - val_accuracy: 0.4622
Epoch 6/25
196/196 [==============================] - 2s 8ms/step - loss: 1.5066 - accuracy: 0.4535 - val_loss: 1.4279 - val_accuracy: 0.4832
Epoch 7/25
196/196 [==============================] - 2s 8ms/step - loss: 1.4529 - accuracy: 0.4767 - val_loss: 1.3504 - val_accuracy: 0.5118
Epoch 8/25
196/196 [==============================] - 2s 8ms/step - loss: 1.4064 - accuracy: 0.4942 - val_loss: 1.4257 - val_accuracy: 0.5027
Epoch 9/25
196/196 [==============================] - 2s 8ms/step - loss: 1.3686 - accuracy: 0.5101 - val_loss: 1.3362 - val_accuracy: 0.5258
Epoch 10/25
196/196 [==============================] - 2s 8ms/step - loss: 1.3322 - accuracy: 0.5230 - val_loss: 1.3372 - val_accuracy: 0.5292
Epoch 11/25
196/196 [==============================] - 2s 8ms/step - loss: 1.3061 - accuracy: 0.5333 - val_loss: 1.2191 - val_accuracy: 0.5691
Epoch 12/25
196/196 [==============================] - 2s 8ms/step - loss: 1.2757 - accuracy: 0.5425 - val_loss: 1.2260 - val_accuracy: 0.5709
Epoch 13/25
196/196 [==============================] - 2s 8ms/step - loss: 1.2515 - accuracy: 0.5543 - val_loss: 1.1758 - val_accuracy: 0.5888
Epoch 14/25
196/196 [==============================] - 2s 8ms/step - loss: 1.2234 - accuracy: 0.5648 - val_loss: 1.1960 - val_accuracy: 0.5795
Epoch 15/25
196/196 [==============================] - 2s 8ms/step - loss: 1.2087 - accuracy: 0.5715 - val_loss: 1.1267 - val_accuracy: 0.6048
Epoch 16/25
196/196 [==============================] - 2s 8ms/step - loss: 1.1799 - accuracy: 0.5812 - val_loss: 1.1146 - val_accuracy: 0.6117
Epoch 17/25
196/196 [==============================] - 2s 8ms/step - loss: 1.1560 - accuracy: 0.5892 - val_loss: 1.1543 - val_accuracy: 0.5996
Epoch 18/25
196/196 [==============================] - 2s 8ms/step - loss: 1.1375 - accuracy: 0.5971 - val_loss: 1.0818 - val_accuracy: 0.6250
Epoch 19/25
196/196 [==============================] - 2s 8ms/step - loss: 1.1155 - accuracy: 0.6073 - val_loss: 1.0693 - val_accuracy: 0.6275
Epoch 20/25
196/196 [==============================] - 2s 8ms/step - loss: 1.1018 - accuracy: 0.6135 - val_loss: 1.0532 - val_accuracy: 0.6338
Epoch 21/25
196/196 [==============================] - 2s 8ms/step - loss: 1.0821 - accuracy: 0.6203 - val_loss: 1.0163 - val_accuracy: 0.6450
Epoch 22/25
196/196 [==============================] - 2s 8ms/step - loss: 1.0656 - accuracy: 0.6237 - val_loss: 1.0094 - val_accuracy: 0.6496
Epoch 23/25
196/196 [==============================] - 2s 9ms/step - loss: 1.0470 - accuracy: 0.6311 - val_loss: 0.9986 - val_accuracy: 0.6573
Epoch 24/25
196/196 [==============================] - 2s 9ms/step - loss: 1.0308 - accuracy: 0.6372 - val_loss: 0.9791 - val_accuracy: 0.6579
Epoch 25/25
196/196 [==============================] - 2s 9ms/step - loss: 1.0156 - accuracy: 0.6449 - val_loss: 0.9698 - val_accuracy: 0.6635
313/313 [==============================] - 1s 3ms/step - loss: 0.9698 - accuracy: 0.6635
Test loss: 0.969793438911438
Test accuracy: 0.6635000109672546

अब आइए XLA संकलक का उपयोग करके मॉडल को फिर से प्रशिक्षित करें। कंपाइलर को एप्लिकेशन के बीच में सक्षम करने के लिए, हमें केरस सत्र को रीसेट करना होगा।

# We need to clear the session to enable JIT in the middle of the program.
tf.keras.backend.clear_session()
tf.config.optimizer.set_jit(True) # Enable XLA.
model = compile_model(generate_model())
(x_train, y_train), (x_test, y_test) = load_data()

warmup(model, x_train, y_train, x_test, y_test)
%time train_model(model, x_train, y_train, x_test, y_test)
196/196 [==============================] - 3s 13ms/step - loss: 2.0589 - accuracy: 0.2391 - val_loss: 1.8392 - val_accuracy: 0.3439
Epoch 1/25
196/196 [==============================] - 3s 17ms/step - loss: 2.1530 - accuracy: 0.1975 - val_loss: 1.8981 - val_accuracy: 0.3395
Epoch 2/25
196/196 [==============================] - 1s 7ms/step - loss: 1.8210 - accuracy: 0.3430 - val_loss: 1.7153 - val_accuracy: 0.3904
Epoch 3/25
196/196 [==============================] - 1s 7ms/step - loss: 1.6835 - accuracy: 0.3884 - val_loss: 1.5915 - val_accuracy: 0.4194
Epoch 4/25
196/196 [==============================] - 1s 7ms/step - loss: 1.5976 - accuracy: 0.4196 - val_loss: 1.4913 - val_accuracy: 0.4648
Epoch 5/25
196/196 [==============================] - 1s 7ms/step - loss: 1.5304 - accuracy: 0.4448 - val_loss: 1.4355 - val_accuracy: 0.4801
Epoch 6/25
196/196 [==============================] - 1s 7ms/step - loss: 1.4768 - accuracy: 0.4661 - val_loss: 1.4313 - val_accuracy: 0.4836
Epoch 7/25
196/196 [==============================] - 1s 7ms/step - loss: 1.4299 - accuracy: 0.4828 - val_loss: 1.3828 - val_accuracy: 0.5031
Epoch 8/25
196/196 [==============================] - 1s 7ms/step - loss: 1.3932 - accuracy: 0.4985 - val_loss: 1.3157 - val_accuracy: 0.5274
Epoch 9/25
196/196 [==============================] - 1s 7ms/step - loss: 1.3571 - accuracy: 0.5150 - val_loss: 1.3079 - val_accuracy: 0.5299
Epoch 10/25
196/196 [==============================] - 1s 7ms/step - loss: 1.3325 - accuracy: 0.5222 - val_loss: 1.2482 - val_accuracy: 0.5544
Epoch 11/25
196/196 [==============================] - 1s 7ms/step - loss: 1.2994 - accuracy: 0.5359 - val_loss: 1.2617 - val_accuracy: 0.5510
Epoch 12/25
196/196 [==============================] - 1s 7ms/step - loss: 1.2733 - accuracy: 0.5451 - val_loss: 1.2264 - val_accuracy: 0.5624
Epoch 13/25
196/196 [==============================] - 1s 7ms/step - loss: 1.2478 - accuracy: 0.5568 - val_loss: 1.2450 - val_accuracy: 0.5644
Epoch 14/25
196/196 [==============================] - 1s 7ms/step - loss: 1.2223 - accuracy: 0.5649 - val_loss: 1.1971 - val_accuracy: 0.5785
Epoch 15/25
196/196 [==============================] - 1s 7ms/step - loss: 1.2038 - accuracy: 0.5730 - val_loss: 1.1459 - val_accuracy: 0.6017
Epoch 16/25
196/196 [==============================] - 1s 7ms/step - loss: 1.1826 - accuracy: 0.5801 - val_loss: 1.1202 - val_accuracy: 0.6047
Epoch 17/25
196/196 [==============================] - 1s 7ms/step - loss: 1.1602 - accuracy: 0.5886 - val_loss: 1.1146 - val_accuracy: 0.6131
Epoch 18/25
196/196 [==============================] - 1s 7ms/step - loss: 1.1402 - accuracy: 0.5976 - val_loss: 1.0740 - val_accuracy: 0.6285
Epoch 19/25
196/196 [==============================] - 1s 7ms/step - loss: 1.1202 - accuracy: 0.6036 - val_loss: 1.0712 - val_accuracy: 0.6196
Epoch 20/25
196/196 [==============================] - 1s 7ms/step - loss: 1.1067 - accuracy: 0.6093 - val_loss: 1.0343 - val_accuracy: 0.6418
Epoch 21/25
196/196 [==============================] - 1s 7ms/step - loss: 1.0867 - accuracy: 0.6178 - val_loss: 1.0376 - val_accuracy: 0.6425
Epoch 22/25
196/196 [==============================] - 1s 7ms/step - loss: 1.0667 - accuracy: 0.6257 - val_loss: 0.9948 - val_accuracy: 0.6555
Epoch 23/25
196/196 [==============================] - 1s 7ms/step - loss: 1.0488 - accuracy: 0.6300 - val_loss: 1.0192 - val_accuracy: 0.6433
Epoch 24/25
196/196 [==============================] - 1s 7ms/step - loss: 1.0300 - accuracy: 0.6375 - val_loss: 0.9870 - val_accuracy: 0.6576
Epoch 25/25
196/196 [==============================] - 1s 7ms/step - loss: 1.0175 - accuracy: 0.6421 - val_loss: 0.9753 - val_accuracy: 0.6580
CPU times: user 44.2 s, sys: 6.08 s, total: 50.3 s
Wall time: 38.9 s

Titan V GPU और Intel Xeon E5-2690 CPU वाली मशीन पर स्पीड ~ ~ 17x है।