On-Device Training with TensorFlow Lite

View on TensorFlow.org Run in Google Colab View source on GitHub Download notebook

When deploying TensorFlow Lite machine learning model to device or mobile app, you may want to enable the model to be improved or personalized based on input from the device or end user. Using on-device training techniques allows you to update a model without data leaving your users' devices, improving user privacy, and without requiring users to update the device software.

For example, you may have a model in your mobile app that recognizes fashion items, but you want users to get improved recognition performance over time based on their interests. Enabling on-device training allows users who are interested in shoes to get better at recognizing a particular style of shoe or shoe brand the more often they use your app.

This tutorial shows you how to construct a TensorFlow Lite model that can be incrementally trained and improved within an installed Android app.


This tutorial uses Python to train and convert a TensorFlow model before incorporating it into an Android app. Get started by installing and importing the following packages.

import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf

print("TensorFlow version:", tf.__version__)
TensorFlow version: 2.8.0

Classify images of clothing

This example code uses the Fashion MNIST dataset to train a neural network model for classifying images of clothing. This dataset contains 60,000 small (28 x 28 pixel) grayscale images containing 10 different categories of fashion accessories, including dresses, shirts, and sandals.

Fashion MNIST images
Figure 1: Fashion-MNIST samples (by Zalando, MIT License).

You can explore this dataset in more depth in the Keras classification tutorial.

Build a model for on-device training

TensorFlow Lite models typically have only a single exposed function method (or signature) that allows you to call the model to run an inference. For a model to be trained and used on a device, you must be able to perform several separate operations, including train, infer, save, and restore functions for the model. You can enable this functionality by first extending your TensorFlow model to have multiple functions, and then exposing those functions as signatures when you convert your model to the TensorFlow Lite model format.

The code example below shows you how to add the following functions to a TensorFlow model:

  • train function trains the model with training data.
  • infer function invokes the inference.
  • save function saves the trainable weights into the file system.
  • restore function loads the trainable weights from the file system.

