Aumento de datos

Ver en TensorFlow.org Ejecutar en Google Colab Ver fuente en GitHub Descargar cuaderno

Visión general

Este tutorial demuestra el aumento de datos: una técnica para aumentar la diversidad de su conjunto de entrenamiento mediante la aplicación de transformaciones aleatorias (pero realistas) como la rotación de imágenes. Aprenderá a aplicar el aumento de datos de dos formas. En primer lugar, se utilizará Capas Keras de preprocesamiento . A continuación, se utilizará tf.image .

Configuración

import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds

from tensorflow.keras import layers
2021-07-31 01:20:29.398577: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library libcudart.so.11.0

Descarga un conjunto de datos

En este tutorial se utiliza el tf_flowers conjunto de datos. Para mayor comodidad, descargar el conjunto de datos utilizando TensorFlow conjuntos de datos . Si desea aprender acerca de otras formas de importación de datos, consulte la carga de imágenes tutorial.

(train_ds, val_ds, test_ds), metadata = tfds.load(
    'tf_flowers',
    split=['train[:80%]', 'train[80%:90%]', 'train[90%:]'],
    with_info=True,
    as_supervised=True,
)
2021-07-31 01:20:32.658409: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library libcuda.so.1
2021-07-31 01:20:33.245494: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2021-07-31 01:20:33.246357: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1733] Found device 0 with properties: 
pciBusID: 0000:00:05.0 name: Tesla V100-SXM2-16GB computeCapability: 7.0
coreClock: 1.53GHz coreCount: 80 deviceMemorySize: 15.78GiB deviceMemoryBandwidth: 836.37GiB/s
2021-07-31 01:20:33.246390: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library libcudart.so.11.0
2021-07-31 01:20:33.249453: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library libcublas.so.11
2021-07-31 01:20:33.249541: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library libcublasLt.so.11
2021-07-31 01:20:33.250610: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library libcufft.so.10
2021-07-31 01:20:33.250929: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library libcurand.so.10
2021-07-31 01:20:33.251983: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library libcusolver.so.11
2021-07-31 01:20:33.252852: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library libcusparse.so.11
2021-07-31 01:20:33.253031: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library libcudnn.so.8
2021-07-31 01:20:33.253121: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2021-07-31 01:20:33.253980: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2021-07-31 01:20:33.254770: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1871] Adding visible gpu devices: 0
2021-07-31 01:20:33.255502: I tensorflow/core/platform/cpu_feature_guard.cc:142] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX512F FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2021-07-31 01:20:33.256126: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2021-07-31 01:20:33.256945: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1733] Found device 0 with properties: 
pciBusID: 0000:00:05.0 name: Tesla V100-SXM2-16GB computeCapability: 7.0
coreClock: 1.53GHz coreCount: 80 deviceMemorySize: 15.78GiB deviceMemoryBandwidth: 836.37GiB/s
2021-07-31 01:20:33.257035: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2021-07-31 01:20:33.257857: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2021-07-31 01:20:33.258642: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1871] Adding visible gpu devices: 0
2021-07-31 01:20:33.258683: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library libcudart.so.11.0
2021-07-31 01:20:33.827497: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1258] Device interconnect StreamExecutor with strength 1 edge matrix:
2021-07-31 01:20:33.827531: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1264]      0 
2021-07-31 01:20:33.827538: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1277] 0:   N 
2021-07-31 01:20:33.827734: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2021-07-31 01:20:33.828649: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2021-07-31 01:20:33.829480: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2021-07-31 01:20:33.830357: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1418] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 14646 MB memory) -> physical GPU (device: 0, name: Tesla V100-SXM2-16GB, pci bus id: 0000:00:05.0, compute capability: 7.0)

El conjunto de datos de flores tiene cinco clases.

num_classes = metadata.features['label'].num_classes
print(num_classes)
5

Recuperemos una imagen del conjunto de datos y usémosla para demostrar el aumento de datos.

get_label_name = metadata.features['label'].int2str

image, label = next(iter(train_ds))
_ = plt.imshow(image)
_ = plt.title(get_label_name(label))
2021-07-31 01:20:33.935374: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:176] None of the MLIR Optimization Passes are enabled (registered 2)
2021-07-31 01:20:33.935908: I tensorflow/core/platform/profile_utils/cpu_utils.cc:114] CPU Frequency: 2000160000 Hz
2021-07-31 01:20:34.470092: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.

png

Utilice capas de preprocesamiento de Keras

Cambiar el tamaño y la escala

