Conjuntos de dados

Em muitos modelos de aprendizado de máquina, especialmente no aprendizado supervisionado, os conjuntos de dados são uma parte vital do processo de treinamento. Swift para TensorFlow fornece wrappers para vários conjuntos de dados comuns no módulo Datasets no repositório de modelos . Esses wrappers facilitam o uso de conjuntos de dados comuns com modelos baseados em Swift e se integram bem ao loop de treinamento generalizado do Swift for TensorFlow.

Wrappers de conjunto de dados fornecidos

Estes são os wrappers de conjunto de dados fornecidos atualmente no repositório de modelos:

Para usar um desses wrappers de conjunto de dados em um projeto Swift, adicione Datasets como uma dependência ao seu destino Swift e importe o módulo:

import Datasets

A maioria dos wrappers de conjuntos de dados são projetados para produzir lotes de dados rotulados embaralhados aleatoriamente. Por exemplo, para usar o conjunto de dados CIFAR-10, primeiro inicialize-o com o tamanho de lote desejado:

let dataset = CIFAR10(batchSize: 100)

No primeiro uso, os wrappers do conjunto de dados do Swift para TensorFlow farão download automaticamente do conjunto de dados original para você, extrairão e analisarão todos os arquivos relevantes e, em seguida, armazenarão o conjunto de dados processado em um diretório de cache local do usuário. Os usos subsequentes do mesmo conjunto de dados serão carregados diretamente do cache local.

Para configurar um loop de treinamento manual envolvendo esse conjunto de dados, você usaria algo como o seguinte:

for (epoch, epochBatches) in dataset.training.prefix(100).enumerated() {
  Context.local.learningPhase = .training
  ...
  for batch in epochBatches {
    let (images, labels) = (batch.data, batch.label)
    ...
  }
}

O acima configura um iterador através de 100 épocas ( .prefix(100) ) e retorna o índice numérico da época atual e uma sequência mapeada preguiçosamente sobre lotes embaralhados que compõem essa época. Dentro de cada época de treinamento, os lotes são iterados e extraídos para processamento. No caso do wrapper do conjunto de dados CIFAR10 , cada lote é um LabeledImage , que fornece um Tensor<Float> contendo todas as imagens desse lote e um Tensor<Int32> com seus rótulos correspondentes.

No caso do CIFAR-10, todo o conjunto de dados é pequeno e pode ser carregado na memória de uma só vez, mas para outros conjuntos de dados maiores, os lotes são carregados lentamente do disco e processados ​​no ponto onde cada lote é obtido. Isso evita o esgotamento da memória com conjuntos de dados maiores.

A API Épocas

A maioria desses wrappers de conjunto de dados é construída em uma infraestrutura compartilhada que chamamos de API Epochs . Epochs fornece componentes flexíveis destinados a suportar uma ampla variedade de tipos de conjuntos de dados, de texto a imagens e muito mais.

Se você deseja criar seu próprio wrapper de conjunto de dados Swift, provavelmente desejará usar a API Epochs para fazer isso. No entanto, para casos comuns, como conjuntos de dados de classificação de imagens, é altamente recomendável começar com um modelo baseado em um dos wrappers de conjunto de dados existentes e modificá-lo para atender às suas necessidades específicas.

Como exemplo, vamos examinar o wrapper do conjunto de dados CIFAR-10 e como ele funciona. O núcleo do conjunto de dados de treinamento é definido aqui:

let trainingSamples = loadCIFARTrainingFiles(in: localStorageDirectory)
training = TrainingEpochs(samples: trainingSamples, batchSize: batchSize, entropy: entropy)
  .lazy.map { (batches: Batches) -> LazyMapSequence<Batches, LabeledImage> in
    return batches.lazy.map{
      makeBatch(samples: $0, mean: mean, standardDeviation: standardDeviation, device: device)
  }
}

O resultado da função loadCIFARTrainingFiles() é uma matriz de tuplas (data: [UInt8], label: Int32) para cada imagem no conjunto de dados de treinamento. Isso é então fornecido para TrainingEpochs(samples:batchSize:entropy:) para criar uma sequência infinita de épocas com lotes de batchSize . Você pode fornecer seu próprio gerador de números aleatórios nos casos em que desejar um comportamento determinístico de lote, mas por padrão o SystemRandomNumberGenerator é usado.

A partir daí, mapas lentos sobre os lotes culminam na função makeBatch(samples:mean:standardDeviation:device:) . Esta é uma função personalizada onde está localizado o pipeline de processamento de imagem real para o conjunto de dados CIFAR-10, então vamos dar uma olhada nisso:

fileprivate func makeBatch<BatchSamples: Collection>(
  samples: BatchSamples, mean: Tensor<Float>?, standardDeviation: Tensor<Float>?, device: Device
) -> LabeledImage where BatchSamples.Element == (data: [UInt8], label: Int32) {
  let bytes = samples.lazy.map(\.data).reduce(into: [], +=)
  let images = Tensor<UInt8>(shape: [samples.count, 3, 32, 32], scalars: bytes, on: device)

  var imageTensor = Tensor<Float>(images.transposed(permutation: [0, 2, 3, 1]))
  imageTensor /= 255.0
  if let mean = mean, let standardDeviation = standardDeviation {
    imageTensor = (imageTensor - mean) / standardDeviation
  }

  let labels = Tensor<Int32>(samples.map(\.label), on: device)
  return LabeledImage(data: imageTensor, label: labels)
}

As duas linhas desta função concatenam todos os bytes data dos BatchSamples recebidos em um Tensor<UInt8> que corresponde ao layout de bytes das imagens dentro do conjunto de dados CIFAR-10 bruto. Em seguida, os canais de imagem são reordenados para corresponder aos esperados em nossos modelos de classificação de imagem padrão e os dados da imagem são reordenados em um Tensor<Float> para consumo do modelo.

Parâmetros de normalização opcionais podem ser fornecidos para ajustar ainda mais os valores do canal de imagem, um processo comum no treinamento de muitos modelos de classificação de imagens. O parâmetro de normalização Tensor s é criado uma vez na inicialização do conjunto de dados e depois passado para makeBatch() como uma otimização para evitar a criação repetida de pequenos tensores temporários com os mesmos valores.

Finalmente, os rótulos inteiros são colocados em um Tensor<Int32> e o par tensor imagem/rótulo retornado em um LabeledImage . Um LabeledImage é um caso específico de LabeledData , uma estrutura com dados e rótulos que estão em conformidade com o protocolo Collatable da API Eppch.

Para obter mais exemplos da API Epochs em diferentes tipos de conjuntos de dados, você pode examinar os outros wrappers de conjuntos de dados no repositório de modelos.