Data augmentation

View on TensorFlow.org Run in Google Colab View source on GitHub Download notebook

Overview

This tutorial demonstrates manual image manipulations and augmentation using tf.image.

Data augmentation is a common technique to improve results and avoid overfitting, see Overfitting and Underfitting for others.

Setup

pip install -q git+https://github.com/tensorflow/docs
import urllib

import tensorflow as tf
from tensorflow.keras.datasets import mnist
from tensorflow.keras import layers
AUTOTUNE = tf.data.experimental.AUTOTUNE

import tensorflow_docs as tfdocs
import tensorflow_docs.plots

import tensorflow_datasets as tfds

import PIL.Image

import matplotlib.pyplot as plt
import matplotlib as mpl
mpl.rcParams['figure.figsize'] = (12, 5)

import numpy as np

Let's check the data augmentation features on an image and then augment a whole dataset later to train a model.

Download this image, by Von.grzanka, for augmentation.

image_path = tf.keras.utils.get_file("cat.jpg", "https://storage.googleapis.com/download.tensorflow.org/example_images/320px-Felis_catus-cat_on_snow.jpg")
PIL.Image.open(image_path)
Downloading data from https://storage.googleapis.com/download.tensorflow.org/example_images/320px-Felis_catus-cat_on_snow.jpg
24576/17858 [=========================================] - 0s 0us/step

png

Read and decode the image to tensor format.

image_string=tf.io.read_file(image_path)
image=tf.image.decode_jpeg(image_string,channels=3)

A function to visualize and compare the original and augmented image side by side.

def visualize(original, augmented):
  fig = plt.figure()
  plt.subplot(1,2,1)
  plt.title('Original image')
  plt.imshow(original)

  plt.subplot(1,2,2)
  plt.title('Augmented image')
  plt.imshow(augmented)

Augment a single image

Flipping the image

Flip the image either vertically or horizontally.

flipped = tf.image.flip_left_right(image)
visualize(image, flipped)

png

Grayscale the image

Grayscale an image.

grayscaled = tf.image.rgb_to_grayscale(image)
visualize(image, tf.squeeze(grayscaled))
plt.colorbar()
<matplotlib.colorbar.Colorbar at 0x7f08102b1160>

png

Saturate the image

Saturate an image by providing a saturation factor.

saturated = tf.image.adjust_saturation(image, 3)
visualize(image, saturated)

png

Change image brightness

Change the brightness of image by providing a brightness factor.

bright = tf.image.adjust_brightness(image, 0.4)
visualize(image, bright)

png

Rotate the image

Rotate an image by 90 degrees.

rotated = tf.image.rot90(image)
visualize(image, rotated)

png

Center crop the image

Crop the image from center upto the image part you desire.

cropped = tf.image.central_crop(image, central_fraction=0.5)
visualize(image,cropped)

png

See the tf.image reference for details about available augmentation options.

Augment a dataset and train a model with it

Train a model on an augmented dataset.

dataset, info =  tfds.load('mnist', as_supervised=True, with_info=True)
train_dataset, test_dataset = dataset['train'], dataset['test']

num_train_examples= info.splits['train'].num_examples
Downloading and preparing dataset mnist/3.0.0 (download: 11.06 MiB, generated: Unknown size, total: 11.06 MiB) to /home/kbuilder/tensorflow_datasets/mnist/3.0.0...

