¡Reserva! Google I / O regresa del 18 al 20 de mayo Regístrese ahora
Se usó la API de Cloud Translation para traducir esta página.
Switch to English

Entrenando una red neuronal en MNIST con Keras

Este sencillo ejemplo demuestra cómo conectar TFDS en un modelo Keras.

Ver en TensorFlow.org Ejecutar en Google Colab Ver fuente en GitHub Descargar cuaderno
import tensorflow as tf
import tensorflow_datasets as tfds

Paso 1: crea tu canalización de entrada

Cree una canalización de entrada eficiente con los consejos de:

Cargar MNIST

Cargue con los siguientes argumentos:

  • shuffle_files : Los datos MNIST solo se almacenan en un solo archivo, pero para conjuntos de datos más grandes con varios archivos en el disco, es una buena práctica mezclarlos durante el entrenamiento.
  • as_supervised : Devuelve tupla (img, label) lugar de dict {'image': img, 'label': label}
(ds_train, ds_test), ds_info = tfds.load(
    'mnist',
    split=['train', 'test'],
    shuffle_files=True,
    as_supervised=True,
    with_info=True,
)

Desarrollar una canalización de capacitación

Aplicar las siguientes transformaciones:

  • ds.map : TFDS proporciona las imágenes como tf.uint8, mientras que el modelo espera tf.float32, así que normalice las imágenes
  • ds.cache A medida que el conjunto de datos encaja en la memoria, caché antes de barajar para un mejor rendimiento.
    Nota: las transformaciones aleatorias deben aplicarse después del almacenamiento en caché
  • ds.shuffle : para una verdadera aleatoriedad, configure el búfer aleatorio al tamaño completo del conjunto de datos.
    Nota: Para conjuntos de datos más grandes que no caben en la memoria, un valor estándar es 1000 si su sistema lo permite.
  • ds.batch : Lote tras barajado para obtener lotes únicos en cada época.
  • ds.prefetch : buena práctica para finalizar la canalización mediante la ds.prefetch de actuaciones .
def normalize_img(image, label):
  """Normalizes images: `uint8` -> `float32`."""
  return tf.cast(image, tf.float32) / 255., label

ds_train = ds_train.map(
    normalize_img, num_parallel_calls=tf.data.experimental.AUTOTUNE)
ds_train = ds_train.cache()
ds_train = ds_train.shuffle(ds_info.splits['train'].num_examples)
ds_train = ds_train.batch(128)
ds_train = ds_train.prefetch(tf.data.experimental.AUTOTUNE)

Construir una canalización de evaluación

La canalización de pruebas es similar a la canalización de formación, con pequeñas diferencias:

  • Sin llamada a ds.shuffle()
  • El almacenamiento en caché se realiza después del procesamiento por lotes (ya que los lotes pueden ser los mismos entre épocas)
ds_test = ds_test.map(
    normalize_img, num_parallel_calls=tf.data.experimental.AUTOTUNE)
ds_test = ds_test.batch(128)
ds_test = ds_test.cache()
ds_test = ds_test.prefetch(tf.data.experimental.AUTOTUNE)

Paso 2: crea y entrena el modelo

Conecte la tubería de entrada a Keras.

model = tf.keras.models.Sequential([
  tf.keras.layers.Flatten(input_shape=(28, 28)),
  tf.keras.layers.Dense(128,activation='relu'),
  tf.keras.layers.Dense(10)
])
model.compile(
    optimizer=tf.keras.optimizers.Adam(0.001),
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],
)

model.fit(
    ds_train,
    epochs=6,
    validation_data=ds_test,
)
Epoch 1/6
469/469 [==============================] - 4s 4ms/step - loss: 0.6240 - sparse_categorical_accuracy: 0.8288 - val_loss: 0.2043 - val_sparse_categorical_accuracy: 0.9424
Epoch 2/6
469/469 [==============================] - 1s 2ms/step - loss: 0.1796 - sparse_categorical_accuracy: 0.9499 - val_loss: 0.1395 - val_sparse_categorical_accuracy: 0.9598
Epoch 3/6
469/469 [==============================] - 1s 2ms/step - loss: 0.1215 - sparse_categorical_accuracy: 0.9642 - val_loss: 0.1137 - val_sparse_categorical_accuracy: 0.9678
Epoch 4/6
469/469 [==============================] - 1s 2ms/step - loss: 0.0968 - sparse_categorical_accuracy: 0.9724 - val_loss: 0.0974 - val_sparse_categorical_accuracy: 0.9707
Epoch 5/6
469/469 [==============================] - 1s 2ms/step - loss: 0.0774 - sparse_categorical_accuracy: 0.9775 - val_loss: 0.0852 - val_sparse_categorical_accuracy: 0.9766
Epoch 6/6
469/469 [==============================] - 1s 2ms/step - loss: 0.0631 - sparse_categorical_accuracy: 0.9811 - val_loss: 0.0868 - val_sparse_categorical_accuracy: 0.9735
<tensorflow.python.keras.callbacks.History at 0x7f70782baa58>