Tăng dữ liệu

Xem trên TensorFlow.org Chạy trong Google Colab Xem nguồn trên GitHub Tải xuống sổ ghi chép

Tổng quat

Hướng dẫn này trình bày khả năng tăng dữ liệu: một kỹ thuật để tăng tính đa dạng của tập huấn luyện của bạn bằng cách áp dụng các phép biến đổi ngẫu nhiên (nhưng thực tế), chẳng hạn như xoay hình ảnh.

Bạn sẽ học cách áp dụng tăng dữ liệu theo hai cách:

Thành lập

import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds

from tensorflow.keras import layers

Tải xuống tập dữ liệu

Hướng dẫn này sử dụng tập dữ liệu tf_flowers . Để thuận tiện, hãy tải xuống tập dữ liệu bằng TensorFlow Datasets . Nếu bạn muốn tìm hiểu về các cách nhập dữ liệu khác, hãy xem hướng dẫn tải hình ảnh .

(train_ds, val_ds, test_ds), metadata = tfds.load(
    'tf_flowers',
    split=['train[:80%]', 'train[80%:90%]', 'train[90%:]'],
    with_info=True,
    as_supervised=True,
)

Tập dữ liệu hoa có năm lớp.

num_classes = metadata.features['label'].num_classes
print(num_classes)
5

Hãy lấy một hình ảnh từ tập dữ liệu và sử dụng nó để chứng minh việc tăng dữ liệu.

get_label_name = metadata.features['label'].int2str

image, label = next(iter(train_ds))
_ = plt.imshow(image)
_ = plt.title(get_label_name(label))
2022-01-26 05:09:18.712477: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.

png

Sử dụng các lớp tiền xử lý của Keras

Thay đổi kích thước và thay đổi tỷ lệ

Bạn có thể sử dụng các lớp tiền xử lý của Keras để thay đổi kích thước hình ảnh của mình thành một hình dạng nhất quán (với tf.keras.layers.Resizing ) và thay đổi tỷ lệ giá trị pixel (với tf.keras.layers.Rescaling ).

IMG_SIZE = 180

resize_and_rescale = tf.keras.Sequential([
  layers.Resizing(IMG_SIZE, IMG_SIZE),
  layers.Rescaling(1./255)
])

Bạn có thể hình dung kết quả của việc áp dụng các lớp này vào một hình ảnh.

result = resize_and_rescale(image)
_ = plt.imshow(result)

png

Xác minh rằng các pixel nằm trong phạm vi [0, 1] :

print("Min and max pixel values:", result.numpy().min(), result.numpy().max())
Min and max pixel values: 0.0 1.0

Tăng dữ liệu

Bạn cũng có thể sử dụng các lớp tiền xử lý Keras để tăng dữ liệu, chẳng hạn như tf.keras.layers.RandomFliptf.keras.layers.RandomRotation .

Hãy tạo một vài lớp tiền xử lý và áp dụng chúng nhiều lần vào cùng một hình ảnh.

data_augmentation = tf.keras.Sequential([
  layers.RandomFlip("horizontal_and_vertical"),
  layers.RandomRotation(0.2),
])
# Add the image to a batch.
image = tf.expand_dims(image, 0)
plt.figure(figsize=(10, 10))
for i in range(9):
  augmented_image = data_augmentation(image)
  ax = plt.subplot(3, 3, i + 1)
  plt.imshow(augmented_image[0])
  plt.axis("off")
WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).

png

Có nhiều lớp tiền xử lý khác nhau mà bạn có thể sử dụng để tăng dữ liệu bao gồm tf.keras.layers.RandomContrast , tf.keras.layers.RandomCrop , tf.keras.layers.RandomZoom và các lớp khác.

Hai tùy chọn để sử dụng các lớp tiền xử lý của Keras

Có hai cách bạn có thể sử dụng các lớp tiền xử lý này, với những đánh đổi quan trọng.

Tùy chọn 1: Đặt các lớp tiền xử lý thành một phần của mô hình của bạn

