Treinando uma rede neural no MNIST com Keras

Este exemplo simples demonstra como conectar conjuntos de dados do TensorFlow (TFDS) em um modelo Keras.

Veja no TensorFlow.org Executar no Google Colab Ver fonte no GitHub Baixar caderno
import tensorflow as tf
import tensorflow_datasets as tfds

Etapa 1: criar seu pipeline de entrada

Comece criando um pipeline de entrada eficiente usando conselhos de:

Carregar um conjunto de dados

Carregue o conjunto de dados MNIST com os seguintes argumentos:

  • shuffle_files=True : Os dados MNIST são armazenados apenas em um único arquivo, mas para conjuntos de dados maiores com vários arquivos em disco, é uma boa prática embaralhá-los durante o treinamento.
  • as_supervised=True : Retorna uma tupla (img, label) em vez de um dicionário {'image': img, 'label': label} .
(ds_train, ds_test), ds_info = tfds.load(
    'mnist',
    split=['train', 'test'],
    shuffle_files=True,
    as_supervised=True,
    with_info=True,
)
2022-02-07 04:05:46.671689: E tensorflow/stream_executor/cuda/cuda_driver.cc:271] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected

Crie um pipeline de treinamento

Aplique as seguintes transformações:

  • tf.data.Dataset.map : TFDS fornece imagens do tipo tf.uint8 , enquanto o modelo espera tf.float32 . Portanto, você precisa normalizar as imagens.
  • tf.data.Dataset.cache Conforme você ajusta o conjunto de dados na memória, armazene-o em cache antes de embaralhá-lo para um melhor desempenho.
    Nota: As transformações aleatórias devem ser aplicadas após o armazenamento em cache.
  • tf.data.Dataset.shuffle : Para aleatoriedade verdadeira, defina o buffer aleatório para o tamanho total do conjunto de dados.
    Observação: para grandes conjuntos de dados que não cabem na memória, use buffer_size=1000 se seu sistema permitir.
  • tf.data.Dataset.batch : Elementos de lote do conjunto de dados após embaralhar para obter lotes exclusivos em cada época.
  • tf.data.Dataset.prefetch : É uma boa prática encerrar o pipeline por pré-busca para desempenho .
def normalize_img(image, label):
  """Normalizes images: `uint8` -> `float32`."""
  return tf.cast(image, tf.float32) / 255., label

ds_train = ds_train.map(
    normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
ds_train = ds_train.cache()
ds_train = ds_train.shuffle(ds_info.splits['train'].num_examples)
ds_train = ds_train.batch(128)
ds_train = ds_train.prefetch(tf.data.AUTOTUNE)

Crie um pipeline de avaliação

Seu pipeline de teste é semelhante ao pipeline de treinamento com pequenas diferenças:

  • Você não precisa chamar tf.data.Dataset.shuffle .
  • O armazenamento em cache é feito após o lote porque os lotes podem ser os mesmos entre épocas.
ds_test = ds_test.map(
    normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
ds_test = ds_test.batch(128)
ds_test = ds_test.cache()
ds_test = ds_test.prefetch(tf.data.AUTOTUNE)

Etapa 2: criar e treinar o modelo

Conecte o pipeline de entrada do TFDS em um modelo Keras simples, compile o modelo e treine-o.

model = tf.keras.models.Sequential([
  tf.keras.layers.Flatten(input_shape=(28, 28)),
  tf.keras.layers.Dense(128, activation='relu'),
  tf.keras.layers.Dense(10)
])
model.compile(
    optimizer=tf.keras.optimizers.Adam(0.001),
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],
)

model.fit(
    ds_train,
    epochs=6,
    validation_data=ds_test,
)
Epoch 1/6
469/469 [==============================] - 5s 4ms/step - loss: 0.3503 - sparse_categorical_accuracy: 0.9053 - val_loss: 0.1979 - val_sparse_categorical_accuracy: 0.9415
Epoch 2/6
469/469 [==============================] - 1s 2ms/step - loss: 0.1668 - sparse_categorical_accuracy: 0.9524 - val_loss: 0.1392 - val_sparse_categorical_accuracy: 0.9595
Epoch 3/6
469/469 [==============================] - 1s 2ms/step - loss: 0.1216 - sparse_categorical_accuracy: 0.9657 - val_loss: 0.1120 - val_sparse_categorical_accuracy: 0.9653
Epoch 4/6
469/469 [==============================] - 1s 2ms/step - loss: 0.0939 - sparse_categorical_accuracy: 0.9726 - val_loss: 0.0960 - val_sparse_categorical_accuracy: 0.9704
Epoch 5/6
469/469 [==============================] - 1s 2ms/step - loss: 0.0757 - sparse_categorical_accuracy: 0.9781 - val_loss: 0.0928 - val_sparse_categorical_accuracy: 0.9717
Epoch 6/6
469/469 [==============================] - 1s 2ms/step - loss: 0.0625 - sparse_categorical_accuracy: 0.9818 - val_loss: 0.0851 - val_sparse_categorical_accuracy: 0.9728
<keras.callbacks.History at 0x7f77b42cd910>