Tập dữ liệu

Trong nhiều mô hình học máy, đặc biệt là học có giám sát, bộ dữ liệu là một phần quan trọng của quá trình đào tạo. Swift cho TensorFlow cung cấp trình bao bọc cho một số bộ dữ liệu phổ biến trong mô-đun Bộ dữ liệu trong kho lưu trữ mô hình . Các trình bao bọc này giúp dễ dàng sử dụng các bộ dữ liệu phổ biến với các mô hình dựa trên Swift và tích hợp tốt với vòng đào tạo tổng quát của Swift cho TensorFlow.

Trình bao bọc tập dữ liệu được cung cấp

Đây là các trình bao bọc tập dữ liệu hiện được cung cấp trong kho lưu trữ mô hình:

Để sử dụng một trong các trình bao bọc tập dữ liệu này trong dự án Swift, hãy thêm Datasets làm phần phụ thuộc vào mục tiêu Swift của bạn và nhập mô-đun:

import Datasets

Hầu hết các trình bao bọc tập dữ liệu được thiết kế để tạo ra các lô dữ liệu được dán nhãn được xáo trộn ngẫu nhiên. Ví dụ: để sử dụng tập dữ liệu CIFAR-10, trước tiên bạn phải khởi tạo nó với kích thước lô mong muốn:

let dataset = CIFAR10(batchSize: 100)

Trong lần sử dụng đầu tiên, trình bao bọc tập dữ liệu Swift for TensorFlow sẽ tự động tải xuống tập dữ liệu gốc cho bạn, trích xuất và phân tích tất cả các kho lưu trữ có liên quan, sau đó lưu trữ tập dữ liệu đã xử lý trong thư mục bộ đệm cục bộ của người dùng. Những lần sử dụng tiếp theo của cùng một tập dữ liệu sẽ tải trực tiếp từ bộ đệm cục bộ.

Để thiết lập vòng đào tạo thủ công liên quan đến tập dữ liệu này, bạn sẽ sử dụng nội dung như sau:

for (epoch, epochBatches) in dataset.training.prefix(100).enumerated() {
  Context.local.learningPhase = .training
  ...
  for batch in epochBatches {
    let (images, labels) = (batch.data, batch.label)
    ...
  }
}

Ở trên thiết lập một trình lặp qua 100 kỷ nguyên ( .prefix(100) ) và trả về chỉ số số của kỷ nguyên hiện tại và một chuỗi được ánh xạ lười biếng trên các lô được xáo trộn tạo nên kỷ nguyên đó. Trong mỗi giai đoạn huấn luyện, các lô được lặp đi lặp lại và trích xuất để xử lý. Trong trường hợp trình bao bọc tập dữ liệu CIFAR10 , mỗi lô là một LabeledImage , cung cấp Tensor<Float> chứa tất cả hình ảnh từ lô đó và Tensor<Int32> với các nhãn phù hợp của chúng.

Trong trường hợp của CIFAR-10, toàn bộ tập dữ liệu có kích thước nhỏ và có thể được tải vào bộ nhớ cùng một lúc, nhưng đối với các tập dữ liệu lớn hơn khác, các lô được tải một cách lười biếng từ đĩa và được xử lý tại thời điểm thu được từng lô. Điều này ngăn ngừa tình trạng cạn kiệt bộ nhớ với các tập dữ liệu lớn hơn đó.

API kỷ nguyên

Hầu hết các trình bao bọc tập dữ liệu này được xây dựng trên cơ sở hạ tầng dùng chung mà chúng tôi gọi là API Epochs . Epochs cung cấp các thành phần linh hoạt nhằm hỗ trợ nhiều loại tập dữ liệu khác nhau, từ văn bản đến hình ảnh, v.v.

Nếu bạn muốn tạo trình bao bọc tập dữ liệu Swift của riêng mình, rất có thể bạn sẽ muốn sử dụng API Epochs để làm điều đó. Tuy nhiên, đối với các trường hợp phổ biến, chẳng hạn như tập dữ liệu phân loại hình ảnh, chúng tôi khuyên bạn nên bắt đầu từ mẫu dựa trên một trong các trình bao bọc tập dữ liệu hiện có và sửa đổi mẫu đó để đáp ứng nhu cầu cụ thể của bạn.

