Catat tanggalnya! Google I / O mengembalikan 18-20 Mei Daftar sekarang
Halaman ini diterjemahkan oleh Cloud Translation API.
Switch to English

Melatih jaringan saraf di MNIST dengan Keras

Contoh sederhana ini menunjukkan cara menyambungkan TFDS ke model Keras.

Lihat di TensorFlow.org Jalankan di Google Colab Lihat sumber di GitHub Unduh buku catatan
import tensorflow as tf
import tensorflow_datasets as tfds

Langkah 1: Buat pipeline masukan Anda

Bangun pipeline masukan yang efisien menggunakan saran dari:

Muat MNIST

Muat dengan argumen berikut:

  • shuffle_files : Data MNIST hanya disimpan dalam satu file, tetapi untuk kumpulan data yang lebih besar dengan beberapa file di disk, praktik yang baik adalah mengacaknya saat pelatihan.
  • as_supervised : Mengembalikan tuple (img, label) bukan dict {'image': img, 'label': label}
(ds_train, ds_test), ds_info = tfds.load(
    'mnist',
    split=['train', 'test'],
    shuffle_files=True,
    as_supervised=True,
    with_info=True,
)

Bangun pipeline pelatihan

Terapkan transormasi berikut:

  • ds.map : TFDS menyediakan gambar sebagai tf.uint8, sedangkan model mengharapkan tf.float32, jadi normalkan gambar
  • ds.cache Karena ds.cache data sesuai dengan memori, cache sebelum pengacakan untuk kinerja yang lebih baik.
    Catatan: Transformasi acak harus diterapkan setelah penyimpanan ke cache
  • ds.shuffle : Untuk keacakan yang sebenarnya, setel buffer acak ke ukuran kumpulan data penuh.
    Catatan: Untuk dataset yang lebih besar yang tidak muat dalam memori, nilai standarnya adalah 1000 jika sistem Anda mengizinkannya.
  • ds.batch : ds.batch setelah pengacakan untuk mendapatkan kelompok unik di setiap periode.
  • ds.prefetch : Praktik yang baik untuk mengakhiri pipeline dengan melakukan prefetching untuk pertunjukan .
def normalize_img(image, label):
  """Normalizes images: `uint8` -> `float32`."""
  return tf.cast(image, tf.float32) / 255., label

ds_train = ds_train.map(
    normalize_img, num_parallel_calls=tf.data.experimental.AUTOTUNE)
ds_train = ds_train.cache()
ds_train = ds_train.shuffle(ds_info.splits['train'].num_examples)
ds_train = ds_train.batch(128)
ds_train = ds_train.prefetch(tf.data.experimental.AUTOTUNE)

Bangun jalur evaluasi

Pipeline pengujian mirip dengan pipeline pelatihan, dengan perbedaan kecil:

  • Tidak ada panggilan ds.shuffle()
  • Caching dilakukan setelah batching (karena batch bisa sama antar epoch)
ds_test = ds_test.map(
    normalize_img, num_parallel_calls=tf.data.experimental.AUTOTUNE)
ds_test = ds_test.batch(128)
ds_test = ds_test.cache()
ds_test = ds_test.prefetch(tf.data.experimental.AUTOTUNE)

Langkah 2: Buat dan latih modelnya

Colokkan pipa input ke Keras.

model = tf.keras.models.Sequential([
  tf.keras.layers.Flatten(input_shape=(28, 28)),
  tf.keras.layers.Dense(128,activation='relu'),
  tf.keras.layers.Dense(10)
])
model.compile(
    optimizer=tf.keras.optimizers.Adam(0.001),
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],
)

model.fit(
    ds_train,
    epochs=6,
    validation_data=ds_test,
)
Epoch 1/6
469/469 [==============================] - 4s 4ms/step - loss: 0.6240 - sparse_categorical_accuracy: 0.8288 - val_loss: 0.2043 - val_sparse_categorical_accuracy: 0.9424
Epoch 2/6
469/469 [==============================] - 1s 2ms/step - loss: 0.1796 - sparse_categorical_accuracy: 0.9499 - val_loss: 0.1395 - val_sparse_categorical_accuracy: 0.9598
Epoch 3/6
469/469 [==============================] - 1s 2ms/step - loss: 0.1215 - sparse_categorical_accuracy: 0.9642 - val_loss: 0.1137 - val_sparse_categorical_accuracy: 0.9678
Epoch 4/6
469/469 [==============================] - 1s 2ms/step - loss: 0.0968 - sparse_categorical_accuracy: 0.9724 - val_loss: 0.0974 - val_sparse_categorical_accuracy: 0.9707
Epoch 5/6
469/469 [==============================] - 1s 2ms/step - loss: 0.0774 - sparse_categorical_accuracy: 0.9775 - val_loss: 0.0852 - val_sparse_categorical_accuracy: 0.9766
Epoch 6/6
469/469 [==============================] - 1s 2ms/step - loss: 0.0631 - sparse_categorical_accuracy: 0.9811 - val_loss: 0.0868 - val_sparse_categorical_accuracy: 0.9735
<tensorflow.python.keras.callbacks.History at 0x7f70782baa58>