![]() | ![]() | ![]() |
นี้รถไฟกวดวิชารูปแบบ TensorFlow ในการจำแนก CIFAR-10 ชุดและเรารวบรวมมันใช้ XLA
โหลดและทำให้ชุดข้อมูลเป็นปกติโดยใช้ Keras API:
import tensorflow as tf
# Check that GPU is available: cf. https://colab.research.google.com/notebooks/gpu.ipynb
assert(tf.test.gpu_device_name())
tf.keras.backend.clear_session()
tf.config.optimizer.set_jit(False) # Start with XLA disabled.
def load_data():
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
x_train = x_train.astype('float32') / 256
x_test = x_test.astype('float32') / 256
# Convert class vectors to binary class matrices.
y_train = tf.keras.utils.to_categorical(y_train, num_classes=10)
y_test = tf.keras.utils.to_categorical(y_test, num_classes=10)
return ((x_train, y_train), (x_test, y_test))
(x_train, y_train), (x_test, y_test) = load_data()
Downloading data from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz 170500096/170498071 [==============================] - 12s 0us/step 170508288/170498071 [==============================] - 12s 0us/step
เรากำหนดรูปแบบที่ดัดแปลงมาจาก Keras CIFAR-10 ตัวอย่างเช่น :
def generate_model():
return tf.keras.models.Sequential([
tf.keras.layers.Conv2D(32, (3, 3), padding='same', input_shape=x_train.shape[1:]),
tf.keras.layers.Activation('relu'),
tf.keras.layers.Conv2D(32, (3, 3)),
tf.keras.layers.Activation('relu'),
tf.keras.layers.MaxPooling2D(pool_size=(2, 2)),
tf.keras.layers.Dropout(0.25),
tf.keras.layers.Conv2D(64, (3, 3), padding='same'),
tf.keras.layers.Activation('relu'),
tf.keras.layers.Conv2D(64, (3, 3)),
tf.keras.layers.Activation('relu'),
tf.keras.layers.MaxPooling2D(pool_size=(2, 2)),
tf.keras.layers.Dropout(0.25),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(512),
tf.keras.layers.Activation('relu'),
tf.keras.layers.Dropout(0.5),
tf.keras.layers.Dense(10),
tf.keras.layers.Activation('softmax')
])
model = generate_model()
เราฝึกอบรมรุ่นโดยใช้ RMSprop เพิ่มประสิทธิภาพ:
def compile_model(model):
opt = tf.keras.optimizers.RMSprop(lr=0.0001, decay=1e-6)
model.compile(loss='categorical_crossentropy',
optimizer=opt,
metrics=['accuracy'])
return model
model = compile_model(model)
def train_model(model, x_train, y_train, x_test, y_test, epochs=25):
model.fit(x_train, y_train, batch_size=256, epochs=epochs, validation_data=(x_test, y_test), shuffle=True)
def warmup(model, x_train, y_train, x_test, y_test):
# Warm up the JIT, we do not wish to measure the compilation time.
initial_weights = model.get_weights()
train_model(model, x_train, y_train, x_test, y_test, epochs=1)
model.set_weights(initial_weights)
warmup(model, x_train, y_train, x_test, y_test)
%time train_model(model, x_train, y_train, x_test, y_test)
scores = model.evaluate(x_test, y_test, verbose=1)
print('Test loss:', scores[0])
print('Test accuracy:', scores[1])
/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/keras/optimizer_v2/optimizer_v2.py:356: UserWarning: The `lr` argument is deprecated, use `learning_rate` instead. "The `lr` argument is deprecated, use `learning_rate` instead.") 196/196 [==============================] - 12s 11ms/step - loss: 2.0710 - accuracy: 0.2378 - val_loss: 1.8439 - val_accuracy: 0.3554 Epoch 1/25 196/196 [==============================] - 2s 10ms/step - loss: 2.1360 - accuracy: 0.2055 - val_loss: 1.9299 - val_accuracy: 0.3314 Epoch 2/25 196/196 [==============================] - 2s 8ms/step - loss: 1.8248 - accuracy: 0.3405 - val_loss: 1.6969 - val_accuracy: 0.3973 Epoch 3/25 196/196 [==============================] - 2s 8ms/step - loss: 1.6864 - accuracy: 0.3944 - val_loss: 1.5874 - val_accuracy: 0.4326 Epoch 4/25 196/196 [==============================] - 2s 8ms/step - loss: 1.5985 - accuracy: 0.4238 - val_loss: 1.5332 - val_accuracy: 0.4401 Epoch 5/25 196/196 [==============================] - 2s 8ms/step - loss: 1.5332 - accuracy: 0.4453 - val_loss: 1.5122 - val_accuracy: 0.4598 Epoch 6/25 196/196 [==============================] - 2s 8ms/step - loss: 1.4860 - accuracy: 0.4627 - val_loss: 1.4261 - val_accuracy: 0.4880 Epoch 7/25 196/196 [==============================] - 2s 8ms/step - loss: 1.4434 - accuracy: 0.4790 - val_loss: 1.3658 - val_accuracy: 0.5058 Epoch 8/25 196/196 [==============================] - 2s 8ms/step - loss: 1.4049 - accuracy: 0.4976 - val_loss: 1.3883 - val_accuracy: 0.5022 Epoch 9/25 196/196 [==============================] - 2s 8ms/step - loss: 1.3727 - accuracy: 0.5110 - val_loss: 1.3145 - val_accuracy: 0.5329 Epoch 10/25 196/196 [==============================] - 2s 8ms/step - loss: 1.3462 - accuracy: 0.5208 - val_loss: 1.2622 - val_accuracy: 0.5503 Epoch 11/25 196/196 [==============================] - 2s 8ms/step - loss: 1.3132 - accuracy: 0.5323 - val_loss: 1.2740 - val_accuracy: 0.5528 Epoch 12/25 196/196 [==============================] - 2s 8ms/step - loss: 1.2875 - accuracy: 0.5431 - val_loss: 1.2296 - val_accuracy: 0.5677 Epoch 13/25 196/196 [==============================] - 2s 8ms/step - loss: 1.2601 - accuracy: 0.5525 - val_loss: 1.3068 - val_accuracy: 0.5353 Epoch 14/25 196/196 [==============================] - 2s 8ms/step - loss: 1.2376 - accuracy: 0.5606 - val_loss: 1.1662 - val_accuracy: 0.5904 Epoch 15/25 196/196 [==============================] - 2s 8ms/step - loss: 1.2112 - accuracy: 0.5708 - val_loss: 1.1504 - val_accuracy: 0.5939 Epoch 16/25 196/196 [==============================] - 2s 8ms/step - loss: 1.1923 - accuracy: 0.5789 - val_loss: 1.1133 - val_accuracy: 0.6125 Epoch 17/25 196/196 [==============================] - 2s 8ms/step - loss: 1.1693 - accuracy: 0.5886 - val_loss: 1.1189 - val_accuracy: 0.6088 Epoch 18/25 196/196 [==============================] - 2s 8ms/step - loss: 1.1498 - accuracy: 0.5938 - val_loss: 1.1080 - val_accuracy: 0.6142 Epoch 19/25 196/196 [==============================] - 2s 8ms/step - loss: 1.1291 - accuracy: 0.6031 - val_loss: 1.0749 - val_accuracy: 0.6290 Epoch 20/25 196/196 [==============================] - 2s 8ms/step - loss: 1.1097 - accuracy: 0.6111 - val_loss: 1.0363 - val_accuracy: 0.6447 Epoch 21/25 196/196 [==============================] - 2s 8ms/step - loss: 1.0928 - accuracy: 0.6161 - val_loss: 1.0340 - val_accuracy: 0.6387 Epoch 22/25 196/196 [==============================] - 2s 8ms/step - loss: 1.0738 - accuracy: 0.6238 - val_loss: 1.0650 - val_accuracy: 0.6244 Epoch 23/25 196/196 [==============================] - 2s 8ms/step - loss: 1.0571 - accuracy: 0.6286 - val_loss: 0.9993 - val_accuracy: 0.6535 Epoch 24/25 196/196 [==============================] - 2s 8ms/step - loss: 1.0470 - accuracy: 0.6325 - val_loss: 0.9925 - val_accuracy: 0.6511 Epoch 25/25 196/196 [==============================] - 2s 8ms/step - loss: 1.0274 - accuracy: 0.6405 - val_loss: 1.0276 - val_accuracy: 0.6387 CPU times: user 49.6 s, sys: 7.41 s, total: 57.1 s Wall time: 40.9 s 313/313 [==============================] - 1s 2ms/step - loss: 1.0276 - accuracy: 0.6387 Test loss: 1.0276497602462769 Test accuracy: 0.638700008392334
ตอนนี้ เรามาฝึกโมเดลกันอีกครั้งโดยใช้คอมไพเลอร์ XLA ในการเปิดใช้คอมไพเลอร์ตรงกลางแอปพลิเคชัน เราต้องรีเซ็ตเซสชัน Keras
# We need to clear the session to enable JIT in the middle of the program.
tf.keras.backend.clear_session()
tf.config.optimizer.set_jit(True) # Enable XLA.
model = compile_model(generate_model())
(x_train, y_train), (x_test, y_test) = load_data()
warmup(model, x_train, y_train, x_test, y_test)
%time train_model(model, x_train, y_train, x_test, y_test)
196/196 [==============================] - 6s 14ms/step - loss: 2.0028 - accuracy: 0.2637 - val_loss: 1.7774 - val_accuracy: 0.3784 Epoch 1/25 196/196 [==============================] - 4s 19ms/step - loss: 2.1142 - accuracy: 0.2136 - val_loss: 1.8591 - val_accuracy: 0.3303 Epoch 2/25 196/196 [==============================] - 2s 8ms/step - loss: 1.7614 - accuracy: 0.3622 - val_loss: 1.6321 - val_accuracy: 0.4149 Epoch 3/25 196/196 [==============================] - 2s 8ms/step - loss: 1.6471 - accuracy: 0.3991 - val_loss: 1.5480 - val_accuracy: 0.4379 Epoch 4/25 196/196 [==============================] - 2s 8ms/step - loss: 1.5706 - accuracy: 0.4303 - val_loss: 1.4671 - val_accuracy: 0.4691 Epoch 5/25 196/196 [==============================] - 2s 8ms/step - loss: 1.5098 - accuracy: 0.4514 - val_loss: 1.4270 - val_accuracy: 0.4858 Epoch 6/25 196/196 [==============================] - 2s 8ms/step - loss: 1.4561 - accuracy: 0.4734 - val_loss: 1.4270 - val_accuracy: 0.4919 Epoch 7/25 196/196 [==============================] - 2s 8ms/step - loss: 1.4199 - accuracy: 0.4906 - val_loss: 1.3290 - val_accuracy: 0.5272 Epoch 8/25 196/196 [==============================] - 2s 8ms/step - loss: 1.3831 - accuracy: 0.5049 - val_loss: 1.3550 - val_accuracy: 0.5205 Epoch 9/25 196/196 [==============================] - 2s 8ms/step - loss: 1.3487 - accuracy: 0.5176 - val_loss: 1.3339 - val_accuracy: 0.5240 Epoch 10/25 196/196 [==============================] - 2s 8ms/step - loss: 1.3190 - accuracy: 0.5275 - val_loss: 1.2579 - val_accuracy: 0.5528 Epoch 11/25 196/196 [==============================] - 2s 8ms/step - loss: 1.2930 - accuracy: 0.5415 - val_loss: 1.2364 - val_accuracy: 0.5654 Epoch 12/25 196/196 [==============================] - 2s 8ms/step - loss: 1.2656 - accuracy: 0.5499 - val_loss: 1.2646 - val_accuracy: 0.5561 Epoch 13/25 196/196 [==============================] - 2s 8ms/step - loss: 1.2416 - accuracy: 0.5614 - val_loss: 1.2615 - val_accuracy: 0.5540 Epoch 14/25 196/196 [==============================] - 2s 8ms/step - loss: 1.2196 - accuracy: 0.5681 - val_loss: 1.1970 - val_accuracy: 0.5797 Epoch 15/25 196/196 [==============================] - 2s 8ms/step - loss: 1.1996 - accuracy: 0.5751 - val_loss: 1.1142 - val_accuracy: 0.6099 Epoch 16/25 196/196 [==============================] - 2s 8ms/step - loss: 1.1777 - accuracy: 0.5849 - val_loss: 1.2225 - val_accuracy: 0.5653 Epoch 17/25 196/196 [==============================] - 2s 8ms/step - loss: 1.1542 - accuracy: 0.5912 - val_loss: 1.0860 - val_accuracy: 0.6187 Epoch 18/25 196/196 [==============================] - 2s 8ms/step - loss: 1.1358 - accuracy: 0.6009 - val_loss: 1.0767 - val_accuracy: 0.6180 Epoch 19/25 196/196 [==============================] - 2s 8ms/step - loss: 1.1197 - accuracy: 0.6062 - val_loss: 1.0517 - val_accuracy: 0.6318 Epoch 20/25 196/196 [==============================] - 2s 8ms/step - loss: 1.0980 - accuracy: 0.6139 - val_loss: 1.0362 - val_accuracy: 0.6390 Epoch 21/25 196/196 [==============================] - 2s 8ms/step - loss: 1.0738 - accuracy: 0.6233 - val_loss: 1.0777 - val_accuracy: 0.6256 Epoch 22/25 196/196 [==============================] - 2s 8ms/step - loss: 1.0555 - accuracy: 0.6285 - val_loss: 1.0615 - val_accuracy: 0.6353 Epoch 23/25 196/196 [==============================] - 2s 8ms/step - loss: 1.0434 - accuracy: 0.6331 - val_loss: 1.0025 - val_accuracy: 0.6498 Epoch 24/25 196/196 [==============================] - 2s 8ms/step - loss: 1.0285 - accuracy: 0.6386 - val_loss: 0.9670 - val_accuracy: 0.6614 Epoch 25/25 196/196 [==============================] - 2s 8ms/step - loss: 1.0116 - accuracy: 0.6443 - val_loss: 0.9806 - val_accuracy: 0.6576 CPU times: user 43.8 s, sys: 6.15 s, total: 49.9 s Wall time: 42.6 s
สำหรับเครื่องที่มี GPU Titan V และ CPU Intel Xeon E5-2690 ความเร็วในการทำงานอยู่ที่ ~1.17x