class Model(tf.Module):

  def __init__(self):
    self.model = tf.keras.Sequential([
        tf.keras.layers.Flatten(input_shape=(IMG_SIZE, IMG_SIZE), name='flatten'),
        tf.keras.layers.Dense(128, activation='relu', name='dense_1'),
        tf.keras.layers.Dense(10, name='dense_2')


  # The `train` function takes a batch of input images and labels.
      tf.TensorSpec([None, IMG_SIZE, IMG_SIZE], tf.float32),
      tf.TensorSpec([None, 10], tf.float32),
  def train(self, x, y):
    with tf.GradientTape() as tape:
      prediction = self.model(x)
      loss = self.model.loss(y, prediction)
    gradients = tape.gradient(loss, self.model.trainable_variables)
        zip(gradients, self.model.trainable_variables))
    result = {"loss": loss}
    return result

      tf.TensorSpec([None, IMG_SIZE, IMG_SIZE], tf.float32),
  def infer(self, x):
    logits = self.model(x)
    probabilities = tf.nn.softmax(logits, axis=-1)
    return {
        "output": probabilities,
        "logits": logits

  @tf.function(input_signature=[tf.TensorSpec(shape=[], dtype=tf.string)])
  def save(self, checkpoint_path):
    tensor_names = [weight.name for weight in self.model.weights]
    tensors_to_save = [weight.read_value() for weight in self.model.weights]
        filename=checkpoint_path, tensor_names=tensor_names,
        data=tensors_to_save, name='save')
    return {
        "checkpoint_path": checkpoint_path

  @tf.function(input_signature=[tf.TensorSpec(shape=[], dtype=tf.string)])
  def restore(self, checkpoint_path):
    restored_tensors = {}
    for var in self.model.weights:
      restored = tf.raw_ops.Restore(
          file_pattern=checkpoint_path, tensor_name=var.name, dt=var.dtype,
      restored_tensors[var.name] = restored
    return restored_tensors

The train function in the code above uses the GradientTape class to record operations for automatic differentiation. For more information on how to use this class, see the Introduction to gradients and automatic differentiation.

You could use the Model.train_step method of the keras model here instead of a from-scratch implementation. Just note that the loss (and metrics) returned by Model.train_step is the running average, and should be reset regularly (typically each epoch). See Customize Model.fit for details.

Prepare the data

Get the Fashion MNIST dataset for training your model.

fashion_mnist = tf.keras.datasets.fashion_mnist
(train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()

Preprocess the dataset

Pixel values in this dataset are between 0 and 255, and must be normalized to a value between 0 and 1 for processing by the model. Divide the values by 255 to make this adjustment.

train_images = (train_images / 255.0).astype(np.float32)
test_images = (test_images / 255.0).astype(np.float32)

Convert the data labels to categorical values by performing one-hot encoding.

train_labels = tf.keras.utils.to_categorical(train_labels)
test_labels = tf.keras.utils.to_categorical(test_labels)

Train the model

Before converting and setting up your TensorFlow Lite model, complete the initial training of your model using the preprocessed dataset and the train signature method. The following code runs model training for 100 epochs, processing batches of 100 images at a time, and displaying the loss value after every 10 epochs. Since this training run is processing quite a bit of data, it may take a few minutes to finish.

epochs = np.arange(1, NUM_EPOCHS + 1, 1)
losses = np.zeros([NUM_EPOCHS])
m = Model()

train_ds = tf.data.Dataset.from_tensor_slices((train_images, train_labels))
train_ds = train_ds.batch(BATCH_SIZE)

for i in range(NUM_EPOCHS):
  for x,y in train_ds:
    result = m.train(x, y)

  losses[i] = result['loss']
  if (i + 1) % 10 == 0:
    print(f"Finished {i+1} epochs")
    print(f"  loss: {losses[i]:.3f}")

# Save the trained weights to a checkpoint.
Finished 10 epochs
  loss: 0.428
Finished 20 epochs
  loss: 0.378
Finished 30 epochs
  loss: 0.344
Finished 40 epochs
  loss: 0.317
Finished 50 epochs
  loss: 0.299
Finished 60 epochs
  loss: 0.283
Finished 70 epochs
  loss: 0.266
Finished 80 epochs
  loss: 0.252
Finished 90 epochs
  loss: 0.240
Finished 100 epochs
  loss: 0.230
{'checkpoint_path': <tf.Tensor: shape=(), dtype=string, numpy=b'/tmp/model.ckpt'>}
plt.plot(epochs, losses, label='Pre-training')
plt.ylim([0, max(plt.ylim())])
plt.ylabel('Loss [Cross Entropy]')


Convert model to TensorFlow Lite format

After you have extended your TensorFlow model to enable additional functions for on-device training and completed initial training of the model, you can convert it to TensorFlow Lite format. The following code converts and saves your model to that format, including the set of signatures that you use with the TensorFlow Lite model on a device: train, infer, save, restore.

SAVED_MODEL_DIR = "saved_model"


# Convert the model
converter = tf.lite.TFLiteConverter.from_saved_model(SAVED_MODEL_DIR)
converter.target_spec.supported_ops = [
    tf.lite.OpsSet.TFLITE_BUILTINS,  # enable TensorFlow Lite ops.
    tf.lite.OpsSet.SELECT_TF_OPS  # enable TensorFlow ops.
converter.experimental_enable_resource_variables = True
tflite_model = converter.convert()

Setup the TensorFlow Lite signatures

The TensorFlow Lite model you saved in the previous step contains several function signatures. You can access them through the tf.lite.Interpreter class and invoke each restore, train, save, and infer signature separately.

interpreter = tf.lite.Interpreter(model_content=tflite_model)

infer = interpreter.get_signature_runner("infer")

Compare the output of the original model, and the converted lite model:

logits_original = m.infer(x=train_images[:1])['logits'][0]
logits_lite = infer(x=train_images[:1])['logits'][0]


Above, you can see that the behavior of the model is not changed by the conversion to TFLite.

Retrain the model on a device

After converting your model to TensorFlow Lite and deploying it with your app, you can retrain the model on a device using new data and the train signature method of your model. Each training run generates a new set of weights that you can save for re-use and further improvement of the model, as shown in the next section.

On Android, you can perform on-device training with TensorFlow Lite using either Java or C++ APIs. In Java, use the Interpreter class to load a model and drive model training tasks. The following example shows how to run the training procedure using the runSignature method:

try (Interpreter interpreter = new Interpreter(modelBuffer)) {
    int NUM_EPOCHS = 100;
    int BATCH_SIZE = 100;
    int IMG_HEIGHT = 28;
    int IMG_WIDTH = 28;
    int NUM_TRAININGS = 60000;

    List<FloatBuffer> trainImageBatches = new ArrayList<>(NUM_BATCHES);
    List<FloatBuffer> trainLabelBatches = new ArrayList<>(NUM_BATCHES);

    // Prepare training batches.
    for (int i = 0; i < NUM_BATCHES; ++i) {
        FloatBuffer trainImages = FloatBuffer.allocateDirect(BATCH_SIZE * IMG_HEIGHT * IMG_WIDTH).order(ByteOrder.nativeOrder());
        FloatBuffer trainLabels = FloatBuffer.allocateDirect(BATCH_SIZE * 10).order(ByteOrder.nativeOrder());

        // Fill the data values...

    // Run training for a few steps.
    float[] losses = new float[NUM_EPOCHS];
    for (int epoch = 0; epoch < NUM_EPOCHS; ++epoch) {
        for (int batchIdx = 0; batchIdx < NUM_BATCHES; ++batchIdx) {
            Map<String, Object> inputs = new HashMap<>();
            inputs.put("x", trainImageBatches.get(batchIdx));
            inputs.put("y", trainLabelBatches.get(batchIdx));

            Map<String, Object> outputs = new HashMap<>();
            FloatBuffer loss = FloatBuffer.allocate(1);
            outputs.put("loss", loss);

            interpreter.runSignature(inputs, outputs, "train");

            // Record the last loss.
            if (batchIdx == NUM_BATCHES - 1) losses[epoch] = loss.get(0);

        // Print the loss output for every 10 epochs.
        if ((epoch + 1) % 10 == 0) {
              "Finished " + (epoch + 1) + " epochs, current loss: " + loss.get(0));

    // ...

You can see a complete code example of model retraining inside an Android app in the model personalization demo app.

Run training for a few epochs to improve or personalize the model. In practice, you would run this additional training using data collected on the device. For simplicity, this example uses the same training data as the previous training step.

train = interpreter.get_signature_runner("train")

more_epochs = np.arange(epochs[-1]+1, epochs[-1] + NUM_EPOCHS + 1, 1)
more_losses = np.zeros([NUM_EPOCHS])

for i in range(NUM_EPOCHS):
  for x,y in train_ds:
    result = train(x=x, y=y)
  more_losses[i] = result['loss']
  if (i + 1) % 10 == 0:
    print(f"Finished {i+1} epochs")
    print(f"  loss: {more_losses[i]:.3f}")
Finished 10 epochs
  loss: 0.223
Finished 20 epochs
  loss: 0.216
Finished 30 epochs
  loss: 0.210
Finished 40 epochs
  loss: 0.204
Finished 50 epochs
  loss: 0.198
plt.plot(epochs, losses, label='Pre-training')
plt.plot(more_epochs, more_losses, label='On device')
plt.ylim([0, max(plt.ylim())])
plt.ylabel('Loss [Cross Entropy]')


Above you can see that the on-device training picks up exactly where the pretraining stopped.

Save the trained weights

When you complete a training run on a device, the model updates the set of weights it is using in memory. Using the save signature method you created in your TensorFlow Lite model, you can save these weights to a checkpoint file for later reuse and improve your model.

save = interpreter.get_signature_runner("save")

save(checkpoint_path=np.array("/tmp/model.ckpt", dtype=np.string_))
{'checkpoint_path': array(b'/tmp/model.ckpt', dtype=object)}

In your Android application, you can store the generated weights as a checkpoint file in the internal storage space allocated for your app.

try (Interpreter interpreter = new Interpreter(modelBuffer)) {
    // Conduct the training jobs.

    // Export the trained weights as a checkpoint file.
    File outputFile = new File(getFilesDir(), "checkpoint.ckpt");
    Map<String, Object> inputs = new HashMap<>();
    inputs.put("checkpoint_path", outputFile.getAbsolutePath());
    Map<String, Object> outputs = new HashMap<>();
    interpreter.runSignature(inputs, outputs, "save");

Restore the trained weights

Any time you create an interpreter from a TFLite model, the interpreter will initially load the original model weights.

So after you've done some training and saved a checkpoint file, you'll need to run the restore signature method to load the checkpoint.

A good rule is "Anytime you create an Interpreter for a model, if the checkpoint exists, load it". If you need to reset the model to the baseline behavior, just delete the checkpoint and create a fresh interpreter.

another_interpreter = tf.lite.Interpreter(model_content=tflite_model)

infer = another_interpreter.get_signature_runner("infer")
restore = another_interpreter.get_signature_runner("restore")
logits_before = infer(x=train_images[:1])['logits'][0]

# Restore the trained weights from /tmp/model.ckpt
restore(checkpoint_path=np.array("/tmp/model.ckpt", dtype=np.string_))

logits_after = infer(x=train_images[:1])['logits'][0]

compare_logits({'Before': logits_before, 'After': logits_after})


The checkpoint was generated by training and saving with TFLite. Above you can see that applying the checkpoint updates the behavior of the model.

In your Android app, you can restore the serialized, trained weights from the checkpoint file you stored earlier.

try (Interpreter anotherInterpreter = new Interpreter(modelBuffer)) {
    // Load the trained weights from the checkpoint file.
    File outputFile = new File(getFilesDir(), "checkpoint.ckpt");
    Map<String, Object> inputs = new HashMap<>();
    inputs.put("checkpoint_path", outputFile.getAbsolutePath());
    Map<String, Object> outputs = new HashMap<>();
    anotherInterpreter.runSignature(inputs, outputs, "restore");

Run Inference using trained weights

Once you have loaded previously saved weights from a checkpoint file, running the infer method uses those weights with your original model to improve predictions. After loading the saved weights, you can use the infer signature method as shown below.

infer = another_interpreter.get_signature_runner("infer")
result = infer(x=test_images)
predictions = np.argmax(result["output"], axis=1)

true_labels = np.argmax(test_labels, axis=1)
(10000, 10)

Plot the predicted labels.

class_names = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',
               'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']

def plot(images, predictions, true_labels):
  for i in range(25):
      plt.imshow(images[i], cmap=plt.cm.binary)
      color = 'b' if predictions[i] == true_labels[i] else 'r'
      plt.xlabel(class_names[predictions[i]], color=color)

plot(test_images, predictions, true_labels)



In your Android application, after restoring the trained weights, run the inferences based on the loaded data.

try (Interpreter anotherInterpreter = new Interpreter(modelBuffer)) {
    // Restore the weights from the checkpoint file.

    int NUM_TESTS = 10;
    FloatBuffer testImages = FloatBuffer.allocateDirect(NUM_TESTS * 28 * 28).order(ByteOrder.nativeOrder());
    FloatBuffer output = FloatBuffer.allocateDirect(NUM_TESTS * 10).order(ByteOrder.nativeOrder());

    // Fill the test data.

    // Run the inference.
    Map<String, Object> inputs = new HashMap<>();
    inputs.put("x", testImages.rewind());
    Map<String, Object> outputs = new HashMap<>();
    outputs.put("output", output);
    anotherInterpreter.runSignature(inputs, outputs, "infer");

    // Process the result to get the final category values.
    int[] testLabels = new int[NUM_TESTS];
    for (int i = 0; i < NUM_TESTS; ++i) {
        int index = 0;
        for (int j = 1; j < 10; ++j) {
            if (output.get(i * 10 + index) < output.get(i * 10 + j)) index = testLabels[j];
        testLabels[i] = index;

Congratulations! You now have built a TensorFlow Lite model that supports on-device training. For more coding details, check out the example implementation in the model personalization demo app.

If you are interested in learning more about image classification, check Keras classification tutorial in the TensorFlow official guide page. This tutorial is based on that exercise and provides more depth on the subject of classification.