Ví dụ: hãy kiểm tra trình bao bọc tập dữ liệu CIFAR-10 và cách thức hoạt động của nó. Cốt lõi của tập dữ liệu huấn luyện được xác định ở đây:

let trainingSamples = loadCIFARTrainingFiles(in: localStorageDirectory)
training = TrainingEpochs(samples: trainingSamples, batchSize: batchSize, entropy: entropy)
  .lazy.map { (batches: Batches) -> LazyMapSequence<Batches, LabeledImage> in
    return batches.lazy.map{
      makeBatch(samples: $0, mean: mean, standardDeviation: standardDeviation, device: device)
  }
}

Kết quả từ hàm loadCIFARTrainingFiles() là một mảng các bộ (data: [UInt8], label: Int32) cho mỗi hình ảnh trong tập dữ liệu huấn luyện. Sau đó, điều này được cung cấp cho TrainingEpochs(samples:batchSize:entropy:) để tạo ra một chuỗi vô hạn các kỷ nguyên với các lô batchSize . Bạn có thể cung cấp trình tạo số ngẫu nhiên của riêng mình trong trường hợp bạn muốn có hành vi phân khối xác định, nhưng theo mặc định, SystemRandomNumberGenerator được sử dụng.

Từ đó, các bản đồ lười biếng theo các lô đạt đến đỉnh điểm trong hàm makeBatch(samples:mean:standardDeviation:device:) . Đây là một chức năng tùy chỉnh chứa đường dẫn xử lý hình ảnh thực tế cho tập dữ liệu CIFAR-10, vì vậy chúng ta hãy xem xét điều đó:

fileprivate func makeBatch<BatchSamples: Collection>(
  samples: BatchSamples, mean: Tensor<Float>?, standardDeviation: Tensor<Float>?, device: Device
) -> LabeledImage where BatchSamples.Element == (data: [UInt8], label: Int32) {
  let bytes = samples.lazy.map(\.data).reduce(into: [], +=)
  let images = Tensor<UInt8>(shape: [samples.count, 3, 32, 32], scalars: bytes, on: device)

  var imageTensor = Tensor<Float>(images.transposed(permutation: [0, 2, 3, 1]))
  imageTensor /= 255.0
  if let mean = mean, let standardDeviation = standardDeviation {
    imageTensor = (imageTensor - mean) / standardDeviation
  }

  let labels = Tensor<Int32>(samples.map(\.label), on: device)
  return LabeledImage(data: imageTensor, label: labels)
}

Hai dòng của hàm này nối tất cả byte data từ BatchSamples đến thành Tensor<UInt8> khớp với bố cục byte của hình ảnh trong tập dữ liệu CIFAR-10 thô. Tiếp theo, các kênh hình ảnh được sắp xếp lại để khớp với những kênh dự kiến ​​trong các mô hình phân loại hình ảnh tiêu chuẩn của chúng tôi và dữ liệu hình ảnh được chuyển thành Tensor<Float> để sử dụng cho mô hình.

Các tham số chuẩn hóa tùy chọn có thể được cung cấp để điều chỉnh thêm các giá trị kênh hình ảnh, một quy trình phổ biến khi đào tạo nhiều mô hình phân loại hình ảnh. Tham số chuẩn hóa Tensor s được tạo một lần khi khởi tạo tập dữ liệu và sau đó được chuyển vào makeBatch() dưới dạng tối ưu hóa để ngăn chặn việc tạo lặp lại các tensor nhỏ tạm thời có cùng giá trị.

Cuối cùng, các nhãn số nguyên được đặt trong Tensor<Int32> và cặp tensor hình ảnh/nhãn được trả về trong LabeledImage . LabeledImage là một trường hợp cụ thể của LabeledData , một cấu trúc có dữ liệu và nhãn tuân theo giao thức Collatable của API Eppch.

Để biết thêm ví dụ về API Epochs trong các loại tập dữ liệu khác nhau, bạn có thể kiểm tra các trình bao bọc tập dữ liệu khác trong kho lưu trữ mô hình.