model = tf.keras.Sequential([
  # Add the preprocessing layers you created earlier.
  resize_and_rescale,
  data_augmentation,
  layers.Conv2D(16, 3, padding='same', activation='relu'),
  layers.MaxPooling2D(),
  # Rest of your model.
])

Có hai điểm quan trọng cần lưu ý trong trường hợp này:

  • Tăng dữ liệu sẽ chạy trên thiết bị, đồng bộ với các lớp còn lại của bạn và được hưởng lợi từ việc tăng tốc GPU.

  • Khi bạn xuất mô hình của mình bằng cách sử dụng model.save , các lớp tiền xử lý sẽ được lưu cùng với phần còn lại của mô hình của bạn. Nếu sau này bạn triển khai mô hình này, nó sẽ tự động chuẩn hóa hình ảnh (theo cấu hình của các lớp của bạn). Điều này có thể giúp bạn tiết kiệm khỏi nỗ lực phải thực hiện lại phía máy chủ logic đó.

Tùy chọn 2: Áp dụng các lớp tiền xử lý cho tập dữ liệu của bạn

aug_ds = train_ds.map(
  lambda x, y: (resize_and_rescale(x, training=True), y))

Với cách tiếp cận này, bạn sử dụng Dataset.map để tạo tập dữ liệu mang lại các lô hình ảnh tăng cường. Trong trường hợp này:

  • Việc tăng dữ liệu sẽ diễn ra không đồng bộ trên CPU và không bị chặn. Bạn có thể chồng chéo quá trình đào tạo mô hình của mình trên GPU với xử lý trước dữ liệu, sử dụng Dataset.prefetch , được hiển thị bên dưới.
  • Trong trường hợp này, các lớp tiền xử lý sẽ không được xuất cùng với mô hình khi bạn gọi Model.save . Bạn sẽ cần đính kèm chúng vào mô hình của mình trước khi lưu hoặc thực hiện lại chúng ở phía máy chủ. Sau khi đào tạo, bạn có thể đính kèm các lớp tiền xử lý trước khi xuất.

Bạn có thể tìm thấy một ví dụ về tùy chọn đầu tiên trong hướng dẫn phân loại Hình ảnh . Hãy chứng minh tùy chọn thứ hai ở đây.

Áp dụng các lớp tiền xử lý cho tập dữ liệu

Định cấu hình tập dữ liệu đào tạo, xác thực và kiểm tra với các lớp tiền xử lý Keras mà bạn đã tạo trước đó. Bạn cũng sẽ định cấu hình bộ dữ liệu cho hiệu suất, sử dụng đọc song song và tìm nạp trước trong bộ đệm để mang lại các lô từ đĩa mà không bị chặn I / O. (Tìm hiểu thêm về hiệu suất tập dữ liệu trong Hiệu suất tốt hơn với hướng dẫn API tf.data .)

batch_size = 32
AUTOTUNE = tf.data.AUTOTUNE

def prepare(ds, shuffle=False, augment=False):
  # Resize and rescale all datasets.
  ds = ds.map(lambda x, y: (resize_and_rescale(x), y), 
              num_parallel_calls=AUTOTUNE)

  if shuffle:
    ds = ds.shuffle(1000)

  # Batch all datasets.
  ds = ds.batch(batch_size)

  # Use data augmentation only on the training set.
  if augment:
    ds = ds.map(lambda x, y: (data_augmentation(x, training=True), y), 
                num_parallel_calls=AUTOTUNE)

  # Use buffered prefetching on all datasets.
  return ds.prefetch(buffer_size=AUTOTUNE)
train_ds = prepare(train_ds, shuffle=True, augment=True)
val_ds = prepare(val_ds)
test_ds = prepare(test_ds)

Đào tạo một người mẫu

Để hoàn thiện, bây giờ bạn sẽ đào tạo một mô hình bằng cách sử dụng các bộ dữ liệu bạn vừa chuẩn bị.

