W wielu modelach uczenia maszynowego, zwłaszcza w przypadku uczenia się nadzorowanego, zbiory danych stanowią istotną część procesu szkoleniowego. Swift dla TensorFlow zapewnia opakowania dla kilku popularnych zestawów danych w module Datasets w repozytorium modeli . Te opakowania ułatwiają korzystanie z typowych zestawów danych w modelach opartych na języku Swift i dobrze integrują się z uogólnioną pętlą szkoleniową Swift for TensorFlow.
Dostarczone opakowania zestawu danych
Oto aktualnie dostępne opakowania zbioru danych w repozytorium modeli:
- BostonMieszkanie
- CIFAR-10
- Pani COCO
- Cola
- ImageNet
- Obrazek
- Obrazwoof
- ModaMNIST
- KuzushijiMNIST
- MNIST
- Obiektyw filmowy
- Oxford-IIIT Pet
- WordSeg
Aby użyć jednego z tych opakowań zbioru danych w projekcie Swift, dodaj Datasets
jako zależność do celu Swift i zaimportuj moduł:
import Datasets
Większość opakowań zbiorów danych jest zaprojektowana tak, aby generować losowo przetasowane partie oznaczonych etykietami danych. Na przykład, aby użyć zbioru danych CIFAR-10, należy najpierw zainicjować go żądaną wielkością partii:
let dataset = CIFAR10(batchSize: 100)
Przy pierwszym użyciu opakowania zestawu danych Swift for TensorFlow automatycznie pobiorą oryginalny zestaw danych, wyodrębnią i przeanalizują wszystkie odpowiednie archiwa, a następnie zapiszą przetworzony zestaw danych w lokalnym katalogu pamięci podręcznej użytkownika. Kolejne użycia tego samego zestawu danych będą ładowane bezpośrednio z lokalnej pamięci podręcznej.
Aby skonfigurować ręczną pętlę szkoleniową obejmującą ten zbiór danych, możesz użyć czegoś takiego:
for (epoch, epochBatches) in dataset.training.prefix(100).enumerated() {
Context.local.learningPhase = .training
...
for batch in epochBatches {
let (images, labels) = (batch.data, batch.label)
...
}
}
Powyższe konfiguruje iterator przez 100 epok ( .prefix(100)
) i zwraca indeks liczbowy bieżącej epoki oraz leniwie odwzorowaną sekwencję na przetasowanych partiach tworzących tę epokę. W każdej epoce szkoleniowej partie są poddawane iteracji i wyodrębniane do przetworzenia. W przypadku opakowania zestawu danych CIFAR10
każda partia to LabeledImage
, który udostępnia Tensor<Float>
zawierający wszystkie obrazy z tej partii oraz Tensor<Int32>
z pasującymi etykietami.
W przypadku CIFAR-10 cały zbiór danych jest mały i można go załadować do pamięci na raz, natomiast w przypadku innych większych zbiorów danych partie są ładowane leniwie z dysku i przetwarzane w momencie uzyskania każdej partii. Zapobiega to wyczerpaniu pamięci w przypadku większych zestawów danych.
Interfejs API epok
Większość opakowań zbioru danych opiera się na współdzielonej infrastrukturze, którą nazwaliśmy interfejsem API Epochs . Epochs zapewnia elastyczne komponenty przeznaczone do obsługi szerokiej gamy typów zbiorów danych, od tekstu po obrazy i nie tylko.
Jeśli chcesz utworzyć własne opakowanie zbioru danych Swift, najprawdopodobniej będziesz chciał użyć do tego interfejsu API Epochs. Jednak w typowych przypadkach, takich jak zbiory danych klasyfikacji obrazów, zdecydowanie zalecamy rozpoczęcie od szablonu opartego na jednym z istniejących opakowań zbioru danych i zmodyfikowanie go w celu spełnienia konkretnych potrzeb.
Jako przykład przyjrzyjmy się opakowaniu zbioru danych CIFAR-10 i jego działaniu. Rdzeń zbioru danych szkoleniowych jest zdefiniowany tutaj:
let trainingSamples = loadCIFARTrainingFiles(in: localStorageDirectory)
training = TrainingEpochs(samples: trainingSamples, batchSize: batchSize, entropy: entropy)
.lazy.map { (batches: Batches) -> LazyMapSequence<Batches, LabeledImage> in
return batches.lazy.map{
makeBatch(samples: $0, mean: mean, standardDeviation: standardDeviation, device: device)
}
}
Wynikiem funkcji loadCIFARTrainingFiles()
jest tablica (data: [UInt8], label: Int32)
krotek dla każdego obrazu w zbiorze danych szkoleniowych. Wartość ta jest następnie przekazywana do TrainingEpochs(samples:batchSize:entropy:)
w celu utworzenia nieskończonej sekwencji epok z partiami batchSize
. Możesz udostępnić własny generator liczb losowych w przypadkach, gdy chcesz zachować deterministyczne zachowanie wsadowe, ale domyślnie używany jest SystemRandomNumberGenerator
.
Stamtąd leniwe mapy w partiach kończą się funkcją makeBatch(samples:mean:standardDeviation:device:)
. Jest to funkcja niestandardowa, w której zlokalizowany jest rzeczywisty potok przetwarzania obrazu dla zbioru danych CIFAR-10, więc przyjrzyjmy się temu:
fileprivate func makeBatch<BatchSamples: Collection>(
samples: BatchSamples, mean: Tensor<Float>?, standardDeviation: Tensor<Float>?, device: Device
) -> LabeledImage where BatchSamples.Element == (data: [UInt8], label: Int32) {
let bytes = samples.lazy.map(\.data).reduce(into: [], +=)
let images = Tensor<UInt8>(shape: [samples.count, 3, 32, 32], scalars: bytes, on: device)
var imageTensor = Tensor<Float>(images.transposed(permutation: [0, 2, 3, 1]))
imageTensor /= 255.0
if let mean = mean, let standardDeviation = standardDeviation {
imageTensor = (imageTensor - mean) / standardDeviation
}
let labels = Tensor<Int32>(samples.map(\.label), on: device)
return LabeledImage(data: imageTensor, label: labels)
}
Dwie linie tej funkcji łączą wszystkie bajty data
z przychodzących próbek BatchSamples
w Tensor<UInt8>
, który pasuje do układu bajtów obrazów w surowym zestawie danych CIFAR-10. Następnie kolejność kanałów obrazu jest zmieniana tak, aby odpowiadała oczekiwaniom w naszych standardowych modelach klasyfikacji obrazów, a dane obrazu są ponownie rzutowane do Tensor<Float>
w celu wykorzystania modelu.
Można udostępnić opcjonalne parametry normalizacji w celu dalszego dostosowania wartości kanałów obrazu, co jest procesem powszechnym podczas uczenia wielu modeli klasyfikacji obrazów. Parametr normalizacyjny Tensor
s jest tworzony raz podczas inicjalizacji zestawu danych, a następnie przekazywany do makeBatch()
w ramach optymalizacji, aby zapobiec wielokrotnemu tworzeniu małych tymczasowych tensorów o tych samych wartościach.
Na koniec etykiety całkowite są umieszczane w Tensor<Int32>
, a para tensorów obraz/etykieta zwracana w LabeledImage
. LabeledImage
to specyficzny przypadek LabeledData
, struktury zawierającej dane i etykiety zgodne z protokołem Collatable
interfejsu API Eppch.
Aby uzyskać więcej przykładów interfejsu API Epochs w różnych typach zestawów danych, możesz sprawdzić inne opakowania zestawu danych w repozytorium modeli.