Se puede utilizar el procesamiento previo capas para cambiar el tamaño de las imágenes a una forma consistente, y para cambiar la escala de valores de los píxeles.

IMG_SIZE = 180

resize_and_rescale = tf.keras.Sequential([
  layers.experimental.preprocessing.Resizing(IMG_SIZE, IMG_SIZE),
  layers.experimental.preprocessing.Rescaling(1./255)
])

Puede ver el resultado de aplicar estas capas a una imagen.

result = resize_and_rescale(image)
_ = plt.imshow(result)

png

Puede verificar los píxeles están en [0-1] .

print("Min and max pixel values:", result.numpy().min(), result.numpy().max())
Min and max pixel values: 0.0 1.0

Aumento de datos

También puede utilizar capas de preprocesamiento para el aumento de datos.

Vamos a crear algunas capas de preprocesamiento y aplicarlas repetidamente a la misma imagen.

data_augmentation = tf.keras.Sequential([
  layers.experimental.preprocessing.RandomFlip("horizontal_and_vertical"),
  layers.experimental.preprocessing.RandomRotation(0.2),
])
# Add the image to a batch
image = tf.expand_dims(image, 0)
plt.figure(figsize=(10, 10))
for i in range(9):
  augmented_image = data_augmentation(image)
  ax = plt.subplot(3, 3, i + 1)
  plt.imshow(augmented_image[0])
  plt.axis("off")

png

Hay una variedad de pre-procesamiento capas que se pueden utilizar para el aumento de los datos incluidos layers.RandomContrast , layers.RandomCrop , layers.RandomZoom , y otros.

Dos opciones para usar las capas de preprocesamiento

Hay dos formas de utilizar estas capas de preprocesamiento, con importantes ventajas y desventajas.

Opción 1: hacer que las capas de preprocesamiento formen parte de su modelo

model = tf.keras.Sequential([
  resize_and_rescale,
  data_augmentation,
  layers.Conv2D(16, 3, padding='same', activation='relu'),
  layers.MaxPooling2D(),
  # Rest of your model
])

Hay dos puntos importantes a tener en cuenta en este caso:

  • El aumento de datos se ejecutará en el dispositivo, de forma sincrónica con el resto de sus capas, y se beneficiará de la aceleración de la GPU.

  • Al exportar el modelo mediante model.save , las capas de pre-procesamiento se guardarán junto con el resto de su modelo. Si luego implementas este modelo, automáticamente estandarizará las imágenes (según la configuración de tus capas). Esto puede ahorrarle el esfuerzo de tener que volver a implementar esa lógica del lado del servidor.

Opción 2: aplique las capas de preprocesamiento a su conjunto de datos

aug_ds = train_ds.map(
  lambda x, y: (resize_and_rescale(x, training=True), y))

Con este enfoque, se utiliza Dataset.map para crear un conjunto de datos que produce lotes de imágenes aumentadas. En este caso:

  • El aumento de datos se realizará de forma asincrónica en la CPU y no se bloqueará. Se puede superponerse a la formación de su modelo en la GPU con los datos de pre-procesamiento, utilizando Dataset.prefetch , se muestra a continuación.
  • En este caso las capas prepreprocessing no se exportarán con el modelo cuando llame model.save . Deberá adjuntarlos a su modelo antes de guardarlo o volver a implementarlos en el lado del servidor. Después del entrenamiento, puede adjuntar las capas de preprocesamiento antes de exportar.

Se puede encontrar un ejemplo de la primera opción en la clasificación de la imagen tutorial. Demostremos aquí la segunda opción.

Aplicar las capas de preprocesamiento a los conjuntos de datos

Configure los conjuntos de datos de entrenamiento, validación y prueba con las capas de preprocesamiento que creó anteriormente. También configurará los conjuntos de datos para el rendimiento, utilizando lecturas paralelas y captación previa en búfer para generar lotes del disco sin que las E / S se bloqueen. Usted puede aprender más el rendimiento conjunto de datos en el rendimiento mejor con la API tf.data guía.

batch_size = 32
AUTOTUNE = tf.data.AUTOTUNE

def prepare(ds, shuffle=False, augment=False):
  # Resize and rescale all datasets
  ds = ds.map(lambda x, y: (resize_and_rescale(x), y), 
              num_parallel_calls=AUTOTUNE)

  if shuffle:
    ds = ds.shuffle(1000)

  # Batch all datasets
  ds = ds.batch(batch_size)

  # Use data augmentation only on the training set
  if augment:
    ds = ds.map(lambda x, y: (data_augmentation(x, training=True), y), 
                num_parallel_calls=AUTOTUNE)

  # Use buffered prefecting on all datasets
  return ds.prefetch(buffer_size=AUTOTUNE)
