Apply to speak at TensorFlow World. Deadline April 23rd. Propose talk

Model training walkthrough

View on TensorFlow.org Run in Google Colab View source on GitHub

This guide introduces Swift for TensorFlow by building a machine learning model that categorizes iris flowers by species. It uses Swift for TensorFlow to: 1. Build a model, 2. Train this model on example data, and 3. Use the model to make predictions about unknown data.

TensorFlow programming

This guide uses these high-level Swift for TensorFlow concepts:

  • Import data with the Datasets API.
  • Build models using Swift abstractions.
  • Use Python libraries using Swift's Python interoperability when pure Swift libraries are not available.

This tutorial is structured like many TensorFlow programs:

  1. Import and parse the data sets.
  2. Select the type of model.
  3. Train the model.
  4. Evaluate the model's effectiveness.
  5. Use the trained model to make predictions.

Setup program

Configure imports

Import TensorFlow and some useful Python modules.

import TensorFlow

import Python
%include "EnableIPythonDisplay.swift"
IPythonDisplay.shell.enable_matplotlib("inline")
let plt = Python.import("matplotlib.pyplot")

// Download a helper file that helps us work around some temporary limitations
// in the dataset API.
let path = Python.import("os.path")
let urllib = Python.import("urllib.request")
let helperFiles = ["TutorialDatasetCSVAPI.swift"]
let helperFileBaseURL = "https://raw.githubusercontent.com/tensorflow/swift/master/docs/site/tutorials/"
for helperFile in helperFiles {
    if !Bool(path.isfile(helperFile))! {
        print("Downloading \(helperFile)")
        urllib.urlretrieve(helperFileBaseURL + helperFile, filename: helperFile)
    } else {
        print("Not downloading \(helperFile): already exists")
    }
}
Downloading TutorialDatasetCSVAPI.swift

The iris classification problem

Imagine you are a botanist seeking an automated way to categorize each iris flower you find. Machine learning provides many algorithms to classify flowers statistically. For instance, a sophisticated machine learning program could classify flowers based on photographs. Our ambitions are more modest—we're going to classify iris flowers based on the length and width measurements of their sepals and petals.

The Iris genus entails about 300 species, but our program will only classify the following three:

  • Iris setosa
  • Iris virginica
  • Iris versicolor
Petal geometry compared for three iris species: Iris setosa, Iris virginica, and Iris versicolor
Figure 1. Iris setosa (by Radomil, CC BY-SA 3.0), Iris versicolor, (by Dlanglois, CC BY-SA 3.0), and Iris virginica (by Frank Mayfield, CC BY-SA 2.0).
 

Fortunately, someone has already created a data set of 120 iris flowers with the sepal and petal measurements. This is a classic dataset that is popular for beginner machine learning classification problems.

Import and parse the training dataset

Download the dataset file and convert it into a structure that can be used by this Swift program.

Download the dataset

Download the training dataset file from http://download.tensorflow.org/data/iris_training.csv. We use a Python library to demonstrate Swift's Python interoperability. Swift's Python interoperability makes it easy and natural to import and use Python libraries from Swift code.

let urllib = Python.import("urllib.request")
let downloadResult = urllib.urlretrieve(
    "http://download.tensorflow.org/data/iris_training.csv",
    "iris_training.csv")
let trainDataFilename = String(downloadResult[0])!
trainDataFilename
"iris_training.csv"

Inspect the data

This dataset, iris_training.csv, is a plain text file that stores tabular data formatted as comma-separated values (CSV). Let's look a the first 5 entries.

let f = Python.open(trainDataFilename)
for _ in 0..<5 {
    print(Python.next(f).strip())
}
f.close()
120,4,setosa,versicolor,virginica
6.4,2.8,5.6,2.2,2
5.0,2.3,3.3,1.0,1
4.9,2.5,4.5,1.7,2
4.9,3.1,1.5,0.1,0

None

From this view of the dataset, notice the following:

  1. The first line is a header containing information about the dataset:
    • There are 120 total examples. Each example has four features and one of three possible label names.
  2. Subsequent rows are data records, one example per line, where:
    • The first four fields are features: these are characteristics of an example. Here, the fields hold float numbers representing flower measurements.
    • The last column is the label: this is the value we want to predict. For this dataset, it's an integer value of 0, 1, or 2 that corresponds to a flower name.