Mô hình Tuần tự bao gồm ba khối tích chập ( tf.keras.layers.Conv2D ) với một lớp tổng hợp tối đa ( tf.keras.layers.MaxPooling2D ) trong mỗi khối. Có một lớp được kết nối đầy đủ ( tf.keras.layers.Dense ) với 128 đơn vị trên cùng được kích hoạt bởi chức năng kích hoạt ReLU ( 'relu' ). Mô hình này chưa được điều chỉnh về độ chính xác (mục đích là để cho bạn thấy cơ chế hoạt động).

model = tf.keras.Sequential([
  layers.Conv2D(16, 3, padding='same', activation='relu'),
  layers.MaxPooling2D(),
  layers.Conv2D(32, 3, padding='same', activation='relu'),
  layers.MaxPooling2D(),
  layers.Conv2D(64, 3, padding='same', activation='relu'),
  layers.MaxPooling2D(),
  layers.Flatten(),
  layers.Dense(128, activation='relu'),
  layers.Dense(num_classes)
])

Chọn trình tối ưu hóa tf.keras.optimizers.Adam và chức năng mất mát tf.keras.losses.SparseCategoricalCrossentropy . Để xem độ chính xác của quá trình đào tạo và xác thực cho từng kỷ nguyên đào tạo, hãy chuyển đối metrics vào Model.compile .

model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

Huấn luyện trong một vài kỷ nguyên:

epochs=5
history = model.fit(
  train_ds,
  validation_data=val_ds,
  epochs=epochs
)
Epoch 1/5
92/92 [==============================] - 13s 110ms/step - loss: 1.2768 - accuracy: 0.4622 - val_loss: 1.0929 - val_accuracy: 0.5640
Epoch 2/5
92/92 [==============================] - 3s 25ms/step - loss: 1.0579 - accuracy: 0.5749 - val_loss: 0.9711 - val_accuracy: 0.6349
Epoch 3/5
92/92 [==============================] - 3s 26ms/step - loss: 0.9677 - accuracy: 0.6291 - val_loss: 0.9764 - val_accuracy: 0.6431
Epoch 4/5
92/92 [==============================] - 3s 25ms/step - loss: 0.9150 - accuracy: 0.6468 - val_loss: 0.8906 - val_accuracy: 0.6431
Epoch 5/5
92/92 [==============================] - 3s 25ms/step - loss: 0.8636 - accuracy: 0.6604 - val_loss: 0.8233 - val_accuracy: 0.6730
loss, acc = model.evaluate(test_ds)
print("Accuracy", acc)
12/12 [==============================] - 5s 14ms/step - loss: 0.7922 - accuracy: 0.6948
Accuracy 0.6948229074478149

Tăng dữ liệu tùy chỉnh

Bạn cũng có thể tạo các lớp tăng dữ liệu tùy chỉnh.

Phần này của hướng dẫn cho thấy hai cách để làm như vậy:

  • Đầu tiên, bạn sẽ tạo một lớp tf.keras.layers.Lambda . Đây là một cách tốt để viết mã ngắn gọn.
  • Tiếp theo, bạn sẽ viết một lớp mới thông qua lớp con , lớp này cho phép bạn kiểm soát nhiều hơn.

Cả hai lớp sẽ đảo ngược màu một cách ngẫu nhiên trong một hình ảnh, theo một số xác suất.

def random_invert_img(x, p=0.5):
  if  tf.random.uniform([]) < p:
    x = (255-x)
  else:
    x
  return x
def random_invert(factor=0.5):
  return layers.Lambda(lambda x: random_invert_img(x, factor))

random_invert = random_invert()
plt.figure(figsize=(10, 10))
for i in range(9):
  augmented_image = random_invert(image)
  ax = plt.subplot(3, 3, i + 1)
  plt.imshow(augmented_image[0].numpy().astype("uint8"))
  plt.axis("off")
2022-01-26 05:09:53.045204: W tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc:399] target triple not found in the module
2022-01-26 05:09:53.045264: W tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc:399] target triple not found in the module
2022-01-26 05:09:53.045312: W tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc:399] target triple not found in the module
2022-01-26 05:09:53.045369: W tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc:399] target triple not found in the module
2022-01-26 05:09:53.045418: W tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc:399] target triple not found in the module
2022-01-26 05:09:53.045467: W tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc:399] target triple not found in the module
2022-01-26 05:09:53.045511: W tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc:399] target triple not found in the module
2022-01-26 05:09:53.047630: W tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc:399] target triple not found in the module

