Melatih jaringan saraf di MNIST dengan Keras

Contoh sederhana ini menunjukkan cara memasang TensorFlow Datasets (TFDS) ke dalam model Keras.

Lihat di TensorFlow.org Jalankan di Google Colab Lihat sumber di GitHub Unduh buku catatan
import tensorflow as tf
import tensorflow_datasets as tfds

Langkah 1: Buat saluran input Anda

Mulailah dengan membangun saluran input yang efisien menggunakan saran dari:

Muat kumpulan data

Muat set data MNIST dengan argumen berikut:

  • shuffle_files=True : Data MNIST hanya disimpan dalam satu file, tetapi untuk kumpulan data yang lebih besar dengan banyak file di disk, praktik yang baik adalah mengocoknya saat pelatihan.
  • as_supervised=True : Mengembalikan tuple (img, label) alih-alih kamus {'image': img, 'label': label} .
(ds_train, ds_test), ds_info = tfds.load(
    'mnist',
    split=['train', 'test'],
    shuffle_files=True,
    as_supervised=True,
    with_info=True,
)
2022-02-07 04:05:46.671689: E tensorflow/stream_executor/cuda/cuda_driver.cc:271] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected

Bangun saluran pelatihan

Terapkan transformasi berikut:

  • tf.data.Dataset.map : TFDS menyediakan gambar bertipe tf.uint8 , sedangkan model mengharapkan tf.float32 . Karena itu, Anda perlu menormalkan gambar.
  • tf.data.Dataset.cache Saat Anda memasukkan dataset ke dalam memori, simpan di cache sebelum mengacaknya untuk kinerja yang lebih baik.
    Catatan: Transformasi acak harus diterapkan setelah caching.
  • tf.data.Dataset.shuffle : Untuk keacakan yang sebenarnya, atur buffer shuffle ke ukuran set data penuh.
    Catatan: Untuk kumpulan data besar yang tidak dapat dimasukkan ke dalam memori, gunakan buffer_size=1000 jika sistem Anda mengizinkannya.
  • tf.data.Dataset.batch : Elemen batch dari dataset setelah mengacak untuk mendapatkan batch unik di setiap epoch.
  • tf.data.Dataset.prefetch : Ini adalah praktik yang baik untuk mengakhiri pipeline dengan melakukan prefetching untuk performa .
def normalize_img(image, label):
  """Normalizes images: `uint8` -> `float32`."""
  return tf.cast(image, tf.float32) / 255., label

ds_train = ds_train.map(
    normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
ds_train = ds_train.cache()
ds_train = ds_train.shuffle(ds_info.splits['train'].num_examples)
ds_train = ds_train.batch(128)
ds_train = ds_train.prefetch(tf.data.AUTOTUNE)

Bangun saluran evaluasi

Pipeline pengujian Anda mirip dengan pipeline pelatihan dengan sedikit perbedaan:

  • Anda tidak perlu memanggil tf.data.Dataset.shuffle .
  • Caching dilakukan setelah batching karena batch bisa sama antar epoch.
ds_test = ds_test.map(
    normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
ds_test = ds_test.batch(128)
ds_test = ds_test.cache()
ds_test = ds_test.prefetch(tf.data.AUTOTUNE)

Langkah 2: Buat dan latih modelnya

Pasang pipa input TFDS ke model Keras sederhana, kompilasi model, dan latih.

model = tf.keras.models.Sequential([
  tf.keras.layers.Flatten(input_shape=(28, 28)),
  tf.keras.layers.Dense(128, activation='relu'),
  tf.keras.layers.Dense(10)
])
model.compile(
    optimizer=tf.keras.optimizers.Adam(0.001),
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],
)

model.fit(
    ds_train,
    epochs=6,
    validation_data=ds_test,
)
Epoch 1/6
469/469 [==============================] - 5s 4ms/step - loss: 0.3503 - sparse_categorical_accuracy: 0.9053 - val_loss: 0.1979 - val_sparse_categorical_accuracy: 0.9415
Epoch 2/6
469/469 [==============================] - 1s 2ms/step - loss: 0.1668 - sparse_categorical_accuracy: 0.9524 - val_loss: 0.1392 - val_sparse_categorical_accuracy: 0.9595
Epoch 3/6
469/469 [==============================] - 1s 2ms/step - loss: 0.1216 - sparse_categorical_accuracy: 0.9657 - val_loss: 0.1120 - val_sparse_categorical_accuracy: 0.9653
Epoch 4/6
469/469 [==============================] - 1s 2ms/step - loss: 0.0939 - sparse_categorical_accuracy: 0.9726 - val_loss: 0.0960 - val_sparse_categorical_accuracy: 0.9704
Epoch 5/6
469/469 [==============================] - 1s 2ms/step - loss: 0.0757 - sparse_categorical_accuracy: 0.9781 - val_loss: 0.0928 - val_sparse_categorical_accuracy: 0.9717
Epoch 6/6
469/469 [==============================] - 1s 2ms/step - loss: 0.0625 - sparse_categorical_accuracy: 0.9818 - val_loss: 0.0851 - val_sparse_categorical_accuracy: 0.9728
<keras.callbacks.History at 0x7f77b42cd910>