On-Device Training in TensorFlow Lite

View on TensorFlow.org Run in Google Colab View source on GitHub Download notebook

To use TensorFlow Lite, a developer needs to prepare a TensorFlow model, use the converter to convert it to TensorFlow Lite model format, and run the model with the TensorFlow Lite runtime on device. This is true for inference use cases, and a similar flow can be applied to training too.

The following code illustrates the high-level flow of preparing a TensorFlow training model, converting it to TensorFlow Lite model and running in TensorFlow Lite runtime for a training use case.

The implementation is based on the Keras classification example in the TensorFlow official guide page.


pip uninstall -y tensorflow keras
pip install tf-nightly
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf

Classify images of clothing

This CoLab trains a neural network model to classify images of clothing, like sneakers and shirts.

Here, 60,000 images are used to train the network and 10,000 images to evaluate how accurately the network learned to classify images. You can access the Fashion MNIST directly from TensorFlow. Import and load the Fashion MNIST data directly from TensorFlow:

Loading the dataset returns four NumPy arrays:

  • The train_images and train_labels arrays are the training set—the data the model uses to learn.
  • The model is tested against the test set, the test_images, and test_labels arrays.
fashion_mnist = tf.keras.datasets.fashion_mnist

(train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-labels-idx1-ubyte.gz
32768/29515 [=================================] - 0s 0us/step
40960/29515 [=========================================] - 0s 0us/step
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-images-idx3-ubyte.gz
26427392/26421880 [==============================] - 0s 0us/step
26435584/26421880 [==============================] - 0s 0us/step
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-labels-idx1-ubyte.gz
16384/5148 [===============================================================================================] - 0s 0us/step
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-images-idx3-ubyte.gz
4423680/4422102 [==============================] - 0s 0us/step
4431872/4422102 [==============================] - 0s 0us/step

The images are 28x28 NumPy arrays, with pixel values ranging from 0 to 255. The labels are an array of integers, ranging from 0 to 9. These correspond to the class of clothing the image represents:

(60000, 28, 28)
Label Class
0 T-shirt/top
1 Trouser
2 Pullover
3 Dress
4 Coat
5 Sandal
6 Shirt
7 Sneaker
8 Bag
9 Ankle boot

Each image is mapped to a single label. Since the class names are not included with the dataset, store them here to use later when plotting the images:

class_names = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',
               'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']

Scale these values to a range of 0 to 1 before feeding them to the neural network model. To do so, divide the values by 255. It's important that the training set and the testing set be preprocessed in the same way:

train_images = train_images / 255.0

test_images = test_images / 255.0
for i in range(25):
    plt.imshow(train_images[i], cmap=plt.cm.binary)


TensorFlow Model for Training

Instead of converting a single TensorFlow model or tf.function to a TensorFlow Lite model with a single entry point, we can convert multiple tf.function(s) into a TensorFlow Lite model. To be able to do that, we're extending the TensorFlow Lite's converter & runtime to handle multiple signatures.

Preparing a TensorFlow Model. The code constructs a tf.module with 4 tf.functions:

  • train function trains the model with training data.
  • infer function invokes the inference.
  • save function saves the trainable weights into the file system.
  • restore function loads the trainable weights from the file system.

The weights will be serialized as a TensorFlow version one checkpoint file format.


