Save the date! Google I/O returns May 18-20 Register now

Training loop

When training a machine learning model, it's common to have a loop where training data is ingested (or generated), batches run through a model, gradients obtained, and the model updated via an optimizer. While you can write a training loop of your own for each training application, Swift for TensorFlow provides an experimental training loop abstraction that may simplify this process.

The TrainingLoop module within the models repository contains the current version of this experimental generalized training loop. It is structured in such a way as to integrate with dataset wrappers that conform to the Epochs API for easy data ingestion, and to automate the interaction of models, datasets, and optimizers with accelerator backends to achieve optimal performance. Heavy customization of the training process can be achieved through the use of callbacks.

Most image-based examples in the model repository have been converted to use this training loop abstraction, as well as the supervised text model training examples. However, the training loop may not be appropriate in its current design for all machine learning models.

The implementation of Swift for TensorFlow's generalized training loop is heavily influenced by fastai's Learner. For more on their design, please refer to "fastai: A Layered API for Deep Learning" and Sylvain Gugger's presentation " - An infinitely customizable training loop".


The ResNet-CIFAR10 example provides a good demonstration of how to use this training loop in practice. First, import the module:

import TrainingLoop

then choose an accelerator backend by setting up a Device. In this case, we'll select the X10 XLA-based backend and use the first available accelerator:

let device = Device.defaultXLA

The next step is to configure the dataset, model, and optimizer to use with your training loop:

let dataset = CIFAR10(batchSize: 10, on: device)
var model = ResNet(classCount: 10, depth: .resNet56, downsamplingInFirstStage: false)
var optimizer = SGD(for: model, learningRate: 0.001)

and then set up the training loop:

var trainingLoop = TrainingLoop(
  validation: dataset.validation,
  optimizer: optimizer,
  lossFunction: softmaxCrossEntropy,
  metrics: [.accuracy])

The training loop assumes that the dataset you're using conforms to the Epochs API, and allows you to specify which splits within the dataset to use for training and validation. Any loss function can be used once placed into a compatible wrapper, such as softmaxCrossEntropy is here.

The current metrics that can be captured include:

  • loss
  • accuracy
  • top5Accuracy
  • matthewsCorrelationCoefficient
  • perplexity

Finally, to perform training, you call the following:

try!, epochs: 10, on: device)

This will train the model for 10 epochs using the accelerator backend we specified. Statistics will be displayed during training to the console using an animated prompt.


Customization of this generalized training loop occurs via the use of callbacks. These callbacks can be hooked into various points within the loop.

Several built-in callbacks provide functionality that can be added to any training loop. These include:

  • Logging statistics to comma-separated-value (CSV) files
  • Adjusting the learning rate according to a custom schedule
  • Monitoring and graphing training progress via TensorBoard

In addition to these, you can create your own custom callbacks to add a range of additional functionality to a standard training loop.

CSV logging

The CSVLogger class encapsulates a callback that will write out training statistics in a comma-separated-value format to a file of your choosing. This file will start with columns labeled epoch, batch, and whatever metrics you have enabled within your training loop. One row will then be written for each batch, with the current values of those columns.

To add CSV logging to your training loop, add something like the following to an array of callbacks provided to the callbacks: parameter for your TrainingLoop:

try! CSVLogger(path: "file.csv").log

As an example, the LeNet-MNIST sample uses this within its training loop.

Learning rate schedules

It's common when training a model to change the learning rate provided to an optimizer during the training process. This can be as simple as a linear decrease over time, or as complex as warmup and decline cycles described by complicated functions.

The learningRateScheduler callback provides the means of describing learning rate schedules composed of different segments, each with their own distinct shape. This is accomplished by defining a LearningRateSchedule composed of ScheduleSegments that each have a Shape defined by a function, an initial learning rate, and a final learning rate.

For example, the BERT-CoLA sample uses a linear increase in the learning rate during a warmup period and a linear decrease after that. To do this, the learning rate schedule callback is defined as follows:

  schedule: makeSchedule(
      ScheduleSegment(shape: linear, startRate: 0, endRate: peakLearningRate, stepCount: 10),
      ScheduleSegment(shape: linear, endRate: 0)

The two ScheduleSegments define a learning rate that starts at 0 and increases linearly to peakLearningRate over a series of 10 discrete steps, then starts at the final learning rate from the previous step and decreases linearly to 0 by the end of the training process.

TensorBoard integration

TensorBoard is a powerful visualization tool for monitoring model training, analyzing training when completed, or comparing training runs. Swift for TensorFlow supports TensorBoard visualization through the use of the TensorBoard module in the models repository, which provides callbacks that log training metrics.

The GPT2-WikiText2 sample illustrates how to add TensorBoard logging to your model training. First, import the TensorBoard module. Then it's as simple as adding tensorBoardStatisticsLogger() to your TrainingLoop's callbacks: array.

By default, that will log each training run within a run/tensorboard/stats directory. To view this within Tensorboard, run

tensorboard --logdir ./run/tensorboard/stats

and TensorBoard should start a local server where you can view your training metrics. Training and validation results should be shown separately, and each run has a unique timestamp to allow for easy comparison between multiple runs of the same model.

The design of the Swift for TensorFlow TensorBoard integration was inspired by tensorboardX. The TensorBoard callbacks directly create the appropriate event and summary protocol buffers and write them within a log file during training.

Custom callbacks

In addition to the built-in callbacks described above, you have the ability to customize the function of training loops by creating your own callbacks. These callbacks are functions that have a signature similar to the following:

func customCallback<L: TrainingLoopProtocol>(_ loop: inout L, event: TrainingLoopEvent) throws
  if event == .updateStart {

The training loop and associated state are passed in as the first parameter. The current part of the loop that the callback is responding to is provided via event. The training loop event has one of the following states, each corresponding to a different point in the loop's life cycle:

  • fitStart
  • fitEnd
  • epochStart
  • epochEnd
  • trainingStart
  • trainingEnd
  • validationStart
  • validationEnd
  • batchStart
  • batchEnd
  • updateStart
  • inferencePredictionEnd

Your callback function can choose to activate its logic on any combination of above states, which allows for extracting data from or otherwise controlling the training loop in many ways.