train_ds = prepare(train_ds, shuffle=True, augment=True)
val_ds = prepare(val_ds)
test_ds = prepare(test_ds)

Entrena un modelo

Para completar, ahora entrenará un modelo usando estos conjuntos de datos. Este modelo no se ha ajustado para la precisión (el objetivo es mostrarle la mecánica).

model = tf.keras.Sequential([
  layers.Conv2D(16, 3, padding='same', activation='relu'),
  layers.MaxPooling2D(),
  layers.Conv2D(32, 3, padding='same', activation='relu'),
  layers.MaxPooling2D(),
  layers.Conv2D(64, 3, padding='same', activation='relu'),
  layers.MaxPooling2D(),
  layers.Flatten(),
  layers.Dense(128, activation='relu'),
  layers.Dense(num_classes)
])
model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])
epochs=5
history = model.fit(
  train_ds,
  validation_data=val_ds,
  epochs=epochs
)
Epoch 1/5
2021-07-31 01:20:39.448244: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library libcudnn.so.8
2021-07-31 01:20:41.475212: I tensorflow/stream_executor/cuda/cuda_dnn.cc:359] Loaded cuDNN version 8100
2021-07-31 01:20:46.496035: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library libcublas.so.11
2021-07-31 01:20:46.860481: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library libcublasLt.so.11
92/92 [==============================] - 17s 66ms/step - loss: 1.3434 - accuracy: 0.4271 - val_loss: 1.1534 - val_accuracy: 0.5368
Epoch 2/5
92/92 [==============================] - 3s 27ms/step - loss: 1.0960 - accuracy: 0.5565 - val_loss: 1.0718 - val_accuracy: 0.5695
Epoch 3/5
92/92 [==============================] - 3s 26ms/step - loss: 1.0115 - accuracy: 0.5988 - val_loss: 1.0322 - val_accuracy: 0.6022
Epoch 4/5
92/92 [==============================] - 3s 27ms/step - loss: 0.9503 - accuracy: 0.6202 - val_loss: 0.8811 - val_accuracy: 0.6730
Epoch 5/5
92/92 [==============================] - 3s 27ms/step - loss: 0.8758 - accuracy: 0.6570 - val_loss: 0.8760 - val_accuracy: 0.6485
loss, acc = model.evaluate(test_ds)
print("Accuracy", acc)
12/12 [==============================] - 1s 13ms/step - loss: 0.8319 - accuracy: 0.6812
Accuracy 0.6811988949775696

Aumento de datos personalizado

También puede crear capas de aumento de datos personalizadas. Este tutorial muestra dos formas de hacerlo. En primer lugar, se creará un layers.Lambda capa. Esta es una buena forma de escribir código conciso. A continuación, se escribirá una nueva capa a través de subclases , que le da más control. Ambas capas invertirán aleatoriamente los colores en una imagen, según alguna probabilidad.

def random_invert_img(x, p=0.5):
  if  tf.random.uniform([]) < p:
    x = (255-x)
  else:
    x
  return x
def random_invert(factor=0.5):
  return layers.Lambda(lambda x: random_invert_img(x, factor))

random_invert = random_invert()
plt.figure(figsize=(10, 10))
for i in range(9):
  augmented_image = random_invert(image)
  ax = plt.subplot(3, 3, i + 1)
  plt.imshow(augmented_image[0].numpy().astype("uint8"))
  plt.axis("off")

png

A continuación, aplicar una capa personalizada de subclases .

class RandomInvert(layers.Layer):
  def __init__(self, factor=0.5, **kwargs):
    super().__init__(**kwargs)
    self.factor = factor

  def call(self, x):
    return random_invert_img(x)
_ = plt.imshow(RandomInvert()(image)[0])

png

Ambas capas se pueden utilizar como se describe en las opciones 1 y 2 anteriores.

Usando tf.image

Los anteriormente layers.preprocessing utilidades son convenientes. Para un control más preciso, puede escribir sus propias tuberías de aumento de datos o capas usando tf.data y tf.image . También es posible que desee comprobar hacia fuera TensorFlow Complementos Image: Operaciones y TensorFlow de E / S: Espacio de color Conversiones

