CIFAR-10'un XLA ile Sınıflandırılması

Bu öğretici trenler bir TensorFlow modeli sınıflandırmak için cifar-10 veri kümesi ve biz XLA kullanarak derlemek.

Keras API'sini kullanarak veri kümesini yükleyin ve normalleştirin:

``````import tensorflow as tf

# Check that GPU is available: cf. https://colab.research.google.com/notebooks/gpu.ipynb
assert(tf.test.gpu_device_name())

tf.keras.backend.clear_session()

(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
x_train = x_train.astype('float32') / 256
x_test = x_test.astype('float32') / 256

# Convert class vectors to binary class matrices.
y_train = tf.keras.utils.to_categorical(y_train, num_classes=10)
y_test = tf.keras.utils.to_categorical(y_test, num_classes=10)
return ((x_train, y_train), (x_test, y_test))

(x_train, y_train), (x_test, y_test) = load_data()
``````
```Downloading data from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz
170500096/170498071 [==============================] - 12s 0us/step
170508288/170498071 [==============================] - 12s 0us/step
```

Biz Keras uyarlanmıştır modeli tanımlamak cifar-10, örneğin :

``````def generate_model():
return tf.keras.models.Sequential([
tf.keras.layers.Activation('relu'),
tf.keras.layers.Conv2D(32, (3, 3)),
tf.keras.layers.Activation('relu'),
tf.keras.layers.MaxPooling2D(pool_size=(2, 2)),
tf.keras.layers.Dropout(0.25),

tf.keras.layers.Activation('relu'),
tf.keras.layers.Conv2D(64, (3, 3)),
tf.keras.layers.Activation('relu'),
tf.keras.layers.MaxPooling2D(pool_size=(2, 2)),
tf.keras.layers.Dropout(0.25),

tf.keras.layers.Flatten(),
tf.keras.layers.Dense(512),
tf.keras.layers.Activation('relu'),
tf.keras.layers.Dropout(0.5),
tf.keras.layers.Dense(10),
tf.keras.layers.Activation('softmax')
])

model = generate_model()
``````

Biz kullanarak modeli eğitmek RMSprop optimize edici:

``````def compile_model(model):
opt = tf.keras.optimizers.RMSprop(lr=0.0001, decay=1e-6)
model.compile(loss='categorical_crossentropy',
optimizer=opt,
metrics=['accuracy'])
return model

model = compile_model(model)

def train_model(model, x_train, y_train, x_test, y_test, epochs=25):
model.fit(x_train, y_train, batch_size=256, epochs=epochs, validation_data=(x_test, y_test), shuffle=True)

def warmup(model, x_train, y_train, x_test, y_test):
# Warm up the JIT, we do not wish to measure the compilation time.
initial_weights = model.get_weights()
train_model(model, x_train, y_train, x_test, y_test, epochs=1)
model.set_weights(initial_weights)

warmup(model, x_train, y_train, x_test, y_test)
%time train_model(model, x_train, y_train, x_test, y_test)

scores = model.evaluate(x_test, y_test, verbose=1)
print('Test loss:', scores[0])
print('Test accuracy:', scores[1])
``````
```/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/keras/optimizer_v2/optimizer_v2.py:356: UserWarning: The `lr` argument is deprecated, use `learning_rate` instead.
"The `lr` argument is deprecated, use `learning_rate` instead.")
196/196 [==============================] - 12s 11ms/step - loss: 2.0710 - accuracy: 0.2378 - val_loss: 1.8439 - val_accuracy: 0.3554
Epoch 1/25
196/196 [==============================] - 2s 10ms/step - loss: 2.1360 - accuracy: 0.2055 - val_loss: 1.9299 - val_accuracy: 0.3314
Epoch 2/25
196/196 [==============================] - 2s 8ms/step - loss: 1.8248 - accuracy: 0.3405 - val_loss: 1.6969 - val_accuracy: 0.3973
Epoch 3/25
196/196 [==============================] - 2s 8ms/step - loss: 1.6864 - accuracy: 0.3944 - val_loss: 1.5874 - val_accuracy: 0.4326
Epoch 4/25
196/196 [==============================] - 2s 8ms/step - loss: 1.5985 - accuracy: 0.4238 - val_loss: 1.5332 - val_accuracy: 0.4401
Epoch 5/25
196/196 [==============================] - 2s 8ms/step - loss: 1.5332 - accuracy: 0.4453 - val_loss: 1.5122 - val_accuracy: 0.4598
Epoch 6/25
196/196 [==============================] - 2s 8ms/step - loss: 1.4860 - accuracy: 0.4627 - val_loss: 1.4261 - val_accuracy: 0.4880
Epoch 7/25
196/196 [==============================] - 2s 8ms/step - loss: 1.4434 - accuracy: 0.4790 - val_loss: 1.3658 - val_accuracy: 0.5058
Epoch 8/25
196/196 [==============================] - 2s 8ms/step - loss: 1.4049 - accuracy: 0.4976 - val_loss: 1.3883 - val_accuracy: 0.5022
Epoch 9/25
196/196 [==============================] - 2s 8ms/step - loss: 1.3727 - accuracy: 0.5110 - val_loss: 1.3145 - val_accuracy: 0.5329
Epoch 10/25
196/196 [==============================] - 2s 8ms/step - loss: 1.3462 - accuracy: 0.5208 - val_loss: 1.2622 - val_accuracy: 0.5503
Epoch 11/25
196/196 [==============================] - 2s 8ms/step - loss: 1.3132 - accuracy: 0.5323 - val_loss: 1.2740 - val_accuracy: 0.5528
Epoch 12/25
196/196 [==============================] - 2s 8ms/step - loss: 1.2875 - accuracy: 0.5431 - val_loss: 1.2296 - val_accuracy: 0.5677
Epoch 13/25
196/196 [==============================] - 2s 8ms/step - loss: 1.2601 - accuracy: 0.5525 - val_loss: 1.3068 - val_accuracy: 0.5353
Epoch 14/25
196/196 [==============================] - 2s 8ms/step - loss: 1.2376 - accuracy: 0.5606 - val_loss: 1.1662 - val_accuracy: 0.5904
Epoch 15/25
196/196 [==============================] - 2s 8ms/step - loss: 1.2112 - accuracy: 0.5708 - val_loss: 1.1504 - val_accuracy: 0.5939
Epoch 16/25
196/196 [==============================] - 2s 8ms/step - loss: 1.1923 - accuracy: 0.5789 - val_loss: 1.1133 - val_accuracy: 0.6125
Epoch 17/25
196/196 [==============================] - 2s 8ms/step - loss: 1.1693 - accuracy: 0.5886 - val_loss: 1.1189 - val_accuracy: 0.6088
Epoch 18/25
196/196 [==============================] - 2s 8ms/step - loss: 1.1498 - accuracy: 0.5938 - val_loss: 1.1080 - val_accuracy: 0.6142
Epoch 19/25
196/196 [==============================] - 2s 8ms/step - loss: 1.1291 - accuracy: 0.6031 - val_loss: 1.0749 - val_accuracy: 0.6290
Epoch 20/25
196/196 [==============================] - 2s 8ms/step - loss: 1.1097 - accuracy: 0.6111 - val_loss: 1.0363 - val_accuracy: 0.6447
Epoch 21/25
196/196 [==============================] - 2s 8ms/step - loss: 1.0928 - accuracy: 0.6161 - val_loss: 1.0340 - val_accuracy: 0.6387
Epoch 22/25
196/196 [==============================] - 2s 8ms/step - loss: 1.0738 - accuracy: 0.6238 - val_loss: 1.0650 - val_accuracy: 0.6244
Epoch 23/25
196/196 [==============================] - 2s 8ms/step - loss: 1.0571 - accuracy: 0.6286 - val_loss: 0.9993 - val_accuracy: 0.6535
Epoch 24/25
196/196 [==============================] - 2s 8ms/step - loss: 1.0470 - accuracy: 0.6325 - val_loss: 0.9925 - val_accuracy: 0.6511
Epoch 25/25
196/196 [==============================] - 2s 8ms/step - loss: 1.0274 - accuracy: 0.6405 - val_loss: 1.0276 - val_accuracy: 0.6387
CPU times: user 49.6 s, sys: 7.41 s, total: 57.1 s
Wall time: 40.9 s
313/313 [==============================] - 1s 2ms/step - loss: 1.0276 - accuracy: 0.6387
Test loss: 1.0276497602462769
Test accuracy: 0.638700008392334
```

Şimdi modeli XLA derleyicisini kullanarak yeniden eğitelim. Uygulamanın ortasında derleyiciyi etkinleştirmek için Keras oturumunu sıfırlamamız gerekiyor.

``````# We need to clear the session to enable JIT in the middle of the program.
tf.keras.backend.clear_session()
tf.config.optimizer.set_jit(True) # Enable XLA.
model = compile_model(generate_model())
(x_train, y_train), (x_test, y_test) = load_data()

warmup(model, x_train, y_train, x_test, y_test)
%time train_model(model, x_train, y_train, x_test, y_test)
``````
```196/196 [==============================] - 6s 14ms/step - loss: 2.0028 - accuracy: 0.2637 - val_loss: 1.7774 - val_accuracy: 0.3784
Epoch 1/25
196/196 [==============================] - 4s 19ms/step - loss: 2.1142 - accuracy: 0.2136 - val_loss: 1.8591 - val_accuracy: 0.3303
Epoch 2/25
196/196 [==============================] - 2s 8ms/step - loss: 1.7614 - accuracy: 0.3622 - val_loss: 1.6321 - val_accuracy: 0.4149
Epoch 3/25
196/196 [==============================] - 2s 8ms/step - loss: 1.6471 - accuracy: 0.3991 - val_loss: 1.5480 - val_accuracy: 0.4379
Epoch 4/25
196/196 [==============================] - 2s 8ms/step - loss: 1.5706 - accuracy: 0.4303 - val_loss: 1.4671 - val_accuracy: 0.4691
Epoch 5/25
196/196 [==============================] - 2s 8ms/step - loss: 1.5098 - accuracy: 0.4514 - val_loss: 1.4270 - val_accuracy: 0.4858
Epoch 6/25
196/196 [==============================] - 2s 8ms/step - loss: 1.4561 - accuracy: 0.4734 - val_loss: 1.4270 - val_accuracy: 0.4919
Epoch 7/25
196/196 [==============================] - 2s 8ms/step - loss: 1.4199 - accuracy: 0.4906 - val_loss: 1.3290 - val_accuracy: 0.5272
Epoch 8/25
196/196 [==============================] - 2s 8ms/step - loss: 1.3831 - accuracy: 0.5049 - val_loss: 1.3550 - val_accuracy: 0.5205
Epoch 9/25
196/196 [==============================] - 2s 8ms/step - loss: 1.3487 - accuracy: 0.5176 - val_loss: 1.3339 - val_accuracy: 0.5240
Epoch 10/25
196/196 [==============================] - 2s 8ms/step - loss: 1.3190 - accuracy: 0.5275 - val_loss: 1.2579 - val_accuracy: 0.5528
Epoch 11/25
196/196 [==============================] - 2s 8ms/step - loss: 1.2930 - accuracy: 0.5415 - val_loss: 1.2364 - val_accuracy: 0.5654
Epoch 12/25
196/196 [==============================] - 2s 8ms/step - loss: 1.2656 - accuracy: 0.5499 - val_loss: 1.2646 - val_accuracy: 0.5561
Epoch 13/25
196/196 [==============================] - 2s 8ms/step - loss: 1.2416 - accuracy: 0.5614 - val_loss: 1.2615 - val_accuracy: 0.5540
Epoch 14/25
196/196 [==============================] - 2s 8ms/step - loss: 1.2196 - accuracy: 0.5681 - val_loss: 1.1970 - val_accuracy: 0.5797
Epoch 15/25
196/196 [==============================] - 2s 8ms/step - loss: 1.1996 - accuracy: 0.5751 - val_loss: 1.1142 - val_accuracy: 0.6099
Epoch 16/25
196/196 [==============================] - 2s 8ms/step - loss: 1.1777 - accuracy: 0.5849 - val_loss: 1.2225 - val_accuracy: 0.5653
Epoch 17/25
196/196 [==============================] - 2s 8ms/step - loss: 1.1542 - accuracy: 0.5912 - val_loss: 1.0860 - val_accuracy: 0.6187
Epoch 18/25
196/196 [==============================] - 2s 8ms/step - loss: 1.1358 - accuracy: 0.6009 - val_loss: 1.0767 - val_accuracy: 0.6180
Epoch 19/25
196/196 [==============================] - 2s 8ms/step - loss: 1.1197 - accuracy: 0.6062 - val_loss: 1.0517 - val_accuracy: 0.6318
Epoch 20/25
196/196 [==============================] - 2s 8ms/step - loss: 1.0980 - accuracy: 0.6139 - val_loss: 1.0362 - val_accuracy: 0.6390
Epoch 21/25
196/196 [==============================] - 2s 8ms/step - loss: 1.0738 - accuracy: 0.6233 - val_loss: 1.0777 - val_accuracy: 0.6256
Epoch 22/25
196/196 [==============================] - 2s 8ms/step - loss: 1.0555 - accuracy: 0.6285 - val_loss: 1.0615 - val_accuracy: 0.6353
Epoch 23/25
196/196 [==============================] - 2s 8ms/step - loss: 1.0434 - accuracy: 0.6331 - val_loss: 1.0025 - val_accuracy: 0.6498
Epoch 24/25
196/196 [==============================] - 2s 8ms/step - loss: 1.0285 - accuracy: 0.6386 - val_loss: 0.9670 - val_accuracy: 0.6614
Epoch 25/25
196/196 [==============================] - 2s 8ms/step - loss: 1.0116 - accuracy: 0.6443 - val_loss: 0.9806 - val_accuracy: 0.6576
CPU times: user 43.8 s, sys: 6.15 s, total: 49.9 s
Wall time: 42.6 s
```

Titan V GPU'lu ve Intel Xeon E5-2690 CPU'lu bir makinede hız artışı ~1.17x'tir.

[{ "type": "thumb-down", "id": "missingTheInformationINeed", "label":"İhtiyacım olan bilgiler yok" },{ "type": "thumb-down", "id": "tooComplicatedTooManySteps", "label":"Çok karmaşık / çok fazla adım var" },{ "type": "thumb-down", "id": "outOfDate", "label":"Güncel değil" },{ "type": "thumb-down", "id": "translationIssue", "label":"Çeviri sorunu" },{ "type": "thumb-down", "id": "samplesCodeIssue", "label":"Örnek veya kod sorunu" },{ "type": "thumb-down", "id": "otherDown", "label":"Diğer" }]
[{ "type": "thumb-up", "id": "easyToUnderstand", "label":"Anlaması kolay" },{ "type": "thumb-up", "id": "solvedMyProblem", "label":"Sorunumu çözdü" },{ "type": "thumb-up", "id": "otherUp", "label":"Diğer" }]