نشكرك على متابعة Google I / O. عرض جميع الجلسات عند الطلب مشاهدة عند الطلب

تصنيف CIFAR-10 مع XLA

عرض على TensorFlow.org تشغيل في Google Colab عرض المصدر على جيثب

هذا البرنامج التعليمي القطارات نموذج TensorFlow لتصنيف CIFAR 10 مجموعة البيانات، ونحن ترجمة باستخدام XLA.

قم بتحميل وتطبيع مجموعة البيانات باستخدام Keras API:

import tensorflow as tf

# Check that GPU is available: cf. https://colab.research.google.com/notebooks/gpu.ipynb
assert(tf.test.gpu_device_name())

tf.keras.backend.clear_session()
tf.config.optimizer.set_jit(False) # Start with XLA disabled.

def load_data():
  (x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
  x_train = x_train.astype('float32') / 256
  x_test = x_test.astype('float32') / 256

  # Convert class vectors to binary class matrices.
  y_train = tf.keras.utils.to_categorical(y_train, num_classes=10)
  y_test = tf.keras.utils.to_categorical(y_test, num_classes=10)
  return ((x_train, y_train), (x_test, y_test))

(x_train, y_train), (x_test, y_test) = load_data()
Downloading data from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz
170500096/170498071 [==============================] - 12s 0us/step
170508288/170498071 [==============================] - 12s 0us/step

ونحن نعرف هذا النموذج، مقتبسة من Keras سبيل المثال CIFAR 10 :

def generate_model():
  return tf.keras.models.Sequential([
    tf.keras.layers.Conv2D(32, (3, 3), padding='same', input_shape=x_train.shape[1:]),
    tf.keras.layers.Activation('relu'),
    tf.keras.layers.Conv2D(32, (3, 3)),
    tf.keras.layers.Activation('relu'),
    tf.keras.layers.MaxPooling2D(pool_size=(2, 2)),
    tf.keras.layers.Dropout(0.25),

    tf.keras.layers.Conv2D(64, (3, 3), padding='same'),
    tf.keras.layers.Activation('relu'),
    tf.keras.layers.Conv2D(64, (3, 3)),
    tf.keras.layers.Activation('relu'),
    tf.keras.layers.MaxPooling2D(pool_size=(2, 2)),
    tf.keras.layers.Dropout(0.25),

    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(512),
    tf.keras.layers.Activation('relu'),
    tf.keras.layers.Dropout(0.5),
    tf.keras.layers.Dense(10),
    tf.keras.layers.Activation('softmax')
  ])

model = generate_model()

ندرب نموذج باستخدام RMSprop محسن:

def compile_model(model):
  opt = tf.keras.optimizers.RMSprop(lr=0.0001, decay=1e-6)
  model.compile(loss='categorical_crossentropy',
                optimizer=opt,
                metrics=['accuracy'])
  return model

model = compile_model(model)

def train_model(model, x_train, y_train, x_test, y_test, epochs=25):
  model.fit(x_train, y_train, batch_size=256, epochs=epochs, validation_data=(x_test, y_test), shuffle=True)

def warmup(model, x_train, y_train, x_test, y_test):
  # Warm up the JIT, we do not wish to measure the compilation time.
  initial_weights = model.get_weights()
  train_model(model, x_train, y_train, x_test, y_test, epochs=1)
  model.set_weights(initial_weights)

warmup(model, x_train, y_train, x_test, y_test)
%time train_model(model, x_train, y_train, x_test, y_test)

scores = model.evaluate(x_test, y_test, verbose=1)
print('Test loss:', scores[0])
print('Test accuracy:', scores[1])
/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/keras/optimizer_v2/optimizer_v2.py:356: UserWarning: The `lr` argument is deprecated, use `learning_rate` instead.
  "The `lr` argument is deprecated, use `learning_rate` instead.")