Dado que el conjunto de datos de flores se configuró previamente con el aumento de datos, volvamos a importarlo para comenzar de nuevo.

(train_ds, val_ds, test_ds), metadata = tfds.load(
    'tf_flowers',
    split=['train[:80%]', 'train[80%:90%]', 'train[90%:]'],
    with_info=True,
    as_supervised=True,
)

Recupere una imagen para trabajar.

image, label = next(iter(train_ds))
_ = plt.imshow(image)
_ = plt.title(get_label_name(label))
2021-07-31 01:21:08.829596: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.

png

Usemos la siguiente función para visualizar y comparar las imágenes originales y aumentadas una al lado de la otra.

def visualize(original, augmented):
  fig = plt.figure()
  plt.subplot(1,2,1)
  plt.title('Original image')
  plt.imshow(original)

  plt.subplot(1,2,2)
  plt.title('Augmented image')
  plt.imshow(augmented)

Aumento de datos

Voltear la imagen

Voltea la imagen vertical u horizontalmente.

flipped = tf.image.flip_left_right(image)
visualize(image, flipped)

png

Escala de grises la imagen

Escala de grises una imagen.

grayscaled = tf.image.rgb_to_grayscale(image)
visualize(image, tf.squeeze(grayscaled))
_ = plt.colorbar()

png

Saturar la imagen

Sature una imagen proporcionando un factor de saturación.

saturated = tf.image.adjust_saturation(image, 3)
visualize(image, saturated)

png

Cambiar el brillo de la imagen

Cambie el brillo de la imagen proporcionando un factor de brillo.

bright = tf.image.adjust_brightness(image, 0.4)
visualize(image, bright)

png

Centro recorta la imagen

Recorta la imagen desde el centro hasta la parte de la imagen que desees.

cropped = tf.image.central_crop(image, central_fraction=0.5)
visualize(image,cropped)

png

Rotar la imagen

Gire una imagen 90 grados.

rotated = tf.image.rot90(image)
visualize(image, rotated)

png

Transformaciones aleatorias

La aplicación de transformaciones aleatorias a las imágenes puede ayudar a generalizar y expandir el conjunto de datos. Corriente tf.image API proporciona 8 de tales operaciones de imágenes al azar (ops):

Estas operaciones de imágenes aleatorias son puramente funcionales: la salida solo depende de la entrada. Esto los hace fáciles de usar en canalizaciones de entrada deterministas de alto rendimiento. Requieren una seed valor sea de entrada de cada paso. Dada la misma seed , que devuelven los mismos resultados independientemente de cuántas veces se les llama.

En las siguientes secciones, podrá:

  1. Repase ejemplos de uso de operaciones de imagen aleatorias para transformar una imagen; y
  2. Demuestre cómo aplicar transformaciones aleatorias a un conjunto de datos de entrenamiento.

Cambiar aleatoriamente el brillo de la imagen

Al azar cambiar el brillo de image , proporcionando un factor de brillo y seed . El factor de luminancia se elige aleatoriamente en el intervalo [-max_delta, max_delta) y se asocia con el dado seed .

for i in range(3):
  seed = (i, 0)  # tuple of size (2,)
  stateless_random_brightness = tf.image.stateless_random_brightness(
      image, max_delta=0.95, seed=seed)
  visualize(image, stateless_random_brightness)

png

png

png

Cambiar aleatoriamente el contraste de la imagen

Aleatoriamente cambiar el contraste de image , proporcionando una gama de contraste y seed . El rango de contraste es elegido aleatoriamente en el intervalo [lower, upper] y se asocia con el dado seed .

for i in range(3):
  seed = (i, 0)  # tuple of size (2,)
  stateless_random_contrast = tf.image.stateless_random_contrast(
      image, lower=0.1, upper=0.9, seed=seed)
  visualize(image, stateless_random_contrast)

png

png

png

Recortar una imagen al azar

Aleatoriamente recortar image proporcionando objetivo size y seed . La porción que consigue recortada de image está en un elegido compensar al azar y se asocia con el dado seed .

for i in range(3):
  seed = (i, 0)  # tuple of size (2,)
  stateless_random_crop = tf.image.stateless_random_crop(
      image, size=[210, 300, 3], seed=seed)
  visualize(image, stateless_random_crop)

png

png

png

Aplicar aumento a un conjunto de datos

Primero descarguemos el conjunto de datos de imágenes nuevamente en caso de que se modifiquen en las secciones anteriores.

