Save and load a model using a distribution strategy

View on Run in Google Colab View source on GitHub Download notebook


This tutorial demonstrates how you can save and load models in a SavedModel format with tf.distribute.Strategy during or after training. There are two kinds of APIs for saving and loading a Keras model: high-level ( and tf.keras.models.load_model) and low-level ( and tf.saved_model.load).

To learn about SavedModel and serialization in general, please read the saved model guide, and the Keras model serialization guide. Let's start with a simple example.

Import dependencies:

import tensorflow_datasets as tfds

import tensorflow as tf

Load and prepare the data with TensorFlow Datasets and, and create the model using tf.distribute.MirroredStrategy:

mirrored_strategy = tf.distribute.MirroredStrategy()

def get_data():
  datasets = tfds.load(name='mnist', as_supervised=True)
  mnist_train, mnist_test = datasets['train'], datasets['test']

  BUFFER_SIZE = 10000

  BATCH_SIZE = BATCH_SIZE_PER_REPLICA * mirrored_strategy.num_replicas_in_sync

  def scale(image, label):
    image = tf.cast(image, tf.float32)
    image /= 255

    return image, label

  train_dataset =
  eval_dataset =

  return train_dataset, eval_dataset

def get_model():
  with mirrored_strategy.scope():
    model = tf.keras.Sequential([
        tf.keras.layers.Conv2D(32, 3, activation='relu', input_shape=(28, 28, 1)),
        tf.keras.layers.Dense(64, activation='relu'),

    return model

Train the model with

model = get_model()
train_dataset, eval_dataset = get_data(), epochs=2)

Save and load the model

Now that you have a simple model to work with, let's explore the saving/loading APIs. There are two kinds of APIs available:

The Keras API

Here is an example of saving and loading a model with the Keras API:

keras_model_path = '/tmp/keras_save.keras'

Restore the model without tf.distribute.Strategy:

restored_keras_model = tf.keras.models.load_model(keras_model_path), epochs=2)

After restoring the model, you can continue training on it, even without needing to call Model.compile again, since it was already compiled before saving. The model is saved a Keras zip archive format, marked by the .keras extension. For more information, please refer to the guide on Keras saving.

Now, restore the model and train it using a tf.distribute.Strategy:

another_strategy = tf.distribute.OneDeviceStrategy('/cpu:0')
with another_strategy.scope():
  restored_keras_model_ds = tf.keras.models.load_model(keras_model_path), epochs=2)

As the output shows, loading works as expected with tf.distribute.Strategy. The strategy used here does not have to be the same strategy used before saving.

The tf.saved_model API

Saving the model with lower-level API is similar to the Keras API:

model = get_model()  # get a fresh model
saved_model_path = '/tmp/tf_save', saved_model_path)

Loading can be done with tf.saved_model.load. However, since it is a lower-level API (and hence has a wider range of use cases), it does not return a Keras model. Instead, it returns an object that contain functions that can be used to do inference. For example:

DEFAULT_FUNCTION_KEY = 'serving_default'
loaded = tf.saved_model.load(saved_model_path)
inference_func = loaded.signatures[DEFAULT_FUNCTION_KEY]

The loaded object may contain multiple functions, each associated with a key. The "serving_default" key is the default key for the inference function with a saved Keras model. To do inference with this function:

predict_dataset = image, label: image)
for batch in predict_dataset.take(1):

You can also load and do inference in a distributed manner:

another_strategy = tf.distribute.MirroredStrategy()
with another_strategy.scope():
  loaded = tf.saved_model.load(saved_model_path)
  inference_func = loaded.signatures[DEFAULT_FUNCTION_KEY]

  dist_predict_dataset = another_strategy.experimental_distribute_dataset(

  # Calling the function in a distributed manner
  for batch in dist_predict_dataset:
    result =, args=(batch,))

Calling the restored function is just a forward pass on the saved model (tf.keras.Model.predict). What if you want to continue training the loaded function? Or what if you need to embed the loaded function into a bigger model? A common practice is to wrap this loaded object into a Keras layer to achieve this. Luckily, TF Hub has hub.KerasLayer for this purpose, shown here:

import tensorflow_hub as hub

def build_model(loaded):
  x = tf.keras.layers.Input(shape=(28, 28, 1), name='input_x')
  # Wrap what's loaded to a KerasLayer
  keras_layer = hub.KerasLayer(loaded, trainable=True)(x)
  model = tf.keras.Model(x, keras_layer)
  return model

another_strategy = tf.distribute.MirroredStrategy()
with another_strategy.scope():
  loaded = tf.saved_model.load(saved_model_path)
  model = build_model(loaded)

                metrics=[tf.metrics.SparseCategoricalAccuracy()]), epochs=2)

In the above example, Tensorflow Hub's hub.KerasLayer wraps the result loaded back from tf.saved_model.load into a Keras layer that is used to build another model. This is very useful for transfer learning.

Which API should I use?

For saving, if you are working with a Keras model, use the Keras API unless you need the additional control allowed by the low-level API. If what you are saving is not a Keras model, then the lower-level API,, is your only choice.

For loading, your API choice depends on what you want to get from the model loading API. If you cannot (or do not want to) get a Keras model, then use tf.saved_model.load. Otherwise, use tf.keras.models.load_model. Note that you can get a Keras model back only if you saved a Keras model.

It is possible to mix and match the APIs. You can save a Keras model with, and load a non-Keras model with the low-level API, tf.saved_model.load.

model = get_model()

# Saving the model using Keras ``

another_strategy = tf.distribute.MirroredStrategy()
# Loading the model using the lower-level API
with another_strategy.scope():
  loaded = tf.saved_model.load(saved_model_path)

Saving/Loading from a local device

When saving and loading from a local I/O device while training on remote devices—for example, when using a Cloud TPU—you must use the option experimental_io_device in tf.saved_model.SaveOptions and tf.saved_model.LoadOptions to set the I/O device to localhost. For example:

model = get_model()

# Saving the model to a path on localhost.
saved_model_path = '/tmp/tf_save'
save_options = tf.saved_model.SaveOptions(experimental_io_device='/job:localhost'), options=save_options)

# Loading the model from a path on localhost.
another_strategy = tf.distribute.MirroredStrategy()
with another_strategy.scope():
  load_options = tf.saved_model.LoadOptions(experimental_io_device='/job:localhost')
  loaded = tf.keras.models.load_model(saved_model_path, options=load_options)


One special case is when you create Keras models in certain ways, and then save them before training. For example:

class SubclassedModel(tf.keras.Model):
  """Example model defined by subclassing `tf.keras.Model`."""

  output_name = 'output_layer'

  def __init__(self):
    super(SubclassedModel, self).__init__()
    self._dense_layer = tf.keras.layers.Dense(
        5, dtype=tf.dtypes.float32, name=self.output_name)

  def call(self, inputs):
    return self._dense_layer(inputs)

my_model = SubclassedModel()
except ValueError as e:
  print(f'{type(e).__name__}: ', *e.args)

A SavedModel saves the tf.types.experimental.ConcreteFunction objects generated when you trace a tf.function (check When is a Function tracing? in the Introduction to graphs and tf.function guide to learn more). If you get a ValueError like this it's because was not able to find or create a traced ConcreteFunction., saved_model_path)
x = tf.saved_model.load(saved_model_path)

Usually the model's forward pass—the call method—will be traced automatically when the model is called for the first time, often via the Keras method. A ConcreteFunction can also be generated by the Keras Sequential and Functional APIs, if you set the input shape, for example, by making the first layer either a tf.keras.layers.InputLayer or another layer type, and passing it the input_shape keyword argument.

To verify if your model has any traced ConcreteFunctions, check if Model.save_spec is None:

print(my_model.save_spec() is None)

Let's use to train the model, and notice that the save_spec gets defined and model saving will work:

BATCH_SIZE = BATCH_SIZE_PER_REPLICA * mirrored_strategy.num_replicas_in_sync

dataset_size = 100
dataset =
    (tf.range(5, dtype=tf.float32), tf.range(5, dtype=tf.float32))

my_model.compile(optimizer='adam', loss='mean_squared_error'), epochs=2)

print(my_model.save_spec() is None)