Let's write that out in code:

let featureNames = ["sepal_length", "sepal_width", "petal_length", "petal_width"]
let labelName = "species"
let columnNames = featureNames + [labelName]

print("Features: \(featureNames)")
print("Label: \(labelName)")
Features: ["sepal_length", "sepal_width", "petal_length", "petal_width"]
Label: species

Each label is associated with string name (for example, "setosa"), but machine learning typically relies on numeric values. The label numbers are mapped to a named representation, such as:

  • 0: Iris setosa
  • 1: Iris versicolor
  • 2: Iris virginica

For more information about features and labels, see the ML Terminology section of the Machine Learning Crash Course.

let classNames = ["Iris setosa", "Iris versicolor", "Iris virginica"]

Create a Dataset

Swift for TensorFlow's Dataset API is a high-level API for reading data and transforming it into a form used for training.

Eventually, the Dataset API will be able to load data from many file formats. The Dataset API is currently very incomplete, so we include some glue code from "TutorialDatasetCSVAPI.swift" that helps the Dataset API load data from the CSV file into the IrisBatch struct.

let batchSize = Int64(32)

/// A batch of examples from the iris dataset.
struct IrisBatch {
    /// [batchSize, featureCount] tensor of features.
    let features: Tensor<Float>

    /// [batchSize] tensor of labels.
    let labels: Tensor<Int32>
}

%include "TutorialDatasetCSVAPI.swift"

let trainDataset: Dataset<IrisBatch> = Dataset(
    contentsOfCSVFile: trainDataFilename, hasHeader: true,
    featureColumns: [0, 1, 2, 3], labelColumns: [4]
).batched(batchSize)

The Dataset<IrisBatch> is a sequence of IrisBatches. Let's look at the first element of the dataset.

let firstTrainExamples = trainDataset.first!
let firstTrainFeatures = firstTrainExamples.features
let firstTrainLabels = firstTrainExamples.labels
print("First batch of features: \(firstTrainFeatures)")
print("First batch of labels: \(firstTrainLabels)")
First batch of features: [[6.4, 2.8, 5.6, 2.2], [5.0, 2.3, 3.3, 1.0], [4.9, 2.5, 4.5, 1.7], [4.9, 3.1, 1.5, 0.1], [5.7, 3.8, 1.7, 0.3], [4.4, 3.2, 1.3, 0.2], [5.4, 3.4, 1.5, 0.4], [6.9, 3.1, 5.1, 2.3], [6.7, 3.1, 4.4, 1.4], [5.1, 3.7, 1.5, 0.4], [5.2, 2.7, 3.9, 1.4], [6.9, 3.1, 4.9, 1.5], [5.8, 4.0, 1.2, 0.2], [5.4, 3.9, 1.7, 0.4], [7.7, 3.8, 6.7, 2.2], [6.3, 3.3, 4.7, 1.6], [6.8, 3.2, 5.9, 2.3], [7.6, 3.0, 6.6, 2.1], [6.4, 3.2, 5.3, 2.3], [5.7, 4.4, 1.5, 0.4], [6.7, 3.3, 5.7, 2.1], [6.4, 2.8, 5.6, 2.1], [5.4, 3.9, 1.3, 0.4], [6.1, 2.6, 5.6, 1.4], [7.2, 3.0, 5.8, 1.6], [5.2, 3.5, 1.5, 0.2], [5.8, 2.6, 4.0, 1.2], [5.9, 3.0, 5.1, 1.8], [5.4, 3.0, 4.5, 1.5], [6.7, 3.0, 5.0, 1.7], [6.3, 2.3, 4.4, 1.3], [5.1, 2.5, 3.0, 1.1]]
First batch of labels: [2, 1, 2, 0, 0, 0, 0, 2, 1, 0, 1, 1, 0, 0, 2, 1, 2, 2, 2, 0, 2, 2, 0, 2, 2, 0, 1, 2, 1, 1, 1, 1]

Notice that the features for the first batchSize examples are grouped together (or batched) into firstTrainFeatures, and that the labels for the first batchSize examples are batched into firstTrainLabels.

You can start to see some clusters by plotting a few features from the batch, using Python's matplotlib:

