O Dia da Comunidade de ML é dia 9 de novembro! Junte-nos para atualização de TensorFlow, JAX, e mais Saiba mais

Treinando uma rede neural em MNIST com Keras

Este exemplo simples demonstra como conectar TensorFlow Datasets (TFDS) em um modelo Keras.

Ver no TensorFlow.org Executar no Google Colab Ver fonte no GitHub Baixar caderno
import tensorflow as tf
import tensorflow_datasets as tfds

Etapa 1: crie seu pipeline de entrada

Comece construindo um pipeline de entrada eficiente usando conselhos de:

Carregar um conjunto de dados

Carregue o conjunto de dados MNIST com os seguintes argumentos:

  • shuffle_files=True : Os dados MNIST só é armazenado em um único arquivo, mas para conjuntos de dados maiores com vários arquivos no disco, é uma boa prática para baralhar-los quando o treinamento.
  • as_supervised=True : Retorna uma tupla (img, label) em vez de um dicionário {'image': img, 'label': label} .
(ds_train, ds_test), ds_info = tfds.load(
    'mnist',
    split=['train', 'test'],
    shuffle_files=True,
    as_supervised=True,
    with_info=True,
)
2021-10-19 11:14:00.483247: E tensorflow/stream_executor/cuda/cuda_driver.cc:271] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected

Construir um pipeline de treinamento

Aplique as seguintes transformações:

  • tf.data.Dataset.map : TFDS fornecer imagens do tipo tf.uint8 , enquanto o modelo de espera tf.float32 . Portanto, você precisa normalizar as imagens.
  • tf.data.Dataset.cache Como você se encaixa o conjunto de dados na memória, cache de antes baralhar para um melhor desempenho.
    Nota: transformações aleatórios devem ser aplicados após cache.
  • tf.data.Dataset.shuffle : Para a verdadeira aleatoriedade, defina o tampão shuffle para o tamanho total do conjunto de dados.
    Nota: Para grandes conjuntos de dados que não podem caber na memória, uso buffer_size=1000 se o seu sistema permitir.
  • tf.data.Dataset.batch : Elementos do lote do conjunto de dados Depois de embaralhar para obter lotes originais em cada época.
  • tf.data.Dataset.prefetch : É uma boa prática para acabar com o gasoduto por pré-busca para o desempenho .
def normalize_img(image, label):
  """Normalizes images: `uint8` -> `float32`."""
  return tf.cast(image, tf.float32) / 255., label

ds_train = ds_train.map(
    normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
ds_train = ds_train.cache()
ds_train = ds_train.shuffle(ds_info.splits['train'].num_examples)
ds_train = ds_train.batch(128)
ds_train = ds_train.prefetch(tf.data.AUTOTUNE)

Construir um pipeline de avaliação

Seu pipeline de teste é semelhante ao pipeline de treinamento, com pequenas diferenças:

  • Você não precisa chamar tf.data.Dataset.shuffle .
  • O armazenamento em cache é feito após o envio em lote porque os lotes podem ser os mesmos entre as épocas.
ds_test = ds_test.map(
    normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
ds_test = ds_test.batch(128)
ds_test = ds_test.cache()
ds_test = ds_test.prefetch(tf.data.AUTOTUNE)

Etapa 2: criar e treinar o modelo

Conecte o pipeline de entrada TFDS em um modelo Keras simples, compile o modelo e treine-o.

model = tf.keras.models.Sequential([
  tf.keras.layers.Flatten(input_shape=(28, 28)),
  tf.keras.layers.Dense(128, activation='relu'),
  tf.keras.layers.Dense(10)
])
model.compile(
    optimizer=tf.keras.optimizers.Adam(0.001),
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],
)

model.fit(
    ds_train,
    epochs=6,
    validation_data=ds_test,
)
Epoch 1/6
469/469 [==============================] - 4s 4ms/step - loss: 0.3589 - sparse_categorical_accuracy: 0.9012 - val_loss: 0.1962 - val_sparse_categorical_accuracy: 0.9437
Epoch 2/6
469/469 [==============================] - 1s 2ms/step - loss: 0.1701 - sparse_categorical_accuracy: 0.9511 - val_loss: 0.1487 - val_sparse_categorical_accuracy: 0.9548
Epoch 3/6
469/469 [==============================] - 1s 2ms/step - loss: 0.1232 - sparse_categorical_accuracy: 0.9648 - val_loss: 0.1158 - val_sparse_categorical_accuracy: 0.9649
Epoch 4/6
469/469 [==============================] - 1s 2ms/step - loss: 0.0949 - sparse_categorical_accuracy: 0.9732 - val_loss: 0.0964 - val_sparse_categorical_accuracy: 0.9706
Epoch 5/6
469/469 [==============================] - 1s 2ms/step - loss: 0.0770 - sparse_categorical_accuracy: 0.9775 - val_loss: 0.0856 - val_sparse_categorical_accuracy: 0.9730
Epoch 6/6
469/469 [==============================] - 1s 2ms/step - loss: 0.0620 - sparse_categorical_accuracy: 0.9825 - val_loss: 0.0834 - val_sparse_categorical_accuracy: 0.9745
<keras.callbacks.History at 0x7f2df74b4090>