196/196 [==============================] - 12s 11ms/step - loss: 2.0710 - accuracy: 0.2378 - val_loss: 1.8439 - val_accuracy: 0.3554
Epoch 1/25
196/196 [==============================] - 2s 10ms/step - loss: 2.1360 - accuracy: 0.2055 - val_loss: 1.9299 - val_accuracy: 0.3314
Epoch 2/25
196/196 [==============================] - 2s 8ms/step - loss: 1.8248 - accuracy: 0.3405 - val_loss: 1.6969 - val_accuracy: 0.3973
Epoch 3/25
196/196 [==============================] - 2s 8ms/step - loss: 1.6864 - accuracy: 0.3944 - val_loss: 1.5874 - val_accuracy: 0.4326
Epoch 4/25
196/196 [==============================] - 2s 8ms/step - loss: 1.5985 - accuracy: 0.4238 - val_loss: 1.5332 - val_accuracy: 0.4401
Epoch 5/25
196/196 [==============================] - 2s 8ms/step - loss: 1.5332 - accuracy: 0.4453 - val_loss: 1.5122 - val_accuracy: 0.4598
Epoch 6/25
196/196 [==============================] - 2s 8ms/step - loss: 1.4860 - accuracy: 0.4627 - val_loss: 1.4261 - val_accuracy: 0.4880
Epoch 7/25
196/196 [==============================] - 2s 8ms/step - loss: 1.4434 - accuracy: 0.4790 - val_loss: 1.3658 - val_accuracy: 0.5058
Epoch 8/25
196/196 [==============================] - 2s 8ms/step - loss: 1.4049 - accuracy: 0.4976 - val_loss: 1.3883 - val_accuracy: 0.5022
Epoch 9/25
196/196 [==============================] - 2s 8ms/step - loss: 1.3727 - accuracy: 0.5110 - val_loss: 1.3145 - val_accuracy: 0.5329
Epoch 10/25
196/196 [==============================] - 2s 8ms/step - loss: 1.3462 - accuracy: 0.5208 - val_loss: 1.2622 - val_accuracy: 0.5503
Epoch 11/25
196/196 [==============================] - 2s 8ms/step - loss: 1.3132 - accuracy: 0.5323 - val_loss: 1.2740 - val_accuracy: 0.5528
Epoch 12/25
196/196 [==============================] - 2s 8ms/step - loss: 1.2875 - accuracy: 0.5431 - val_loss: 1.2296 - val_accuracy: 0.5677
Epoch 13/25
196/196 [==============================] - 2s 8ms/step - loss: 1.2601 - accuracy: 0.5525 - val_loss: 1.3068 - val_accuracy: 0.5353
Epoch 14/25
196/196 [==============================] - 2s 8ms/step - loss: 1.2376 - accuracy: 0.5606 - val_loss: 1.1662 - val_accuracy: 0.5904
Epoch 15/25
196/196 [==============================] - 2s 8ms/step - loss: 1.2112 - accuracy: 0.5708 - val_loss: 1.1504 - val_accuracy: 0.5939
Epoch 16/25
196/196 [==============================] - 2s 8ms/step - loss: 1.1923 - accuracy: 0.5789 - val_loss: 1.1133 - val_accuracy: 0.6125
Epoch 17/25
196/196 [==============================] - 2s 8ms/step - loss: 1.1693 - accuracy: 0.5886 - val_loss: 1.1189 - val_accuracy: 0.6088
Epoch 18/25
196/196 [==============================] - 2s 8ms/step - loss: 1.1498 - accuracy: 0.5938 - val_loss: 1.1080 - val_accuracy: 0.6142
Epoch 19/25
196/196 [==============================] - 2s 8ms/step - loss: 1.1291 - accuracy: 0.6031 - val_loss: 1.0749 - val_accuracy: 0.6290
Epoch 20/25
196/196 [==============================] - 2s 8ms/step - loss: 1.1097 - accuracy: 0.6111 - val_loss: 1.0363 - val_accuracy: 0.6447
Epoch 21/25
196/196 [==============================] - 2s 8ms/step - loss: 1.0928 - accuracy: 0.6161 - val_loss: 1.0340 - val_accuracy: 0.6387
Epoch 22/25
196/196 [==============================] - 2s 8ms/step - loss: 1.0738 - accuracy: 0.6238 - val_loss: 1.0650 - val_accuracy: 0.6244
Epoch 23/25
196/196 [==============================] - 2s 8ms/step - loss: 1.0571 - accuracy: 0.6286 - val_loss: 0.9993 - val_accuracy: 0.6535
Epoch 24/25
196/196 [==============================] - 2s 8ms/step - loss: 1.0470 - accuracy: 0.6325 - val_loss: 0.9925 - val_accuracy: 0.6511
Epoch 25/25
196/196 [==============================] - 2s 8ms/step - loss: 1.0274 - accuracy: 0.6405 - val_loss: 1.0276 - val_accuracy: 0.6387
CPU times: user 49.6 s, sys: 7.41 s, total: 57.1 s
Wall time: 40.9 s
313/313 [==============================] - 1s 2ms/step - loss: 1.0276 - accuracy: 0.6387
Test loss: 1.0276497602462769
Test accuracy: 0.638700008392334

الآن دعنا ندرب النموذج مرة أخرى ، باستخدام مترجم XLA. لتمكين المترجم في منتصف التطبيق ، نحتاج إلى إعادة تعيين جلسة Keras.

# We need to clear the session to enable JIT in the middle of the program.
tf.keras.backend.clear_session()
tf.config.optimizer.set_jit(True) # Enable XLA.
model = compile_model(generate_model())
(x_train, y_train), (x_test, y_test) = load_data()