let firstTrainFeaturesTransposed = firstTrainFeatures.transposed()
let petalLengths = firstTrainFeaturesTransposed[2].scalars
let sepalLengths = firstTrainFeaturesTransposed[0].scalars

plt.scatter(petalLengths, sepalLengths, c: firstTrainLabels.array.scalars)
plt.xlabel("Petal length")
plt.ylabel("Sepal length")
plt.show()

png

None

Select the type of model

Why model?

A model is a relationship between features and the label. For the iris classification problem, the model defines the relationship between the sepal and petal measurements and the predicted iris species. Some simple models can be described with a few lines of algebra, but complex machine learning models have a large number of parameters that are difficult to summarize.

Could you determine the relationship between the four features and the iris species without using machine learning? That is, could you use traditional programming techniques (for example, a lot of conditional statements) to create a model? Perhaps—if you analyzed the dataset long enough to determine the relationships between petal and sepal measurements to a particular species. And this becomes difficult—maybe impossible—on more complicated datasets. A good machine learning approach determines the model for you. If you feed enough representative examples into the right machine learning model type, the program will figure out the relationships for you.

Select the model

We need to select the kind of model to train. There are many types of models and picking a good one takes experience. This tutorial uses a neural network to solve the iris classification problem. Neural networks can find complex relationships between features and the label. It is a highly-structured graph, organized into one or more hidden layers. Each hidden layer consists of one or more neurons. There are several categories of neural networks and this program uses a dense, or fully-connected neural network: the neurons in one layer receive input connections from every neuron in the previous layer. For example, Figure 2 illustrates a dense neural network consisting of an input layer, two hidden layers, and an output layer:

A diagram of the network architecture: Inputs, 2 hidden layers, and outputs
Figure 2. A neural network with features, hidden layers, and predictions.
 

When the model from Figure 2 is trained and fed an unlabeled example, it yields three predictions: the likelihood that this flower is the given iris species. This prediction is called inference. For this example, the sum of the output predictions is 1.0. In Figure 2, this prediction breaks down as: 0.02 for Iris setosa, 0.95 for Iris versicolor, and 0.03 for Iris virginica. This means that the model predicts—with 95% probability—that an unlabeled example flower is an Iris versicolor.

Create a model using the Swift for TensorFlow Deep Learning Library

The Swift for TensorFlow Deep Learning Library defines primitive layers and conventions for wiring them together, which makes it easy to build models and experiment.

A model is a struct that conforms to Layer, which means that it defines an applied(to:in:) method that maps input Tensors to output Tensors. The applied(to:in:) method often simply sequences the input through sublayers. Let's define an IrisModel that sequences the input through three Dense sublayers.

import TensorFlow

let hiddenSize: Int = 10
struct IrisModel: Layer {
    var layer1 = Dense<Float>(inputSize: 4, outputSize: hiddenSize, activation: relu)
    var layer2 = Dense<Float>(inputSize: hiddenSize, outputSize: hiddenSize, activation: relu)
    var layer3 = Dense<Float>(inputSize: hiddenSize, outputSize: 3)
    
    @differentiable
    func applied(to input: Tensor<Float>, in context: Context) -> Tensor<Float> {
        return input.sequenced(in: context, through: layer1, layer2, layer3)
    }
}

var model = IrisModel()

The activation function determines the output shape of each node in the layer. These non-linearities are important—without them the model would be equivalent to a single layer. There are many available activations, but ReLU is common for hidden layers.

The ideal number of hidden layers and neurons depends on the problem and the dataset. Like many aspects of machine learning, picking the best shape of the neural network requires a mixture of knowledge and experimentation. As a rule of thumb, increasing the number of hidden layers and neurons typically creates a more powerful model, which requires more data to train effectively.

Using the model

Let's have a quick look at what this model does to a batch of features:

// A `Context` tells the model what it's being used for. The `IrisModel` does
// not actually use the `Context`, but the `Layer` API requires `Context` so
// that it can support more complex models that depend on context. Here, we
// construct an "inference" context because we're using the model to infer
// labels.
let inferenceContext = Context(learningPhase: .inference)