png

Tiếp theo, triển khai một lớp tùy chỉnh bằng cách phân lớp :

class RandomInvert(layers.Layer):
  def __init__(self, factor=0.5, **kwargs):
    super().__init__(**kwargs)
    self.factor = factor

  def call(self, x):
    return random_invert_img(x)
_ = plt.imshow(RandomInvert()(image)[0])

png

Cả hai lớp này đều có thể được sử dụng như được mô tả trong tùy chọn 1 và 2 ở trên.

Sử dụng tf.image

Các tiện ích tiền xử lý Keras ở trên rất tiện lợi. Tuy nhiên, để kiểm soát tốt hơn, bạn có thể viết các đường ống hoặc lớp tăng dữ liệu của riêng mình bằng cách sử dụng tf.datatf.image . (Bạn cũng có thể muốn xem Hình ảnh bổ trợ TensorFlow: Hoạt độngTensorFlow I / O: Chuyển đổi không gian màu .)

Vì tập dữ liệu về hoa trước đây đã được định cấu hình bằng cách tăng dữ liệu, hãy nhập lại để bắt đầu làm mới:

(train_ds, val_ds, test_ds), metadata = tfds.load(
    'tf_flowers',
    split=['train[:80%]', 'train[80%:90%]', 'train[90%:]'],
    with_info=True,
    as_supervised=True,
)

Lấy một hình ảnh để làm việc với:

image, label = next(iter(train_ds))
_ = plt.imshow(image)
_ = plt.title(get_label_name(label))
2022-01-26 05:09:59.918847: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.

png

Hãy sử dụng chức năng sau để trực quan hóa và so sánh song song hình ảnh gốc và hình ảnh tăng cường:

def visualize(original, augmented):
  fig = plt.figure()
  plt.subplot(1,2,1)
  plt.title('Original image')
  plt.imshow(original)

  plt.subplot(1,2,2)
  plt.title('Augmented image')
  plt.imshow(augmented)

Tăng dữ liệu

Lật một hình ảnh

Lật hình ảnh theo chiều dọc hoặc chiều ngang với tf.image.flip_left_right :

flipped = tf.image.flip_left_right(image)
visualize(image, flipped)

png

Thang độ xám một hình ảnh

Bạn có thể tạo thang độ xám cho hình ảnh với tf.image.rgb_to_grayscale :

grayscaled = tf.image.rgb_to_grayscale(image)
visualize(image, tf.squeeze(grayscaled))
_ = plt.colorbar()

png

Làm bão hòa hình ảnh

Làm bão hòa hình ảnh bằng tf.image.adjust_saturation bằng cách cung cấp hệ số bão hòa:

saturated = tf.image.adjust_saturation(image, 3)
visualize(image, saturated)

png

Thay đổi độ sáng hình ảnh

Thay đổi độ sáng của hình ảnh với tf.image.adjust_brightness bằng cách cung cấp hệ số độ sáng:

bright = tf.image.adjust_brightness(image, 0.4)
visualize(image, bright)

png

Cắt giữa một hình ảnh

Cắt hình ảnh từ giữa lên đến phần hình ảnh bạn muốn bằng cách sử dụng tf.image.central_crop :

cropped = tf.image.central_crop(image, central_fraction=0.5)
visualize(image, cropped)

png

Xoay hình ảnh

Xoay hình ảnh 90 độ với tf.image.rot90 :

rotated = tf.image.rot90(image)
visualize(image, rotated)

png

Các phép biến đổi ngẫu nhiên

Việc áp dụng các phép biến đổi ngẫu nhiên cho hình ảnh có thể giúp tổng quát hóa và mở rộng tập dữ liệu hơn nữa. API tf.image hiện tại cung cấp tám hoạt động hình ảnh ngẫu nhiên như vậy (hoạt động):