warmup(model, x_train, y_train, x_test, y_test)
%time train_model(model, x_train, y_train, x_test, y_test)
196/196 [==============================] - 6s 14ms/step - loss: 2.0028 - accuracy: 0.2637 - val_loss: 1.7774 - val_accuracy: 0.3784
Epoch 1/25
196/196 [==============================] - 4s 19ms/step - loss: 2.1142 - accuracy: 0.2136 - val_loss: 1.8591 - val_accuracy: 0.3303
Epoch 2/25
196/196 [==============================] - 2s 8ms/step - loss: 1.7614 - accuracy: 0.3622 - val_loss: 1.6321 - val_accuracy: 0.4149
Epoch 3/25
196/196 [==============================] - 2s 8ms/step - loss: 1.6471 - accuracy: 0.3991 - val_loss: 1.5480 - val_accuracy: 0.4379
Epoch 4/25
196/196 [==============================] - 2s 8ms/step - loss: 1.5706 - accuracy: 0.4303 - val_loss: 1.4671 - val_accuracy: 0.4691
Epoch 5/25
196/196 [==============================] - 2s 8ms/step - loss: 1.5098 - accuracy: 0.4514 - val_loss: 1.4270 - val_accuracy: 0.4858
Epoch 6/25
196/196 [==============================] - 2s 8ms/step - loss: 1.4561 - accuracy: 0.4734 - val_loss: 1.4270 - val_accuracy: 0.4919
Epoch 7/25
196/196 [==============================] - 2s 8ms/step - loss: 1.4199 - accuracy: 0.4906 - val_loss: 1.3290 - val_accuracy: 0.5272
Epoch 8/25
196/196 [==============================] - 2s 8ms/step - loss: 1.3831 - accuracy: 0.5049 - val_loss: 1.3550 - val_accuracy: 0.5205
Epoch 9/25
196/196 [==============================] - 2s 8ms/step - loss: 1.3487 - accuracy: 0.5176 - val_loss: 1.3339 - val_accuracy: 0.5240
Epoch 10/25
196/196 [==============================] - 2s 8ms/step - loss: 1.3190 - accuracy: 0.5275 - val_loss: 1.2579 - val_accuracy: 0.5528
Epoch 11/25
196/196 [==============================] - 2s 8ms/step - loss: 1.2930 - accuracy: 0.5415 - val_loss: 1.2364 - val_accuracy: 0.5654
Epoch 12/25
196/196 [==============================] - 2s 8ms/step - loss: 1.2656 - accuracy: 0.5499 - val_loss: 1.2646 - val_accuracy: 0.5561
Epoch 13/25
196/196 [==============================] - 2s 8ms/step - loss: 1.2416 - accuracy: 0.5614 - val_loss: 1.2615 - val_accuracy: 0.5540
Epoch 14/25
196/196 [==============================] - 2s 8ms/step - loss: 1.2196 - accuracy: 0.5681 - val_loss: 1.1970 - val_accuracy: 0.5797
Epoch 15/25
196/196 [==============================] - 2s 8ms/step - loss: 1.1996 - accuracy: 0.5751 - val_loss: 1.1142 - val_accuracy: 0.6099
Epoch 16/25
196/196 [==============================] - 2s 8ms/step - loss: 1.1777 - accuracy: 0.5849 - val_loss: 1.2225 - val_accuracy: 0.5653
Epoch 17/25
196/196 [==============================] - 2s 8ms/step - loss: 1.1542 - accuracy: 0.5912 - val_loss: 1.0860 - val_accuracy: 0.6187
Epoch 18/25
196/196 [==============================] - 2s 8ms/step - loss: 1.1358 - accuracy: 0.6009 - val_loss: 1.0767 - val_accuracy: 0.6180
Epoch 19/25
196/196 [==============================] - 2s 8ms/step - loss: 1.1197 - accuracy: 0.6062 - val_loss: 1.0517 - val_accuracy: 0.6318
Epoch 20/25
196/196 [==============================] - 2s 8ms/step - loss: 1.0980 - accuracy: 0.6139 - val_loss: 1.0362 - val_accuracy: 0.6390
Epoch 21/25
196/196 [==============================] - 2s 8ms/step - loss: 1.0738 - accuracy: 0.6233 - val_loss: 1.0777 - val_accuracy: 0.6256
Epoch 22/25
196/196 [==============================] - 2s 8ms/step - loss: 1.0555 - accuracy: 0.6285 - val_loss: 1.0615 - val_accuracy: 0.6353
Epoch 23/25
196/196 [==============================] - 2s 8ms/step - loss: 1.0434 - accuracy: 0.6331 - val_loss: 1.0025 - val_accuracy: 0.6498
Epoch 24/25
196/196 [==============================] - 2s 8ms/step - loss: 1.0285 - accuracy: 0.6386 - val_loss: 0.9670 - val_accuracy: 0.6614
Epoch 25/25
196/196 [==============================] - 2s 8ms/step - loss: 1.0116 - accuracy: 0.6443 - val_loss: 0.9806 - val_accuracy: 0.6576
CPU times: user 43.8 s, sys: 6.15 s, total: 49.9 s
Wall time: 42.6 s

على جهاز به وحدة معالجة رسومات Titan V ووحدة المعالجة المركزية Intel Xeon E5-2690 ، تكون السرعة ~ 1.17x.