// Apply the model to a batch of features.
let firstTrainPredictions = model.applied(to: firstTrainFeatures, in: inferenceContext)
firstTrainPredictions[0..<5]
[[-3.739069, 3.533737, 7.177187], [-2.4676516, 2.456597, 4.8783565], [-3.027687, 2.795834, 5.7705793], [-2.1089425, 2.004953, 3.863989], [-2.6216679, 2.3670707, 4.5700254]]

Here, each example returns a logit for each class.

To convert these logits to a probability for each class, use the softmax function:

softmax(firstTrainPredictions[0..<5])
[[1.7697288e-05, 0.025494482, 0.9744879], [0.00059221254, 0.08148013, 0.9179277], [0.00014363867, 0.048572946, 0.9512834], [0.0021985762, 0.13451903, 0.8632824], [0.00067746005, 0.09941807, 0.8999045]]

Taking the argmax across classes gives us the predicted class index. But, the model hasn't been trained yet, so these aren't good predictions.

print("Prediction: \(firstTrainPredictions.argmax(squeezingAxis: 1))")
print("    Labels: \(firstTrainLabels)")
Prediction: [2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2]
    Labels: [2, 1, 2, 0, 0, 0, 0, 2, 1, 0, 1, 1, 0, 0, 2, 1, 2, 2, 2, 0, 2, 2, 0, 2, 2, 0, 1, 2, 1, 1, 1, 1]

Train the model

Training is the stage of machine learning when the model is gradually optimized, or the model learns the dataset. The goal is to learn enough about the structure of the training dataset to make predictions about unseen data. If you learn too much about the training dataset, then the predictions only work for the data it has seen and will not be generalizable. This problem is called overfitting—it's like memorizing the answers instead of understanding how to solve a problem.

The iris classification problem is an example of supervised machine learning: the model is trained from examples that contain labels. In unsupervised machine learning, the examples don't contain labels. Instead, the model typically finds patterns among the features.

Choose a loss function

Both training and evaluation stages need to calculate the model's loss. This measures how off a model's predictions are from the desired label, in other words, how bad the model is performing. We want to minimize, or optimize, this value.

Our model will calculate its loss using the softmaxCrossEntropy(logits:labels:) function which takes the model's class probability predictions and the desired label, and returns the average loss across the examples.

Let's calculate the loss for the current untrained model:

let trainingContext = Context(learningPhase: .training)

let untrainedLogits = model.applied(to: firstTrainFeatures, in: trainingContext)
let untrainedLoss = softmaxCrossEntropy(logits: untrainedLogits, labels: firstTrainLabels)
print("Loss test: \(untrainedLoss)")
Loss test: 3.1096826

Create an optimizer

An optimizer applies the computed gradients to the model's variables to minimize the loss function. You can think of the loss function as a curved surface (see Figure 3) and we want to find its lowest point by walking around. The gradients point in the direction of steepest ascent—so we'll travel the opposite way and move down the hill. By iteratively calculating the loss and gradient for each batch, we'll adjust the model during training. Gradually, the model will find the best combination of weights and bias to minimize loss. And the lower the loss, the better the model's predictions.

Optimization algorithms visualized over time in 3D space.
Figure 3. Optimization algorithms visualized over time in 3D space.
(Source: Stanford class CS231n, MIT License, Image credit: Alec Radford)

Swift for TensorFlow has many optimization algorithms available for training. This model uses the SGD optimizer that implements the stochastic gradient descent (SGD) algorithm. The learningRate sets the step size to take for each iteration down the hill. This is a hyperparameter that you'll commonly adjust to achieve better results.

let optimizer = SGD<IrisModel, Float>(learningRate: 0.01)

Let's use optimizer to take a single gradient descent step. First, we compute the gradient of the loss with respect to the model:

let (loss, grads) = model.valueWithGradient { model -> Tensor<Float> in
    let logits = model.applied(to: firstTrainFeatures, in: trainingContext)
    return softmaxCrossEntropy(logits: logits, labels: firstTrainLabels)
}
print("Current loss: \(loss)")
Current loss: 3.1096826

Next, we pass the gradient that we just calculated to the optimizer, which updates the model's differentiable variables accordingly:

optimizer.update(&model.allDifferentiableVariables, along: grads)

If we calculate the loss again, it should be smaller, because gradient descent steps (usually) decrease the loss:

let logitsAfterOneStep = model.applied(to: firstTrainFeatures, in: trainingContext)
let lossAfterOneStep = softmaxCrossEntropy(logits: logitsAfterOneStep, labels: firstTrainLabels)
print("Next loss: \(lossAfterOneStep)")
Next loss: 2.565462

Training loop

With all the pieces in place, the model is ready for training! A training loop feeds the dataset examples into the model to help it make better predictions. The following code block sets up these training steps:

  1. Iterate each epoch. An epoch is one pass through the dataset.
  2. Within an epoch, iterate over each example in the training Dataset grabbing its features (x) and label (y).
  3. Using the example's features, make a prediction and compare it with the label. Measure the inaccuracy of the prediction and use that to calculate the model's loss and gradients.
  4. Use gradient descent to update the model's variables.
  5. Keep track of some stats for visualization.
  6. Repeat for each epoch.

The epochCount variable is the number of times to loop over the dataset collection. Counter-intuitively, training a model longer does not guarantee a better model. epochCount is a hyperparameter that you can tune. Choosing the right number usually requires both experience and experimentation.

let epochCount = 500
var trainAccuracyResults: [Float] = []
var trainLossResults: [Float] = []
func accuracy(predictions: Tensor<Int32>, truths: Tensor<Int32>) -> Float {
    return Tensor<Float>(predictions .== truths).mean().scalarized()
}

for epoch in 1...epochCount {
    var epochLoss: Float = 0
    var epochAccuracy: Float = 0
    var batchCount: Int = 0
    for batch in trainDataset {
        let (loss, grad) = model.valueWithGradient { (model: IrisModel) -> Tensor<Float> in
            let logits = model.applied(to: batch.features, in: trainingContext)
            return softmaxCrossEntropy(logits: logits, labels: batch.labels)
        }
        optimizer.update(&model.allDifferentiableVariables, along: grad)
        
        let logits = model.applied(to: batch.features, in: trainingContext)
        epochAccuracy += accuracy(predictions: logits.argmax(squeezingAxis: 1), truths: batch.labels)
        epochLoss += loss.scalarized()
        batchCount += 1
    }
    epochAccuracy /= Float(batchCount)
    epochLoss /= Float(batchCount)
    trainAccuracyResults.append(epochAccuracy)
    trainLossResults.append(epochLoss)
    if epoch % 50 == 0 {
        print("Epoch \(epoch): Loss: \(epochLoss), Accuracy: \(epochAccuracy)")
    }
}
Epoch 50: Loss: 0.13649629, Accuracy: 0.9921875
Epoch 100: Loss: 0.12866701, Accuracy: 0.9921875
Epoch 150: Loss: 0.12233954, Accuracy: 0.9921875
Epoch 200: Loss: 0.11711606, Accuracy: 0.9921875
Epoch 250: Loss: 0.11268301, Accuracy: 0.9921875
Epoch 300: Loss: 0.108866386, Accuracy: 0.9921875
Epoch 350: Loss: 0.10561609, Accuracy: 0.9921875
Epoch 400: Loss: 0.102745734, Accuracy: 0.9921875
Epoch 450: Loss: 0.10018992, Accuracy: 0.9921875
Epoch 500: Loss: 0.097879, Accuracy: 0.9921875

Visualize the loss function over time

While it's helpful to print out the model's training progress, it's often more helpful to see this progress. We can create basic charts using Python's matplotlib module.

Interpreting these charts takes some experience, but you really want to see the loss go down and the accuracy go up.

plt.figure(figsize: [12, 8])

let accuracyAxes = plt.subplot(2, 1, 1)
accuracyAxes.set_ylabel("Accuracy")
accuracyAxes.plot(trainAccuracyResults)

let lossAxes = plt.subplot(2, 1, 2)
lossAxes.set_ylabel("Loss")
lossAxes.set_xlabel("Epoch")
lossAxes.plot(trainLossResults)

plt.show()

png

None

Note that the y-axes of the graphs are not zero-based.

Evaluate the model's effectiveness

Now that the model is trained, we can get some statistics on its performance.