Các hoạt động hình ảnh ngẫu nhiên này hoàn toàn có chức năng: đầu ra chỉ phụ thuộc vào đầu vào. Điều này làm cho chúng đơn giản để sử dụng trong các đường ống đầu vào xác định, hiệu suất cao. Họ yêu cầu một giá trị seed được nhập vào mỗi bước. Cho cùng một seed , chúng trả về kết quả giống nhau, không phụ thuộc vào số lần chúng được gọi.

Trong các phần sau, bạn sẽ:

  1. Xem qua các ví dụ về việc sử dụng các phép toán hình ảnh ngẫu nhiên để biến đổi hình ảnh.
  2. Trình bày cách áp dụng các phép biến đổi ngẫu nhiên cho tập dữ liệu huấn luyện.

Thay đổi độ sáng hình ảnh một cách ngẫu nhiên

Thay đổi ngẫu nhiên độ sáng của image bằng tf.image.stateless_random_brightness bằng cách cung cấp hệ số độ sáng và seed . Hệ số độ sáng được chọn ngẫu nhiên trong phạm vi [-max_delta, max_delta) và được liên kết với seed đã cho.

for i in range(3):
  seed = (i, 0)  # tuple of size (2,)
  stateless_random_brightness = tf.image.stateless_random_brightness(
      image, max_delta=0.95, seed=seed)
  visualize(image, stateless_random_brightness)

png

png

png

Thay đổi ngẫu nhiên độ tương phản của hình ảnh

Thay đổi ngẫu nhiên độ tương phản của image bằng cách sử dụng tf.image.stateless_random_contrast bằng cách cung cấp dải tương phản và seed . Phạm vi tương phản được chọn ngẫu nhiên trong khoảng [lower, upper] và được liên kết với seed đã cho.

for i in range(3):
  seed = (i, 0)  # tuple of size (2,)
  stateless_random_contrast = tf.image.stateless_random_contrast(
      image, lower=0.1, upper=0.9, seed=seed)
  visualize(image, stateless_random_contrast)

png

png

png

Cắt ngẫu nhiên một hình ảnh

Cắt ngẫu nhiên image bằng cách sử dụng tf.image.stateless_random_crop bằng cách cung cấp size mục tiêu và seed . Phần bị cắt ra khỏi image ở một khoảng bù được chọn ngẫu nhiên và được liên kết với seed đã cho.

for i in range(3):
  seed = (i, 0)  # tuple of size (2,)
  stateless_random_crop = tf.image.stateless_random_crop(
      image, size=[210, 300, 3], seed=seed)
  visualize(image, stateless_random_crop)

png

png

png

Áp dụng tăng cường cho tập dữ liệu

Đầu tiên chúng ta hãy tải xuống lại tập dữ liệu hình ảnh trong trường hợp chúng được sửa đổi trong các phần trước.

(train_datasets, val_ds, test_ds), metadata = tfds.load(
    'tf_flowers',
    split=['train[:80%]', 'train[80%:90%]', 'train[90%:]'],
    with_info=True,
    as_supervised=True,
)

Tiếp theo, xác định một chức năng tiện ích để thay đổi kích thước và thay đổi tỷ lệ hình ảnh. Chức năng này sẽ được sử dụng để thống nhất kích thước và tỷ lệ của hình ảnh trong tập dữ liệu:

def resize_and_rescale(image, label):
  image = tf.cast(image, tf.float32)
  image = tf.image.resize(image, [IMG_SIZE, IMG_SIZE])
  image = (image / 255.0)
  return image, label

Hãy cũng xác định hàm augment có thể áp dụng các phép biến đổi ngẫu nhiên cho hình ảnh. Hàm này sẽ được sử dụng trên tập dữ liệu trong bước tiếp theo.

def augment(image_label, seed):
  image, label = image_label
  image, label = resize_and_rescale(image, label)
  image = tf.image.resize_with_crop_or_pad(image, IMG_SIZE + 6, IMG_SIZE + 6)
  # Make a new seed.
  new_seed = tf.random.experimental.stateless_split(seed, num=1)[0, :]
  # Random crop back to the original size.
  image = tf.image.stateless_random_crop(
      image, size=[IMG_SIZE, IMG_SIZE, 3], seed=seed)
  # Random brightness.
  image = tf.image.stateless_random_brightness(
      image, max_delta=0.5, seed=new_seed)
  image = tf.clip_by_value(image, 0, 1)
  return image, label