class Model(tf.Module):

  def __init__(self):
    self.model = tf.keras.Sequential([
        tf.keras.layers.Flatten(input_shape=(IMG_SIZE, IMG_SIZE)),
        tf.keras.layers.Dense(128, activation='relu'),
        tf.keras.layers.Dense(10, activation='softmax')

    self._LOSS_FN = tf.keras.losses.CategoricalCrossentropy()
    self._OPTIM = tf.optimizers.SGD()

      tf.TensorSpec([None, IMG_SIZE, IMG_SIZE], tf.float32),
      tf.TensorSpec([None, 10], tf.float32),
  def train(self, x, y):
    with tf.GradientTape() as tape:
      prediction = self.model(x)
      loss = self._LOSS_FN(prediction, y)
    gradients = tape.gradient(loss, self.model.trainable_variables)
        zip(gradients, self.model.trainable_variables))
    result = {"loss": loss}
    for grad in gradients:
      result[grad.name] = grad
    return result

  @tf.function(input_signature=[tf.TensorSpec([None, IMG_SIZE, IMG_SIZE], tf.float32)])
  def predict(self, x):
    return {
        "output": self.model(x)

  @tf.function(input_signature=[tf.TensorSpec(shape=[], dtype=tf.string)])
  def save(self, checkpoint_path):
    tensor_names = [weight.name for weight in self.model.weights]
    tensors_to_save = [weight.read_value() for weight in self.model.weights]
        filename=checkpoint_path, tensor_names=tensor_names,
        data=tensors_to_save, name='save')
    return {
        "checkpoint_path": checkpoint_path

  @tf.function(input_signature=[tf.TensorSpec(shape=[], dtype=tf.string)])
  def restore(self, checkpoint_path):
    restored_tensors = {}
    for var in self.model.weights:
      restored = tf.raw_ops.Restore(
          file_pattern=checkpoint_path, tensor_name=var.name, dt=var.dtype,
      restored_tensors[var.name] = restored
    return restored_tensors

Converting to TensorFlow Lite format.

# Export the TensorFlow model to the saved model
SAVED_MODEL_DIR = "saved_model"
m= Model()

# Convert the model
converter = tf.lite.TFLiteConverter.from_saved_model(SAVED_MODEL_DIR)
converter.target_spec.supported_ops = [
    tf.lite.OpsSet.TFLITE_BUILTINS,  # enable TensorFlow Lite ops.
    tf.lite.OpsSet.SELECT_TF_OPS  # enable TensorFlow ops.
converter.experimental_enable_resource_variables = True
tflite_model = converter.convert()
INFO:tensorflow:Assets written to: saved_model/assets
WARNING:absl:Importing a function (__inference_internal_grad_fn_1051) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_internal_grad_fn_1079) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.

Executing in TensorFlow Lite runtime.

TensorFlow Lite's Interpreter capability will be extended to support multiple signatures too. Developers can choose to invoke restoring, training, saving and inferring signatures separately.

interpreter = tf.lite.Interpreter(model_content=tflite_model)

train = interpreter.get_signature_runner("train")
infer = interpreter.get_signature_runner("infer")
save = interpreter.get_signature_runner("save")
restore = interpreter.get_signature_runner("restore")

On Android, TensorFlow Lite on-device training can be performed using either Java or C++ APIs. In this document, we describe how the above TensorFlow Lite model will work with Java API.

Training with data set

Training can be done with the train signature method.

# Generate the training labels
processed_train_labels = []
for i in range(len(train_images)):
  train_label = [0] * 10
  train_label[train_labels[i]] = 1

# Run training for a few steps
for i in range(NUM_EPOCHS):
      x=tf.constant(train_images, shape=(len(train_images), IMG_SIZE, IMG_SIZE), dtype=tf.float32),
      y=tf.constant(processed_train_labels, shape=(len(train_images), 10), dtype=tf.float32))

In Java, you'll use the Interpreter class to load a model and drive model training tasks. The following example shows how to run the training procedure by using the runSignature method:

try (Interpreter interpreter = new Interpreter(modelBuffer)) {
    int NUM_EPOCHS = 100;
    int NUM_TRAININGS = 60000;
    float[][][] trainImages = new float[NUM_TRAININGS][28][28];
    float[][] trainLabels = new float[NUM_TRAININGS][10];

    // Fill the data values.

    // Run training for a few steps.
    for (int i = 0; i < NUM_EPOCHS; ++i) {
        Map<String, Object> inputs = new HashMap<>();
        inputs.put("x", trainImages);
        inputs.put("y", trainLabels);
        Map<String, Object> outputs = new HashMap<>();
        FloatBuffer loss = FloatBuffer.allocate(1);
        outputs.put("loss", loss);
        interpreter.runSignature(inputs, outputs, "train");

    // Do the other stuffs..

Exporting the trained weights to the checkpoint file

The checkpoint file can be generated through the save signature method.

# Export the trained weights to /tmp/model.ckpt
save(checkpoint_path=np.array("/tmp/model.ckpt", dtype=np.string_))
{'checkpoint_path': array(b'/tmp/model.ckpt', dtype=object)}

In Java, you can store the trained weight as a checkpoint format into the internal storage of the application. Training tasks usually perform at the idle time (e.g., at night time) in the background process occasionally.

try (Interpreter interpreter = new Interpreter(modelBuffer)) {
    // Conduct the training jobs.

    // Export the trained weights as a checkpoint file.
    File outputFile = new File(getFilesDir(), "checkpoint.ckpt");
    Map<String, Object> inputs = new HashMap<>();
    inputs.put("checkpoint_path", outputFile.getAbsolutePath());
    Map<String, Object> outputs = new HashMap<>();
    interpreter.runSignature(inputs, outputs, "save");

Restoring the trained weights from the checkpoint file

The exported checkpoint file can be restored through the restore signature method.

another_interpreter = tf.lite.Interpreter(model_content=tflite_model)

train = another_interpreter.get_signature_runner("train")
infer = another_interpreter.get_signature_runner("infer")
save = another_interpreter.get_signature_runner("save")
restore = another_interpreter.get_signature_runner("restore")

# Restore the trained weights from /tmp/model.ckpt
restore(checkpoint_path=np.array("/tmp/model.ckpt", dtype=np.string_))
{'dense/bias:0': array([-1.57285843e-03, -4.36150981e-03,  1.48836458e-02,  1.35667417e-02,
         7.46194832e-03, -4.83886991e-03,  6.57455437e-03, -4.73781518e-04,
         4.12419299e-03,  2.58954018e-02,  1.61975506e-03, -1.17487879e-02,
        -9.89729655e-04,  2.91385991e-03, -1.68202328e-03, -7.42188143e-03,
        -5.06101223e-03,  6.53741765e-04,  1.69041492e-02, -1.53265726e-02,
        -2.11372157e-03, -5.78631414e-04, -2.74046906e-04,  3.44356429e-03,
        -6.28103316e-03, -5.93029847e-03, -9.84243676e-03,  6.41197478e-03,
         1.40506737e-02,  1.76685315e-03, -4.20411956e-03,  1.15378276e-02,
        -3.02422279e-03,  2.73351930e-03, -5.04991924e-03, -1.11118972e-03,
         5.69828087e-03,  1.89777557e-02, -1.45693170e-03,  9.47452988e-03,
        -1.30979321e-03, -3.83566157e-03,  1.10030295e-02,  5.07608103e-03,
        -2.20827921e-03,  4.95553203e-03, -3.35005258e-04,  8.24076496e-03,
        -1.69422454e-03, -1.87482947e-05, -4.49026795e-03,  4.35718242e-03,
         1.78324934e-02,  1.42959831e-02, -2.30042241e-03,  5.18994220e-03,
         6.62731845e-03,  7.24743400e-03,  5.25178947e-03,  3.92448641e-02,
         3.73577769e-03,  8.25242326e-03, -4.64917533e-03,  2.43719462e-02,
         6.26876531e-03,  1.22555560e-02, -8.83383770e-03,  1.11999724e-03,
         1.49061400e-02, -4.74620517e-03,  2.67512929e-02,  1.11020561e-02,
        -1.90901093e-03,  2.44263411e-02,  7.79864145e-03, -2.52184668e-03,
        -6.87671639e-03,  3.68663599e-03,  5.41868430e-05,  2.91476957e-03,
        -4.00473690e-03, -1.32534467e-03,  9.50045511e-03,  6.97401725e-03,
         1.05041859e-03, -1.72942970e-03,  1.67397149e-02,  1.84007604e-02,
         1.70802865e-02, -2.74509331e-03,  1.19104227e-02, -7.76856346e-03,
         3.96032445e-03,  1.08306114e-04, -2.09370907e-03, -5.75081864e-03,
         2.46251584e-03,  4.41297237e-03,  3.82803148e-03, -7.36459717e-03,
        -4.46036365e-03,  1.71174686e-02, -5.10382559e-03,  3.27243935e-03,
        -8.88920447e-04, -1.15249865e-03,  2.40680855e-02,  7.59835262e-03,
        -6.81193219e-03,  9.22085746e-05, -8.08919687e-03, -3.34995426e-03,
        -5.73148811e-03,  1.89565066e-02,  2.22169347e-02,  8.48335121e-03,
         3.62862530e-03, -7.96046108e-03, -8.44814070e-03, -8.45994428e-03,
         3.14374454e-03, -3.04440828e-03,  5.37578622e-03,  1.18150758e-02,
         1.89486868e-03, -3.66321346e-03,  5.59679279e-03,  9.31394898e-05],
 'dense/kernel:0': array([[-0.08031233,  0.03281811, -0.06192348, ...,  0.04006809,
         -0.06608431, -0.00679768],
        [-0.07204922,  0.0010613 ,  0.04190496, ..., -0.05824471,
         -0.08109791,  0.00590287],
        [-0.01371243,  0.05592427, -0.05328341, ..., -0.02204874,
         -0.05019723,  0.01568204],
        [-0.00936218,  0.05286681, -0.04807849, ..., -0.02278425,
          0.04561274,  0.03458098],
        [ 0.0703925 , -0.00397872,  0.01036303, ..., -0.07454848,
          0.03762647,  0.03301041],
        [ 0.07140217,  0.03024638,  0.06050078, ..., -0.00729908,
          0.03561355, -0.01380337]], dtype=float32),
 'dense_1/bias:0': array([-0.01057897,  0.02798155, -0.02721532, -0.01178902, -0.03261909,
         0.03824002, -0.01674107,  0.06479462, -0.00622304, -0.02584666],
 'dense_1/kernel:0': array([[ 0.20532946,  0.04895966,  0.04806477, ...,  0.1842476 ,
          0.04001875,  0.02829899],
        [-0.11564036,  0.01309861, -0.1085465 , ...,  0.03441835,
          0.18818673, -0.10964528],
        [ 0.16526467,  0.27772224, -0.13799857, ..., -0.17034456,
         -0.08155934, -0.18946697],
        [-0.16586629, -0.14551383, -0.13446577, ..., -0.23064311,
         -0.12630565, -0.0663368 ],
        [-0.13030651,  0.16436552,  0.16922349, ...,  0.0384584 ,
         -0.11609874,  0.09569876],
        [ 0.17911045,  0.12671752, -0.12468987, ...,  0.05676334,
         -0.07585928,  0.1164245 ]], dtype=float32)}

In Java, you can restore the serialized trained weights from the file, stored at the internal storage. When the application restarts, the trained weights usually need to be restored prior to the inferences.

try (Interpreter another_interpreter = new Interpreter(modelBuffer)) {
    // Load the trained weights from the checkpoint file.
    File outputFile = new File(getFilesDir(), "checkpoint.ckpt");
    Map<String, Object> inputs = new HashMap<>();
    inputs.put("checkpoint_path", outputFile.getAbsolutePath());
    Map<String, Object> outputs = new HashMap<>();
    another_interpreter.runSignature(inputs, outputs, "restore");

Run the inference through the trained weights

Developers can use the trained model to run inference through the infer signature method.

# Run the inference
result = infer(
    x=tf.constant(test_images, shape=(len(test_images), IMG_SIZE, IMG_SIZE), dtype=tf.float32))

test_labels = np.argmax(result["output"], axis=1)
for i in range(25):
    plt.imshow(test_images[i], cmap=plt.cm.binary)


In Java, after restoring the trained weights, developers can run the inferences based on the loaded data.

try (Interpreter another_interpreter = new Interpreter(modelBuffer)) {
    // Restore the weights from the checkpoint file.

    int NUM_TESTS = 10;
    float[][][] testImages = new float[NUM_TESTS][28][28];
    float[][] output = new float[NUM_TESTS][10];

    // Fill the test data.

    // Run the inference.
    inputs = new HashMap<>();
    inputs.put("x", testImages);
    outputs = new HashMap<>();
    outputs.put("output", output);
    another_interpreter.runSignature(inputs, outputs, "infer");

    // Process the result to get the final category values.
    int[] testLabels = new int[NUM_TESTS];
    for (int i = 0; i < NUM_TESTS; ++i) {
        int index = 0;
        for (int j = 1; j < 10; ++j) {
            if (output[i][index] < output[i][j]) index = testLabels[j];
        testLabels[i] = index;