Evaluating means determining how effectively the model makes predictions. To determine the model's effectiveness at iris classification, pass some sepal and petal measurements to the model and ask the model to predict what iris species they represent. Then compare the model's prediction against the actual label. For example, a model that picked the correct species on half the input examples has an accuracy of 0.5. Figure 4 shows a slightly more effective model, getting 4 out of 5 predictions correct at 80% accuracy:

Example features Label Model prediction
5.93.04.31.511
6.93.15.42.122
5.13.31.70.500
6.0 3.4 4.5 1.6 12
5.52.54.01.311
Figure 4. An iris classifier that is 80% accurate.
 

Setup the test dataset

Evaluating the model is similar to training the model. The biggest difference is the examples come from a separate test set rather than the training set. To fairly assess a model's effectiveness, the examples used to evaluate a model must be different from the examples used to train the model.

The setup for the test Dataset is similar to the setup for training Dataset. Download the test set from http://download.tensorflow.org/data/iris_training.csv:

let urllib = Python.import("urllib.request")
let downloadResult = urllib.urlretrieve(
    "http://download.tensorflow.org/data/iris_test.csv",
    "iris_test.csv")
let testDataFilename = String(downloadResult[0])!
testDataFilename
"iris_test.csv"

Now load it into a Dataset:

let testDataset: Dataset<IrisBatch> = Dataset(
    contentsOfCSVFile: testDataFilename, hasHeader: true,
    featureColumns: [0, 1, 2, 3], labelColumns: [4]
).batched(batchSize)

Evaluate the model on the test dataset

Unlike the training stage, the model only evaluates a single epoch of the test data. In the following code cell, we iterate over each example in the test set and compare the model's prediction against the actual label. This is used to measure the model's accuracy across the entire test set.

// NOTE: With `batchSize = 32` and 30 examples in the test dataset, only one batch will run in the loop.
for testBatch in testDataset {
    let logits = model.applied(to: testBatch.features, in: inferenceContext)
    let predictions = logits.argmax(squeezingAxis: 1)
    print("Test batch accuracy: \(accuracy(predictions: predictions, truths: testBatch.labels))")
}
Test batch accuracy: 0.96666664

We can see on the first batch, for example, the model is usually correct:

let firstTestBatch = testDataset.first!
let firstTestBatchLogits = model.applied(to: firstTestBatch.features, in: inferenceContext)
let firstTestBatchPredictions = firstTestBatchLogits.argmax(squeezingAxis: 1)

print(firstTestBatchPredictions)
print(firstTestBatch.labels)
[1, 2, 0, 1, 1, 1, 0, 1, 1, 2, 2, 0, 2, 1, 1, 0, 1, 0, 0, 2, 0, 1, 2, 1, 1, 1, 0, 1, 2, 1]
[1, 2, 0, 1, 1, 1, 0, 2, 1, 2, 2, 0, 2, 1, 1, 0, 1, 0, 0, 2, 0, 1, 2, 1, 1, 1, 0, 1, 2, 1]

Use the trained model to make predictions

We've trained a model and demonstrated that it's good—but not perfect—at classifying iris species. Now let's use the trained model to make some predictions on unlabeled examples; that is, on examples that contain features but not a label.

In real-life, the unlabeled examples could come from lots of different sources including apps, CSV files, and data feeds. For now, we're going to manually provide three unlabeled examples to predict their labels. Recall, the label numbers are mapped to a named representation as:

  • 0: Iris setosa
  • 1: Iris versicolor
  • 2: Iris virginica
let unlabeledDataset: Tensor<Float> =
    [[5.1, 3.3, 1.7, 0.5],
     [5.9, 3.0, 4.2, 1.5],
     [6.9, 3.1, 5.4, 2.1]]

let unlabeledDatasetPredictions = model.applied(to: unlabeledDataset, in: inferenceContext)

for i in 0..<unlabeledDatasetPredictions.shape[0] {
    let logits = unlabeledDatasetPredictions[i]
    let classIdx = logits.argmax().scalar!
    print("Example \(i) prediction: \(classNames[Int(classIdx)]) (\(softmax(logits)))")
}
Example 0 prediction: Iris setosa ([0.9365366, 0.042833436, 0.020630043])
Example 1 prediction: Iris versicolor ([0.0003901447, 0.996633, 0.0029768792])
Example 2 prediction: Iris virginica ([1.0952108e-05, 0.12486785, 0.8751212])