Tùy chọn 1: Sử dụng tf.data.experimental.Counter

Tạo một đối tượng tf.data.experimental.Counter (chúng ta hãy gọi nó là bộ counter ) và Dataset.zip tập dữ liệu với (counter, counter) . Điều này sẽ đảm bảo rằng mỗi hình ảnh trong tập dữ liệu được liên kết với một giá trị duy nhất (hình dạng (2,) ) dựa trên bộ counter mà sau này có thể được chuyển vào hàm augment seed trị gốc cho các phép biến đổi ngẫu nhiên.

# Create a `Counter` object and `Dataset.zip` it together with the training set.
counter = tf.data.experimental.Counter()
train_ds = tf.data.Dataset.zip((train_datasets, (counter, counter)))

Ánh xạ chức năng augment với tập dữ liệu đào tạo:

train_ds = (
    train_ds
    .shuffle(1000)
    .map(augment, num_parallel_calls=AUTOTUNE)
    .batch(batch_size)
    .prefetch(AUTOTUNE)
)
val_ds = (
    val_ds
    .map(resize_and_rescale, num_parallel_calls=AUTOTUNE)
    .batch(batch_size)
    .prefetch(AUTOTUNE)
)
test_ds = (
    test_ds
    .map(resize_and_rescale, num_parallel_calls=AUTOTUNE)
    .batch(batch_size)
    .prefetch(AUTOTUNE)
)

Tùy chọn 2: Sử dụng tf.random.Generator

  • Tạo một đối tượng tf.random.Generator với seed giá trị ban đầu. Việc gọi hàm make_seeds trên cùng một đối tượng trình tạo luôn trả về một giá trị seed mới, duy nhất.
  • Định nghĩa một hàm trình bao bọc mà: 1) gọi hàm make_seeds ; và 2) chuyển giá trị seed mới được tạo vào hàm augment cho các phép biến đổi ngẫu nhiên.
# Create a generator.
rng = tf.random.Generator.from_seed(123, alg='philox')
# Create a wrapper function for updating seeds.
def f(x, y):
  seed = rng.make_seeds(2)[0]
  image, label = augment((x, y), seed)
  return image, label

Ánh xạ hàm trình bao bọc f tới tập dữ liệu huấn luyện và hàm resize_and_rescale — với tập xác thực và kiểm tra:

train_ds = (
    train_datasets
    .shuffle(1000)
    .map(f, num_parallel_calls=AUTOTUNE)
    .batch(batch_size)
    .prefetch(AUTOTUNE)
)
val_ds = (
    val_ds
    .map(resize_and_rescale, num_parallel_calls=AUTOTUNE)
    .batch(batch_size)
    .prefetch(AUTOTUNE)
)
test_ds = (
    test_ds
    .map(resize_and_rescale, num_parallel_calls=AUTOTUNE)
    .batch(batch_size)
    .prefetch(AUTOTUNE)
)

Các bộ dữ liệu này hiện có thể được sử dụng để đào tạo một mô hình như được hiển thị trước đó.

Bước tiếp theo

Hướng dẫn này đã chứng minh việc tăng dữ liệu bằng cách sử dụng các lớp tiền xử lý của Keras và tf.image .

  • Để tìm hiểu cách bao gồm các lớp tiền xử lý bên trong mô hình của bạn, hãy tham khảo hướng dẫn phân loại Hình ảnh .
  • Bạn cũng có thể quan tâm đến việc tìm hiểu cách xử lý trước các lớp có thể giúp bạn phân loại văn bản, như được hiển thị trong hướng dẫn Phân loại văn bản cơ bản .
  • Bạn có thể tìm hiểu thêm về tf.data trong hướng dẫn này và bạn có thể tìm hiểu cách định cấu hình đường ống đầu vào của mình để đạt được hiệu suất tại đây .