Ayuda a proteger la Gran Barrera de Coral con TensorFlow en Kaggle Únete Challenge

Entrenando una red neuronal en MNIST con Keras

Este simple ejemplo demuestra cómo conectar TensorFlow Datasets (TFDS) en un modelo de Keras.

Ver en TensorFlow.org Ejecutar en Google Colab Ver fuente en GitHub Descargar cuaderno
import tensorflow as tf
import tensorflow_datasets as tfds

Paso 1: crea tu canalización de entrada

Empiece por crear una canalización de entrada eficiente con los consejos de:

Cargar un conjunto de datos

Cargue el conjunto de datos MNIST con los siguientes argumentos:

  • shuffle_files=True : Los datos MNIST solamente se almacena en un solo archivo, pero para grandes conjuntos de datos con varios archivos en el disco, es una buena práctica para mezclarlas en el entrenamiento.
  • as_supervised=True : Devuelve una tupla (img, label) en lugar de un diccionario {'image': img, 'label': label} .
(ds_train, ds_test), ds_info = tfds.load(
    'mnist',
    split=['train', 'test'],
    shuffle_files=True,
    as_supervised=True,
    with_info=True,
)
2021-10-19 11:14:00.483247: E tensorflow/stream_executor/cuda/cuda_driver.cc:271] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected

Construya una canalización de capacitación

Aplicar las siguientes transformaciones:

  • tf.data.Dataset.map : TFDS proporcionar imágenes de tipo tf.uint8 , mientras que el modelo de espera tf.float32 . Por lo tanto, necesita normalizar las imágenes.
  • tf.data.Dataset.cache A medida que encaja en el conjunto de datos en la memoria caché antes de arrastrar los pies para un mejor rendimiento.
    Nota: Las transformaciones azar deben ser aplicadas después del almacenamiento en caché.
  • tf.data.Dataset.shuffle : Por cierto aleatoriedad, establece el búfer de reproducción aleatoria para el conjunto de datos de tamaño completo.
    Nota: Para grandes conjuntos de datos que no caben en la memoria, el uso buffer_size=1000 si el sistema lo permite.
  • tf.data.Dataset.batch : Elementos lote del conjunto de datos después de arrastrar los pies para obtener lotes únicos en cada época.
  • tf.data.Dataset.prefetch : Es una buena práctica para poner fin a la tubería mediante la obtención previa para el rendimiento .
def normalize_img(image, label):
  """Normalizes images: `uint8` -> `float32`."""
  return tf.cast(image, tf.float32) / 255., label

ds_train = ds_train.map(
    normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
ds_train = ds_train.cache()
ds_train = ds_train.shuffle(ds_info.splits['train'].num_examples)
ds_train = ds_train.batch(128)
ds_train = ds_train.prefetch(tf.data.AUTOTUNE)

Construya una canalización de evaluación

Su canal de pruebas es similar al canal de formación con pequeñas diferencias:

  • No es necesario llamar tf.data.Dataset.shuffle .
  • El almacenamiento en caché se realiza después del procesamiento por lotes porque los lotes pueden ser los mismos entre épocas.
ds_test = ds_test.map(
    normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
ds_test = ds_test.batch(128)
ds_test = ds_test.cache()
ds_test = ds_test.prefetch(tf.data.AUTOTUNE)

Paso 2: crea y entrena el modelo

Conecte la canalización de entrada TFDS en un modelo Keras simple, compile el modelo y entrénelo.

model = tf.keras.models.Sequential([
  tf.keras.layers.Flatten(input_shape=(28, 28)),
  tf.keras.layers.Dense(128, activation='relu'),
  tf.keras.layers.Dense(10)
])
model.compile(
    optimizer=tf.keras.optimizers.Adam(0.001),
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],
)

model.fit(
    ds_train,
    epochs=6,
    validation_data=ds_test,
)
Epoch 1/6
469/469 [==============================] - 4s 4ms/step - loss: 0.3589 - sparse_categorical_accuracy: 0.9012 - val_loss: 0.1962 - val_sparse_categorical_accuracy: 0.9437
Epoch 2/6
469/469 [==============================] - 1s 2ms/step - loss: 0.1701 - sparse_categorical_accuracy: 0.9511 - val_loss: 0.1487 - val_sparse_categorical_accuracy: 0.9548
Epoch 3/6
469/469 [==============================] - 1s 2ms/step - loss: 0.1232 - sparse_categorical_accuracy: 0.9648 - val_loss: 0.1158 - val_sparse_categorical_accuracy: 0.9649
Epoch 4/6
469/469 [==============================] - 1s 2ms/step - loss: 0.0949 - sparse_categorical_accuracy: 0.9732 - val_loss: 0.0964 - val_sparse_categorical_accuracy: 0.9706
Epoch 5/6
469/469 [==============================] - 1s 2ms/step - loss: 0.0770 - sparse_categorical_accuracy: 0.9775 - val_loss: 0.0856 - val_sparse_categorical_accuracy: 0.9730
Epoch 6/6
469/469 [==============================] - 1s 2ms/step - loss: 0.0620 - sparse_categorical_accuracy: 0.9825 - val_loss: 0.0834 - val_sparse_categorical_accuracy: 0.9745
<keras.callbacks.History at 0x7f2df74b4090>