(train_datasets, val_ds, test_ds), metadata = tfds.load(
    'tf_flowers',
    split=['train[:80%]', 'train[80%:90%]', 'train[90%:]'],
    with_info=True,
    as_supervised=True,
)

Definamos una función de utilidad para cambiar el tamaño y la escala de las imágenes. Esta función se utilizará para unificar el tamaño y la escala de las imágenes en el conjunto de datos:

def resize_and_rescale(image, label):
  image = tf.cast(image, tf.float32)
  image = tf.image.resize(image, [IMG_SIZE, IMG_SIZE])
  image = (image / 255.0)
  return image, label

También vamos a definir augment función que puede aplicar las transformaciones aleatorias a las imágenes. Esta función se utilizará en el conjunto de datos en el siguiente paso.

def augment(image_label, seed):
  image, label = image_label
  image, label = resize_and_rescale(image, label)
  image = tf.image.resize_with_crop_or_pad(image, IMG_SIZE + 6, IMG_SIZE + 6)
  # Make a new seed
  new_seed = tf.random.experimental.stateless_split(seed, num=1)[0, :]
  # Random crop back to the original size
  image = tf.image.stateless_random_crop(
      image, size=[IMG_SIZE, IMG_SIZE, 3], seed=seed)
  # Random brightness
  image = tf.image.stateless_random_brightness(
      image, max_delta=0.5, seed=new_seed)
  image = tf.clip_by_value(image, 0, 1)
  return image, label

Opción 1: Uso tf.data.experimental.Counter()

Crear un tf.data.experimental.Counter() objeto (llamémoslo counter ) y zip el conjunto de datos con (counter, counter) . Esto asegurará que cada imagen en el conjunto de datos se asocia con un valor único (de forma (2,) ) basado en counter que más tarde puede conseguir pasado en el augment función que la seed valor para las transformaciones aleatorias.

# Create counter and zip together with train dataset
counter = tf.data.experimental.Counter()
train_ds = tf.data.Dataset.zip((train_datasets, (counter, counter)))

Mapa del augment función para la formación de datos.

train_ds = (
    train_ds
    .shuffle(1000)
    .map(augment, num_parallel_calls=AUTOTUNE)
    .batch(batch_size)
    .prefetch(AUTOTUNE)
)
val_ds = (
    val_ds
    .map(resize_and_rescale, num_parallel_calls=AUTOTUNE)
    .batch(batch_size)
    .prefetch(AUTOTUNE)
)
test_ds = (
    test_ds
    .map(resize_and_rescale, num_parallel_calls=AUTOTUNE)
    .batch(batch_size)
    .prefetch(AUTOTUNE)
)

Opción 2: Uso tf.random.Generator

Crear un tf.random.Generator objeto con un intial de seed de valor. Llamar make_seeds funcionan en el mismo objeto generador devuelve una nueva, única seed valor siempre. Definir una función de contenedor que 1) llama make_seeds función y que 2) pasa el recién generado seed valor en el augment función para transformaciones aleatorias.

# Create a generator
rng = tf.random.Generator.from_seed(123, alg='philox')
# A wrapper function for updating seeds
def f(x, y):
  seed = rng.make_seeds(2)[0]
  image, label = augment((x, y), seed)
  return image, label

Mapear la función de contenedor f al conjunto de datos de entrenamiento.

train_ds = (
    train_datasets
    .shuffle(1000)
    .map(f, num_parallel_calls=AUTOTUNE)
    .batch(batch_size)
    .prefetch(AUTOTUNE)
)
val_ds = (
    val_ds
    .map(resize_and_rescale, num_parallel_calls=AUTOTUNE)
    .batch(batch_size)
    .prefetch(AUTOTUNE)
)
test_ds = (
    test_ds
    .map(resize_and_rescale, num_parallel_calls=AUTOTUNE)
    .batch(batch_size)
    .prefetch(AUTOTUNE)
)

Estos conjuntos de datos ahora se pueden usar para entrenar un modelo como se mostró anteriormente.

Próximos pasos

Este tutorial demostrado aumento de datos utilizando capas Keras de preprocesamiento y tf.image . Para aprender cómo incluir capas de pre-procesamiento dentro de su modelo, consulte la Clasificación de la imagen tutorial. Usted también puede estar interesado en aprender cómo las capas de pre-procesamiento que pueden ayudar a clasificar el texto, como se muestra en el texto básico de clasificación tutorial. Usted puede aprender más sobre tf.data en esta guía , y se puede aprender cómo configurar sus tuberías de entrada para un rendimiento aquí .