Наборы данных

Во многих моделях машинного обучения, особенно в контролируемом обучении, наборы данных являются жизненно важной частью процесса обучения. Swift для TensorFlow предоставляет оболочки для нескольких распространенных наборов данных в модуле Datasets в репозитории моделей . Эти оболочки упрощают использование общих наборов данных с моделями на основе Swift и хорошо интегрируются с обобщенным циклом обучения Swift для TensorFlow.

Предоставленные оболочки наборов данных

В настоящее время в репозитории моделей представлены оболочки наборов данных:

Чтобы использовать одну из этих оболочек набора данных в проекте Swift, добавьте Datasets в качестве зависимости к цели Swift и импортируйте модуль:

import Datasets

Большинство оболочек наборов данных предназначены для создания случайно перетасованных пакетов помеченных данных. Например, чтобы использовать набор данных CIFAR-10, вы сначала инициализируете его с желаемым размером пакета:

let dataset = CIFAR10(batchSize: 100)

При первом использовании оболочки набора данных Swift для TensorFlow автоматически загрузят для вас исходный набор данных, извлекут и проанализируют все соответствующие архивы, а затем сохранят обработанный набор данных в локальном каталоге кэша пользователя. Последующие использования того же набора данных будут загружаться непосредственно из локального кэша.

Чтобы настроить цикл ручного обучения с использованием этого набора данных, вы должны использовать что-то вроде следующего:

for (epoch, epochBatches) in dataset.training.prefix(100).enumerated() {
  Context.local.learningPhase = .training
  ...
  for batch in epochBatches {
    let (images, labels) = (batch.data, batch.label)
    ...
  }
}

Вышеупомянутое устанавливает итератор для 100 эпох ( .prefix(100) ) и возвращает числовой индекс текущей эпохи и лениво сопоставленную последовательность по перетасованным пакетам, составляющим эту эпоху. В течение каждой эпохи обучения пакеты повторяются и извлекаются для обработки. В случае оболочки набора данных CIFAR10 каждый пакет представляет собой LabeledImage , который предоставляет Tensor<Float> , содержащий все изображения из этого пакета, и Tensor<Int32> с соответствующими метками.

В случае CIFAR-10 весь набор данных невелик и может быть загружен в память за один раз, но для других более крупных наборов данных пакеты загружаются лениво с диска и обрабатываются в момент получения каждого пакета. Это предотвращает исчерпание памяти при работе с большими наборами данных.

API эпох

Большинство этих оболочек наборов данных построены на основе общей инфраструктуры, которую мы назвали Epochs API . Epochs предоставляет гибкие компоненты, предназначенные для поддержки самых разных типов наборов данных, от текста до изображений и многого другого.

Если вы хотите создать собственную оболочку набора данных Swift, вам, скорее всего, захочется использовать для этого API Epochs. Однако для распространенных случаев, таких как наборы данных классификации изображений, мы настоятельно рекомендуем начать с шаблона, основанного на одной из существующих оболочек набора данных, и изменить его в соответствии с вашими конкретными потребностями.

В качестве примера давайте рассмотрим оболочку набора данных CIFAR-10 и то, как она работает. Ядро набора обучающих данных определяется здесь:

let trainingSamples = loadCIFARTrainingFiles(in: localStorageDirectory)
training = TrainingEpochs(samples: trainingSamples, batchSize: batchSize, entropy: entropy)
  .lazy.map { (batches: Batches) -> LazyMapSequence<Batches, LabeledImage> in
    return batches.lazy.map{
      makeBatch(samples: $0, mean: mean, standardDeviation: standardDeviation, device: device)
  }
}

Результатом функции loadCIFARTrainingFiles() является массив кортежей (data: [UInt8], label: Int32) для каждого изображения в наборе обучающих данных. Затем это передается TrainingEpochs(samples:batchSize:entropy:) для создания бесконечной последовательности эпох с пакетами batchSize . Вы можете предоставить свой собственный генератор случайных чисел в тех случаях, когда вам может потребоваться детерминированное поведение пакетной обработки, но по умолчанию используется SystemRandomNumberGenerator .

Отсюда ленивые карты для пакетов завершаются функцией makeBatch(samples:mean:standardDeviation:device:) . Это пользовательская функция, в которой находится реальный конвейер обработки изображений для набора данных CIFAR-10, поэтому давайте взглянем на нее:

fileprivate func makeBatch<BatchSamples: Collection>(
  samples: BatchSamples, mean: Tensor<Float>?, standardDeviation: Tensor<Float>?, device: Device
) -> LabeledImage where BatchSamples.Element == (data: [UInt8], label: Int32) {
  let bytes = samples.lazy.map(\.data).reduce(into: [], +=)
  let images = Tensor<UInt8>(shape: [samples.count, 3, 32, 32], scalars: bytes, on: device)

  var imageTensor = Tensor<Float>(images.transposed(permutation: [0, 2, 3, 1]))
  imageTensor /= 255.0
  if let mean = mean, let standardDeviation = standardDeviation {
    imageTensor = (imageTensor - mean) / standardDeviation
  }

  let labels = Tensor<Int32>(samples.map(\.label), on: device)
  return LabeledImage(data: imageTensor, label: labels)
}

Две строки этой функции объединяют все байты data из входящих BatchSamples в Tensor<UInt8> , который соответствует расположению байтов изображений в необработанном наборе данных CIFAR-10. Затем каналы изображения переупорядочиваются, чтобы соответствовать ожидаемым в наших стандартных моделях классификации изображений, а данные изображения преобразуются в Tensor<Float> для использования в модели.

Дополнительные параметры нормализации могут быть предоставлены для дальнейшей настройки значений каналов изображения — процесса, который часто встречается при обучении многих моделей классификации изображений. Параметр нормализации Tensor создается один раз при инициализации набора данных, а затем передается в makeBatch() в качестве оптимизации, чтобы предотвратить повторное создание небольших временных тензоров с одинаковыми значениями.

Наконец, целочисленные метки помещаются в Tensor<Int32> , а пара тензоров изображение/метка возвращается в LabeledImage . LabeledImage — это частный случай LabeledData , структуры с данными и метками, которые соответствуют протоколу Collatable API Eppch.

Дополнительные примеры API Epochs в различных типах наборов данных можно найти в других оболочках наборов данных в репозитории моделей.