ML Community Day è il 9 novembre! Unisciti a noi per gli aggiornamenti da tensorflow, JAX, e più Per saperne di più

Addestrare una rete neurale su MNIST con Keras

Questo semplice esempio mostra come collegare TensorFlow Datasets (TFDS) in un modello Keras.

Visualizza su TensorFlow.org Esegui in Google Colab Visualizza la fonte su GitHub Scarica taccuino
import tensorflow as tf
import tensorflow_datasets as tfds

Passaggio 1: crea la tua pipeline di input

Inizia costruendo una pipeline di input efficiente utilizzando i consigli di:

Carica un set di dati

Carica il set di dati MNIST con i seguenti argomenti:

  • shuffle_files=True : I dati MNIST vengono memorizzati solo in un unico file, ma per i set di dati di grandi dimensioni con più file sul disco, è buona pratica per la riproduzione casuale quando la formazione.
  • as_supervised=True : restituisce una tupla (img, label) invece di un dizionario {'image': img, 'label': label} .
(ds_train, ds_test), ds_info = tfds.load(
    'mnist',
    split=['train', 'test'],
    shuffle_files=True,
    as_supervised=True,
    with_info=True,
)
2021-10-19 11:14:00.483247: E tensorflow/stream_executor/cuda/cuda_driver.cc:271] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected

Costruisci una pipeline di formazione

Applicare le seguenti trasformazioni:

  • tf.data.Dataset.map : TFDS fornire immagini di tipo tf.uint8 , mentre il modello si aspetta tf.float32 . Pertanto, è necessario normalizzare le immagini.
  • tf.data.Dataset.cache Come si misura il set di dati in memoria, la cache prima di mischiare per una migliore performance.
    Nota: le trasformazioni casuali devono essere applicate dopo la memorizzazione nella cache.
  • tf.data.Dataset.shuffle : Per la vera casualità, impostare il buffer di riordino per il full size set di dati.
    Nota: Per le grandi serie di dati che non possono andare bene in memoria, uso buffer_size=1000 se il sistema lo permette.
  • tf.data.Dataset.batch : elementi lotto del set di dati Dopo aver mescolato per ottenere lotti unici ad ogni epoca.
  • tf.data.Dataset.prefetch : E 'buona pratica di porre fine alla condotta da prefetching per le prestazioni .
def normalize_img(image, label):
  """Normalizes images: `uint8` -> `float32`."""
  return tf.cast(image, tf.float32) / 255., label

ds_train = ds_train.map(
    normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
ds_train = ds_train.cache()
ds_train = ds_train.shuffle(ds_info.splits['train'].num_examples)
ds_train = ds_train.batch(128)
ds_train = ds_train.prefetch(tf.data.AUTOTUNE)

Costruisci una pipeline di valutazione

La tua pipeline di test è simile alla pipeline di formazione con piccole differenze:

  • Non è necessario chiamare tf.data.Dataset.shuffle .
  • La memorizzazione nella cache viene eseguita dopo il batch perché i batch possono essere gli stessi tra le epoche.
ds_test = ds_test.map(
    normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
ds_test = ds_test.batch(128)
ds_test = ds_test.cache()
ds_test = ds_test.prefetch(tf.data.AUTOTUNE)

Passaggio 2: creare e addestrare il modello

Collega la pipeline di input TFDS a un semplice modello Keras, compila il modello e addestralo.

model = tf.keras.models.Sequential([
  tf.keras.layers.Flatten(input_shape=(28, 28)),
  tf.keras.layers.Dense(128, activation='relu'),
  tf.keras.layers.Dense(10)
])
model.compile(
    optimizer=tf.keras.optimizers.Adam(0.001),
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],
)

model.fit(
    ds_train,
    epochs=6,
    validation_data=ds_test,
)
Epoch 1/6
469/469 [==============================] - 4s 4ms/step - loss: 0.3589 - sparse_categorical_accuracy: 0.9012 - val_loss: 0.1962 - val_sparse_categorical_accuracy: 0.9437
Epoch 2/6
469/469 [==============================] - 1s 2ms/step - loss: 0.1701 - sparse_categorical_accuracy: 0.9511 - val_loss: 0.1487 - val_sparse_categorical_accuracy: 0.9548
Epoch 3/6
469/469 [==============================] - 1s 2ms/step - loss: 0.1232 - sparse_categorical_accuracy: 0.9648 - val_loss: 0.1158 - val_sparse_categorical_accuracy: 0.9649
Epoch 4/6
469/469 [==============================] - 1s 2ms/step - loss: 0.0949 - sparse_categorical_accuracy: 0.9732 - val_loss: 0.0964 - val_sparse_categorical_accuracy: 0.9706
Epoch 5/6
469/469 [==============================] - 1s 2ms/step - loss: 0.0770 - sparse_categorical_accuracy: 0.9775 - val_loss: 0.0856 - val_sparse_categorical_accuracy: 0.9730
Epoch 6/6
469/469 [==============================] - 1s 2ms/step - loss: 0.0620 - sparse_categorical_accuracy: 0.9825 - val_loss: 0.0834 - val_sparse_categorical_accuracy: 0.9745
<keras.callbacks.History at 0x7f2df74b4090>