Entrenamiento de una red neuronal en MNIST con Keras

Este sencillo ejemplo demuestra cómo conectar conjuntos de datos de TensorFlow (TFDS) en un modelo de Keras.

Ver en TensorFlow.org Ejecutar en Google Colab Ver fuente en GitHub Descargar libreta
import tensorflow as tf
import tensorflow_datasets as tfds

Paso 1: Cree su tubería de entrada

Comience por construir una tubería de entrada eficiente utilizando los consejos de:

Cargar un conjunto de datos

Cargue el conjunto de datos MNIST con los siguientes argumentos:

  • shuffle_files=True : los datos de MNIST solo se almacenan en un solo archivo, pero para conjuntos de datos más grandes con varios archivos en el disco, es una buena práctica mezclarlos durante el entrenamiento.
  • as_supervised=True : Devuelve una tupla (img, label) en lugar de un diccionario {'image': img, 'label': label} .
(ds_train, ds_test), ds_info = tfds.load(
    'mnist',
    split=['train', 'test'],
    shuffle_files=True,
    as_supervised=True,
    with_info=True,
)
2022-02-07 04:05:46.671689: E tensorflow/stream_executor/cuda/cuda_driver.cc:271] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected

Cree una canalización de capacitación

Aplicar las siguientes transformaciones:

  • tf.data.Dataset.map : TFDS proporciona imágenes de tipo tf.uint8 , mientras que el modelo espera tf.float32 . Por lo tanto, necesita normalizar las imágenes.
  • tf.data.Dataset.cache A medida que ajusta el conjunto de datos en la memoria, colóquelo en caché antes de barajar para un mejor rendimiento.
    Nota: Las transformaciones aleatorias deben aplicarse después del almacenamiento en caché.
  • tf.data.Dataset.shuffle : para una verdadera aleatoriedad, establezca el búfer de reproducción aleatoria en el tamaño completo del conjunto de datos.
    Nota: Para grandes conjuntos de datos que no caben en la memoria, use buffer_size=1000 si su sistema lo permite.
  • tf.data.Dataset.batch : Elementos de lote del conjunto de datos después de mezclar para obtener lotes únicos en cada época.
  • tf.data.Dataset.prefetch : es una buena práctica finalizar la canalización mediante la captación previa del rendimiento .
def normalize_img(image, label):
  """Normalizes images: `uint8` -> `float32`."""
  return tf.cast(image, tf.float32) / 255., label

ds_train = ds_train.map(
    normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
ds_train = ds_train.cache()
ds_train = ds_train.shuffle(ds_info.splits['train'].num_examples)
ds_train = ds_train.batch(128)
ds_train = ds_train.prefetch(tf.data.AUTOTUNE)

Cree una canalización de evaluación

Su canalización de prueba es similar a la canalización de entrenamiento con pequeñas diferencias:

  • No necesita llamar a tf.data.Dataset.shuffle .
  • El almacenamiento en caché se realiza después del procesamiento por lotes porque los lotes pueden ser los mismos entre épocas.
ds_test = ds_test.map(
    normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
ds_test = ds_test.batch(128)
ds_test = ds_test.cache()
ds_test = ds_test.prefetch(tf.data.AUTOTUNE)

Paso 2: crear y entrenar el modelo

Conecte la canalización de entrada de TFDS en un modelo Keras simple, compile el modelo y entrénelo.

model = tf.keras.models.Sequential([
  tf.keras.layers.Flatten(input_shape=(28, 28)),
  tf.keras.layers.Dense(128, activation='relu'),
  tf.keras.layers.Dense(10)
])
model.compile(
    optimizer=tf.keras.optimizers.Adam(0.001),
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],
)

model.fit(
    ds_train,
    epochs=6,
    validation_data=ds_test,
)
Epoch 1/6
469/469 [==============================] - 5s 4ms/step - loss: 0.3503 - sparse_categorical_accuracy: 0.9053 - val_loss: 0.1979 - val_sparse_categorical_accuracy: 0.9415
Epoch 2/6
469/469 [==============================] - 1s 2ms/step - loss: 0.1668 - sparse_categorical_accuracy: 0.9524 - val_loss: 0.1392 - val_sparse_categorical_accuracy: 0.9595
Epoch 3/6
469/469 [==============================] - 1s 2ms/step - loss: 0.1216 - sparse_categorical_accuracy: 0.9657 - val_loss: 0.1120 - val_sparse_categorical_accuracy: 0.9653
Epoch 4/6
469/469 [==============================] - 1s 2ms/step - loss: 0.0939 - sparse_categorical_accuracy: 0.9726 - val_loss: 0.0960 - val_sparse_categorical_accuracy: 0.9704
Epoch 5/6
469/469 [==============================] - 1s 2ms/step - loss: 0.0757 - sparse_categorical_accuracy: 0.9781 - val_loss: 0.0928 - val_sparse_categorical_accuracy: 0.9717
Epoch 6/6
469/469 [==============================] - 1s 2ms/step - loss: 0.0625 - sparse_categorical_accuracy: 0.9818 - val_loss: 0.0851 - val_sparse_categorical_accuracy: 0.9728
<keras.callbacks.History at 0x7f77b42cd910>