Warning:absl:Dataset mnist is hosted on GCS. It will automatically be downloaded to your
local data directory. If you'd instead prefer to read directly from our public
GCS bucket (recommended if you're running on GCP), you can instead set
data_dir=gs://tfds-data/datasets.


HBox(children=(FloatProgress(value=0.0, description='Dl Completed...', max=4.0, style=ProgressStyle(descriptio…


Dataset mnist downloaded and prepared to /home/kbuilder/tensorflow_datasets/mnist/3.0.0. Subsequent calls will reuse this data.

Write a function to augment the images. Map it over the the dataset. This returns a dataset that augments the data on the fly.

def convert(image, label):
  image = tf.image.convert_image_dtype(image, tf.float32) # Cast and normalize the image to [0,1]
  return image, label

def augment(image,label):
  image,label = convert(image, label)
  image = tf.image.convert_image_dtype(image, tf.float32) # Cast and normalize the image to [0,1]
  image = tf.image.resize_with_crop_or_pad(image, 34, 34) # Add 6 pixels of padding
  image = tf.image.random_crop(image, size=[28, 28, 1]) # Random crop back to 28x28
  image = tf.image.random_brightness(image, max_delta=0.5) # Random brightness

  return image,label
BATCH_SIZE = 64
# Only use a subset of the data so it's easier to overfit, for this tutorial
NUM_EXAMPLES = 2048

Create the augmented dataset.

augmented_train_batches = (
    train_dataset
    # Only train on a subset, so you can quickly see the effect.
    .take(NUM_EXAMPLES)
    .cache()
    .shuffle(num_train_examples//4)
    # The augmentation is added here.
    .map(augment, num_parallel_calls=AUTOTUNE)
    .batch(BATCH_SIZE)
    .prefetch(AUTOTUNE)
) 

And a non-augmented one for comparison.

non_augmented_train_batches = (
    train_dataset
    # Only train on a subset, so you can quickly see the effect.
    .take(NUM_EXAMPLES)
    .cache()
    .shuffle(num_train_examples//4)
    # No augmentation.
    .map(convert, num_parallel_calls=AUTOTUNE)
    .batch(BATCH_SIZE)
    .prefetch(AUTOTUNE)
) 

Setup the validation dataset. This doesn't change whether or not you're using the augmentation.

validation_batches = (
    test_dataset
    .map(convert, num_parallel_calls=AUTOTUNE)
    .batch(2*BATCH_SIZE)
)

Create and compile the model. The model is a two layered, fully-connected neural network without convolution.

def make_model():
  model = tf.keras.Sequential([
      layers.Flatten(input_shape=(28, 28, 1)),
      layers.Dense(4096, activation='relu'),
      layers.Dense(4096, activation='relu'),
      layers.Dense(10)
  ])
  model.compile(optimizer = 'adam',
                loss=tf.losses.SparseCategoricalCrossentropy(from_logits=True),
                metrics=['accuracy'])
  return model

Train the model, without augmentation:

model_without_aug = make_model()

no_aug_history = model_without_aug.fit(non_augmented_train_batches, epochs=50, validation_data=validation_batches)
Epoch 1/50
32/32 [==============================] - 1s 38ms/step - loss: 0.7947 - accuracy: 0.7485 - val_loss: 0.3258 - val_accuracy: 0.9008
Epoch 2/50
32/32 [==============================] - 1s 23ms/step - loss: 0.1551 - accuracy: 0.9512 - val_loss: 0.3558 - val_accuracy: 0.9033
Epoch 3/50
32/32 [==============================] - 1s 23ms/step - loss: 0.0937 - accuracy: 0.9692 - val_loss: 0.3041 - val_accuracy: 0.9170
Epoch 4/50
32/32 [==============================] - 1s 22ms/step - loss: 0.0413 - accuracy: 0.9839 - val_loss: 0.3262 - val_accuracy: 0.9207
Epoch 5/50
32/32 [==============================] - 1s 22ms/step - loss: 0.0206 - accuracy: 0.9927 - val_loss: 0.3151 - val_accuracy: 0.9272
Epoch 6/50
32/32 [==============================] - 1s 23ms/step - loss: 0.0179 - accuracy: 0.9932 - val_loss: 0.3467 - val_accuracy: 0.9246
Epoch 7/50
32/32 [==============================] - 1s 22ms/step - loss: 0.0351 - accuracy: 0.9907 - val_loss: 0.3731 - val_accuracy: 0.9227
Epoch 8/50
32/32 [==============================] - 1s 23ms/step - loss: 0.0772 - accuracy: 0.9751 - val_loss: 0.3367 - val_accuracy: 0.9200
Epoch 9/50
32/32 [==============================] - 1s 22ms/step - loss: 0.0579 - accuracy: 0.9805 - val_loss: 0.4197 - val_accuracy: 0.9141
Epoch 10/50
32/32 [==============================] - 1s 22ms/step - loss: 0.0341 - accuracy: 0.9878 - val_loss: 0.4159 - val_accuracy: 0.9221
Epoch 11/50
32/32 [==============================] - 1s 23ms/step - loss: 0.0444 - accuracy: 0.9897 - val_loss: 0.4530 - val_accuracy: 0.9180
Epoch 12/50
32/32 [==============================] - 1s 23ms/step - loss: 0.0440 - accuracy: 0.9888 - val_loss: 0.3899 - val_accuracy: 0.9208
Epoch 13/50
32/32 [==============================] - 1s 22ms/step - loss: 0.0102 - accuracy: 0.9961 - val_loss: 0.4007 - val_accuracy: 0.9254
Epoch 14/50
32/32 [==============================] - 1s 22ms/step - loss: 0.0085 - accuracy: 0.9976 - val_loss: 0.3633 - val_accuracy: 0.9352
Epoch 15/50
32/32 [==============================] - 1s 22ms/step - loss: 0.0060 - accuracy: 0.9985 - val_loss: 0.4455 - val_accuracy: 0.9220
Epoch 16/50
32/32 [==============================] - 1s 23ms/step - loss: 0.0384 - accuracy: 0.9922 - val_loss: 0.3923 - val_accuracy: 0.9291
Epoch 17/50
32/32 [==============================] - 1s 23ms/step - loss: 0.0142 - accuracy: 0.9956 - val_loss: 0.3724 - val_accuracy: 0.9293
Epoch 18/50
32/32 [==============================] - 1s 23ms/step - loss: 0.0019 - accuracy: 1.0000 - val_loss: 0.3417 - val_accuracy: 0.9360
Epoch 19/50
32/32 [==============================] - 1s 22ms/step - loss: 3.0518e-04 - accuracy: 1.0000 - val_loss: 0.3471 - val_accuracy: 0.9406
Epoch 20/50
32/32 [==============================] - 1s 23ms/step - loss: 2.2373e-04 - accuracy: 1.0000 - val_loss: 0.3493 - val_accuracy: 0.9405
Epoch 21/50
32/32 [==============================] - 1s 22ms/step - loss: 1.0900e-04 - accuracy: 1.0000 - val_loss: 0.3590 - val_accuracy: 0.9389
Epoch 22/50
32/32 [==============================] - 1s 22ms/step - loss: 8.8505e-05 - accuracy: 1.0000 - val_loss: 0.3616 - val_accuracy: 0.9393
Epoch 23/50
32/32 [==============================] - 1s 25ms/step - loss: 6.8111e-05 - accuracy: 1.0000 - val_loss: 0.3655 - val_accuracy: 0.9391
Epoch 24/50
32/32 [==============================] - 1s 23ms/step - loss: 5.1576e-05 - accuracy: 1.0000 - val_loss: 0.3731 - val_accuracy: 0.9392
Epoch 25/50
32/32 [==============================] - 1s 22ms/step - loss: 3.7954e-05 - accuracy: 1.0000 - val_loss: 0.3832 - val_accuracy: 0.9394
Epoch 26/50
32/32 [==============================] - 1s 22ms/step - loss: 2.6865e-05 - accuracy: 1.0000 - val_loss: 0.3915 - val_accuracy: 0.9396
Epoch 27/50
32/32 [==============================] - 1s 23ms/step - loss: 2.0358e-05 - accuracy: 1.0000 - val_loss: 0.4018 - val_accuracy: 0.9394
Epoch 28/50
32/32 [==============================] - 1s 22ms/step - loss: 1.5839e-05 - accuracy: 1.0000 - val_loss: 0.4086 - val_accuracy: 0.9394
Epoch 29/50
32/32 [==============================] - 1s 22ms/step - loss: 1.2895e-05 - accuracy: 1.0000 - val_loss: 0.4148 - val_accuracy: 0.9393
Epoch 30/50
32/32 [==============================] - 1s 23ms/step - loss: 1.0775e-05 - accuracy: 1.0000 - val_loss: 0.4208 - val_accuracy: 0.9394
Epoch 31/50
32/32 [==============================] - 1s 22ms/step - loss: 9.2570e-06 - accuracy: 1.0000 - val_loss: 0.4253 - val_accuracy: 0.9394
Epoch 32/50
32/32 [==============================] - 1s 23ms/step - loss: 8.1208e-06 - accuracy: 1.0000 - val_loss: 0.4302 - val_accuracy: 0.9396
Epoch 33/50
32/32 [==============================] - 1s 25ms/step - loss: 7.1906e-06 - accuracy: 1.0000 - val_loss: 0.4337 - val_accuracy: 0.9398
Epoch 34/50
32/32 [==============================] - 1s 23ms/step - loss: 6.5134e-06 - accuracy: 1.0000 - val_loss: 0.4373 - val_accuracy: 0.9398
Epoch 35/50
32/32 [==============================] - 1s 22ms/step - loss: 5.8839e-06 - accuracy: 1.0000 - val_loss: 0.4412 - val_accuracy: 0.9398
Epoch 36/50
32/32 [==============================] - 1s 22ms/step - loss: 5.3458e-06 - accuracy: 1.0000 - val_loss: 0.4449 - val_accuracy: 0.9395
Epoch 37/50
32/32 [==============================] - 1s 22ms/step - loss: 4.9340e-06 - accuracy: 1.0000 - val_loss: 0.4474 - val_accuracy: 0.9395
Epoch 38/50
32/32 [==============================] - 1s 22ms/step - loss: 4.5733e-06 - accuracy: 1.0000 - val_loss: 0.4503 - val_accuracy: 0.9394
Epoch 39/50
32/32 [==============================] - 1s 22ms/step - loss: 4.2514e-06 - accuracy: 1.0000 - val_loss: 0.4525 - val_accuracy: 0.9394
Epoch 40/50
32/32 [==============================] - 1s 23ms/step - loss: 3.9590e-06 - accuracy: 1.0000 - val_loss: 0.4553 - val_accuracy: 0.9394
Epoch 41/50
32/32 [==============================] - 1s 23ms/step - loss: 3.7258e-06 - accuracy: 1.0000 - val_loss: 0.4568 - val_accuracy: 0.9393
Epoch 42/50
32/32 [==============================] - 1s 22ms/step - loss: 3.4758e-06 - accuracy: 1.0000 - val_loss: 0.4596 - val_accuracy: 0.9391
Epoch 43/50
32/32 [==============================] - 1s 23ms/step - loss: 3.3092e-06 - accuracy: 1.0000 - val_loss: 0.4624 - val_accuracy: 0.9391
Epoch 44/50
32/32 [==============================] - 1s 23ms/step - loss: 3.1065e-06 - accuracy: 1.0000 - val_loss: 0.4645 - val_accuracy: 0.9389
Epoch 45/50
32/32 [==============================] - 1s 23ms/step - loss: 2.9453e-06 - accuracy: 1.0000 - val_loss: 0.4658 - val_accuracy: 0.9391
Epoch 46/50
32/32 [==============================] - 1s 23ms/step - loss: 2.7940e-06 - accuracy: 1.0000 - val_loss: 0.4682 - val_accuracy: 0.9392
Epoch 47/50
32/32 [==============================] - 1s 23ms/step - loss: 2.6627e-06 - accuracy: 1.0000 - val_loss: 0.4696 - val_accuracy: 0.9392
Epoch 48/50
32/32 [==============================] - 1s 23ms/step - loss: 2.5305e-06 - accuracy: 1.0000 - val_loss: 0.4716 - val_accuracy: 0.9392
Epoch 49/50
32/32 [==============================] - 1s 23ms/step - loss: 2.4067e-06 - accuracy: 1.0000 - val_loss: 0.4733 - val_accuracy: 0.9393
Epoch 50/50
32/32 [==============================] - 1s 22ms/step - loss: 2.3043e-06 - accuracy: 1.0000 - val_loss: 0.4741 - val_accuracy: 0.9394

Train it again with augmentation:

model_with_aug = make_model()

aug_history = model_with_aug.fit(augmented_train_batches, epochs=50, validation_data=validation_batches)
Epoch 1/50
32/32 [==============================] - 1s 35ms/step - loss: 2.3106 - accuracy: 0.3115 - val_loss: 1.2095 - val_accuracy: 0.6655
Epoch 2/50
32/32 [==============================] - 1s 23ms/step - loss: 1.3821 - accuracy: 0.5371 - val_loss: 0.7040 - val_accuracy: 0.7943
Epoch 3/50
32/32 [==============================] - 1s 22ms/step - loss: 0.9761 - accuracy: 0.6709 - val_loss: 0.5867 - val_accuracy: 0.8518
Epoch 4/50
32/32 [==============================] - 1s 23ms/step - loss: 0.7830 - accuracy: 0.7437 - val_loss: 0.4247 - val_accuracy: 0.8816
Epoch 5/50
32/32 [==============================] - 1s 23ms/step - loss: 0.6546 - accuracy: 0.7715 - val_loss: 0.3521 - val_accuracy: 0.8943
Epoch 6/50
32/32 [==============================] - 1s 23ms/step - loss: 0.5795 - accuracy: 0.8057 - val_loss: 0.3036 - val_accuracy: 0.9034
Epoch 7/50
32/32 [==============================] - 1s 22ms/step - loss: 0.4934 - accuracy: 0.8350 - val_loss: 0.2950 - val_accuracy: 0.9088
Epoch 8/50
32/32 [==============================] - 1s 23ms/step - loss: 0.4900 - accuracy: 0.8447 - val_loss: 0.2640 - val_accuracy: 0.9179
Epoch 9/50
32/32 [==============================] - 1s 23ms/step - loss: 0.4573 - accuracy: 0.8408 - val_loss: 0.2528 - val_accuracy: 0.9191
Epoch 10/50
32/32 [==============================] - 1s 23ms/step - loss: 0.4360 - accuracy: 0.8623 - val_loss: 0.2463 - val_accuracy: 0.9224
Epoch 11/50
32/32 [==============================] - 1s 23ms/step - loss: 0.3768 - accuracy: 0.8765 - val_loss: 0.2337 - val_accuracy: 0.9238
Epoch 12/50
32/32 [==============================] - 1s 24ms/step - loss: 0.3681 - accuracy: 0.8823 - val_loss: 0.2051 - val_accuracy: 0.9351
Epoch 13/50
32/32 [==============================] - 1s 23ms/step - loss: 0.3783 - accuracy: 0.8804 - val_loss: 0.2184 - val_accuracy: 0.9315
Epoch 14/50
32/32 [==============================] - 1s 24ms/step - loss: 0.3429 - accuracy: 0.8887 - val_loss: 0.1834 - val_accuracy: 0.9430
Epoch 15/50
32/32 [==============================] - 1s 23ms/step - loss: 0.3271 - accuracy: 0.8945 - val_loss: 0.2174 - val_accuracy: 0.9310
Epoch 16/50
32/32 [==============================] - 1s 23ms/step - loss: 0.3005 - accuracy: 0.8975 - val_loss: 0.2344 - val_accuracy: 0.9302
Epoch 17/50
32/32 [==============================] - 1s 23ms/step - loss: 0.3343 - accuracy: 0.8926 - val_loss: 0.1898 - val_accuracy: 0.9398
Epoch 18/50
32/32 [==============================] - 1s 23ms/step - loss: 0.3340 - accuracy: 0.8867 - val_loss: 0.1854 - val_accuracy: 0.9413
Epoch 19/50
32/32 [==============================] - 1s 23ms/step - loss: 0.2615 - accuracy: 0.9116 - val_loss: 0.2155 - val_accuracy: 0.9310
Epoch 20/50
32/32 [==============================] - 1s 23ms/step - loss: 0.3074 - accuracy: 0.8950 - val_loss: 0.1897 - val_accuracy: 0.9412
Epoch 21/50
32/32 [==============================] - 1s 23ms/step - loss: 0.2912 - accuracy: 0.9087 - val_loss: 0.1575 - val_accuracy: 0.9523
Epoch 22/50
32/32 [==============================] - 1s 23ms/step - loss: 0.2821 - accuracy: 0.9067 - val_loss: 0.1591 - val_accuracy: 0.9515
Epoch 23/50
32/32 [==============================] - 1s 23ms/step - loss: 0.2324 - accuracy: 0.9233 - val_loss: 0.1755 - val_accuracy: 0.9436
Epoch 24/50
32/32 [==============================] - 1s 23ms/step - loss: 0.2290 - accuracy: 0.9224 - val_loss: 0.1562 - val_accuracy: 0.9515
Epoch 25/50
32/32 [==============================] - 1s 23ms/step - loss: 0.2217 - accuracy: 0.9307 - val_loss: 0.1724 - val_accuracy: 0.9487
Epoch 26/50
32/32 [==============================] - 1s 23ms/step - loss: 0.2533 - accuracy: 0.9160 - val_loss: 0.1754 - val_accuracy: 0.9452
Epoch 27/50
32/32 [==============================] - 1s 23ms/step - loss: 0.2780 - accuracy: 0.9111 - val_loss: 0.1715 - val_accuracy: 0.9484
Epoch 28/50
32/32 [==============================] - 1s 23ms/step - loss: 0.2000 - accuracy: 0.9341 - val_loss: 0.1744 - val_accuracy: 0.9459
Epoch 29/50
32/32 [==============================] - 1s 23ms/step - loss: 0.2195 - accuracy: 0.9341 - val_loss: 0.1528 - val_accuracy: 0.9534
Epoch 30/50
32/32 [==============================] - 1s 23ms/step - loss: 0.2188 - accuracy: 0.9277 - val_loss: 0.1650 - val_accuracy: 0.9480
Epoch 31/50
32/32 [==============================] - 1s 24ms/step - loss: 0.2226 - accuracy: 0.9287 - val_loss: 0.1505 - val_accuracy: 0.9543
Epoch 32/50
32/32 [==============================] - 1s 23ms/step - loss: 0.2362 - accuracy: 0.9248 - val_loss: 0.1791 - val_accuracy: 0.9472
Epoch 33/50
32/32 [==============================] - 1s 23ms/step - loss: 0.2175 - accuracy: 0.9277 - val_loss: 0.1589 - val_accuracy: 0.9508
Epoch 34/50
32/32 [==============================] - 1s 24ms/step - loss: 0.2393 - accuracy: 0.9248 - val_loss: 0.1736 - val_accuracy: 0.9455
Epoch 35/50
32/32 [==============================] - 1s 23ms/step - loss: 0.1797 - accuracy: 0.9370 - val_loss: 0.1602 - val_accuracy: 0.9515
Epoch 36/50
32/32 [==============================] - 1s 24ms/step - loss: 0.2059 - accuracy: 0.9326 - val_loss: 0.1544 - val_accuracy: 0.9525
Epoch 37/50
32/32 [==============================] - 1s 24ms/step - loss: 0.1548 - accuracy: 0.9492 - val_loss: 0.1524 - val_accuracy: 0.9537
Epoch 38/50
32/32 [==============================] - 1s 24ms/step - loss: 0.1691 - accuracy: 0.9473 - val_loss: 0.1896 - val_accuracy: 0.9385
Epoch 39/50
32/32 [==============================] - 1s 24ms/step - loss: 0.1962 - accuracy: 0.9360 - val_loss: 0.1591 - val_accuracy: 0.9516
Epoch 40/50
32/32 [==============================] - 1s 24ms/step - loss: 0.1725 - accuracy: 0.9429 - val_loss: 0.1522 - val_accuracy: 0.9533
Epoch 41/50
32/32 [==============================] - 1s 23ms/step - loss: 0.1502 - accuracy: 0.9468 - val_loss: 0.1502 - val_accuracy: 0.9566
Epoch 42/50
32/32 [==============================] - 1s 23ms/step - loss: 0.1620 - accuracy: 0.9409 - val_loss: 0.1744 - val_accuracy: 0.9504
Epoch 43/50
32/32 [==============================] - 1s 24ms/step - loss: 0.1784 - accuracy: 0.9443 - val_loss: 0.1562 - val_accuracy: 0.9549
Epoch 44/50
32/32 [==============================] - 1s 25ms/step - loss: 0.1676 - accuracy: 0.9399 - val_loss: 0.1541 - val_accuracy: 0.9542
Epoch 45/50
32/32 [==============================] - 1s 23ms/step - loss: 0.2119 - accuracy: 0.9351 - val_loss: 0.2007 - val_accuracy: 0.9441
Epoch 46/50
32/32 [==============================] - 1s 23ms/step - loss: 0.1873 - accuracy: 0.9395 - val_loss: 0.1683 - val_accuracy: 0.9515
Epoch 47/50
32/32 [==============================] - 1s 23ms/step - loss: 0.1826 - accuracy: 0.9409 - val_loss: 0.1621 - val_accuracy: 0.9519
Epoch 48/50
32/32 [==============================] - 1s 23ms/step - loss: 0.1718 - accuracy: 0.9458 - val_loss: 0.1673 - val_accuracy: 0.9515
Epoch 49/50
32/32 [==============================] - 1s 23ms/step - loss: 0.1484 - accuracy: 0.9521 - val_loss: 0.1644 - val_accuracy: 0.9520
Epoch 50/50
32/32 [==============================] - 1s 24ms/step - loss: 0.1538 - accuracy: 0.9473 - val_loss: 0.1562 - val_accuracy: 0.9534

Conclusion:

In this example the augmented model converges to an accuracy ~95% on validation set. This is slightly higher (+1%) than the model trained without data augmentation.

plotter = tfdocs.plots.HistoryPlotter()
plotter.plot({"Augmented": aug_history, "Non-Augmented": no_aug_history}, metric = "accuracy")
plt.title("Accuracy")
plt.ylim([0.75,1])
(0.75, 1.0)

png

In terms of loss, the non-augmented model is obviously in the overfitting regime. The augmented model, while a few epoch slower, is still training correctly and clearly not overfitting.

plotter = tfdocs.plots.HistoryPlotter()
plotter.plot({"Augmented": aug_history, "Non-Augmented": no_aug_history}, metric = "loss")
plt.title("Loss")
plt.ylim([0,1])
(0.0, 1.0)

png