Google I/O is a wrap! Catch up on TensorFlow sessions View sessions

Save and load a model using a distribution strategy

View on TensorFlow.org Run in Google Colab View source on GitHub Download notebook

Overview

This tutorial demonstrates how you can save and load models in a SavedModel format with tf.distribute.Strategy during or after training. There are two kinds of APIs for saving and loading a Keras model: high-level (tf.keras.Model.save and tf.keras.models.load_model) and low-level (tf.saved_model.save and tf.saved_model.load).

To learn about SavedModel and serialization in general, please read the saved model guide, and the Keras model serialization guide. Let's start with a simple example.

Import dependencies:

import tensorflow_datasets as tfds

import tensorflow as tf

Load and prepare the data with TensorFlow Datasets and tf.data, and create the model using tf.distribute.MirroredStrategy:

mirrored_strategy = tf.distribute.MirroredStrategy()

def get_data():
  datasets = tfds.load(name='mnist', as_supervised=True)
  mnist_train, mnist_test = datasets['train'], datasets['test']

  BUFFER_SIZE = 10000

  BATCH_SIZE_PER_REPLICA = 64
  BATCH_SIZE = BATCH_SIZE_PER_REPLICA * mirrored_strategy.num_replicas_in_sync

  def scale(image, label):
    image = tf.cast(image, tf.float32)
    image /= 255

    return image, label

  train_dataset = mnist_train.map(scale).cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
  eval_dataset = mnist_test.map(scale).batch(BATCH_SIZE)

  return train_dataset, eval_dataset

def get_model():
  with mirrored_strategy.scope():
    model = tf.keras.Sequential([
        tf.keras.layers.Conv2D(32, 3, activation='relu', input_shape=(28, 28, 1)),
        tf.keras.layers.MaxPooling2D(),
        tf.keras.layers.Flatten(),
        tf.keras.layers.Dense(64, activation='relu'),
        tf.keras.layers.Dense(10)
    ])

    model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                  optimizer=tf.keras.optimizers.Adam(),
                  metrics=[tf.metrics.SparseCategoricalAccuracy()])
    return model
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)

Train the model with tf.keras.Model.fit:

model = get_model()
train_dataset, eval_dataset = get_data()
model.fit(train_dataset, epochs=2)
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
2022-04-22 06:50:47.309435: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:547] The `assert_cardinality` transformation is currently not handled by the auto-shard rewrite and will be removed.
Epoch 1/2
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
938/938 [==============================] - 9s 5ms/step - loss: 0.2031 - sparse_categorical_accuracy: 0.9420
Epoch 2/2
938/938 [==============================] - 4s 4ms/step - loss: 0.0676 - sparse_categorical_accuracy: 0.9798
<keras.callbacks.History at 0x7f952c00fe50>

Save and load the model

Now that you have a simple model to work with, let's explore the saving/loading APIs. There are two kinds of APIs available:

The Keras API

Here is an example of saving and loading a model with the Keras API:

keras_model_path = '/tmp/keras_save'
model.save(keras_model_path)
WARNING:absl:Found untraced functions such as _jit_compiled_convolution_op while saving (showing 1 of 1). These functions will not be directly callable after loading.
INFO:tensorflow:Assets written to: /tmp/keras_save/assets
INFO:tensorflow:Assets written to: /tmp/keras_save/assets

Restore the model without tf.distribute.Strategy:

restored_keras_model = tf.keras.models.load_model(keras_model_path)
restored_keras_model.fit(train_dataset, epochs=2)
Epoch 1/2
938/938 [==============================] - 3s 3ms/step - loss: 0.0485 - sparse_categorical_accuracy: 0.9849
Epoch 2/2
938/938 [==============================] - 3s 3ms/step - loss: 0.0344 - sparse_categorical_accuracy: 0.9893
<keras.callbacks.History at 0x7f952c00bbd0>

After restoring the model, you can continue training on it, even without needing to call Model.compile again, since it was already compiled before saving. The model is saved in TensorFlow's standard SavedModel proto format. For more information, please refer to the guide to SavedModel format.

Now, restore the model and train it using a tf.distribute.Strategy:

another_strategy = tf.distribute.OneDeviceStrategy('/cpu:0')
with another_strategy.scope():
  restored_keras_model_ds = tf.keras.models.load_model(keras_model_path)
  restored_keras_model_ds.fit(train_dataset, epochs=2)
Epoch 1/2
2022-04-22 06:51:07.485951: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:547] The `assert_cardinality` transformation is currently not handled by the auto-shard rewrite and will be removed.
2022-04-22 06:51:07.533281: W tensorflow/core/framework/dataset.cc:768] Input of GeneratorDatasetOp::Dataset will not be optimized because the dataset does not implement the AsGraphDefInternal() method needed to apply optimizations.
938/938 [==============================] - 10s 11ms/step - loss: 0.0485 - sparse_categorical_accuracy: 0.9855
Epoch 2/2
938/938 [==============================] - 9s 10ms/step - loss: 0.0338 - sparse_categorical_accuracy: 0.9898

As the Model.fit output shows, loading works as expected with tf.distribute.Strategy. The strategy used here does not have to be the same strategy used before saving.

The tf.saved_model API

Saving the model with lower-level API is similar to the Keras API:

model = get_model()  # get a fresh model
saved_model_path = '/tmp/tf_save'
tf.saved_model.save(model, saved_model_path)
WARNING:absl:Found untraced functions such as _jit_compiled_convolution_op while saving (showing 1 of 1). These functions will not be directly callable after loading.
INFO:tensorflow:Assets written to: /tmp/tf_save/assets
INFO:tensorflow:Assets written to: /tmp/tf_save/assets

Loading can be done with tf.saved_model.load. However, since it is a lower-level API (and hence has a wider range of use cases), it does not return a Keras model. Instead, it returns an object that contain functions that can be used to do inference. For example:

DEFAULT_FUNCTION_KEY = 'serving_default'
loaded = tf.saved_model.load(saved_model_path)
inference_func = loaded.signatures[DEFAULT_FUNCTION_KEY]

The loaded object may contain multiple functions, each associated with a key. The "serving_default" key is the default key for the inference function with a saved Keras model. To do inference with this function:

predict_dataset = eval_dataset.map(lambda image, label: image)
for batch in predict_dataset.take(1):
  print(inference_func(batch))
{'dense_3': <tf.Tensor: shape=(64, 10), dtype=float32, numpy=
array([[-1.66829824e-01, -6.57023340e-02,  1.00433722e-01,
         2.64007002e-01,  1.64257541e-01,  4.53347489e-02,
         4.71319817e-02, -5.05662560e-02,  7.54343793e-02,
         6.02332428e-02],
       [ 9.63039920e-02, -2.26380471e-02,  5.72972409e-02,
         2.80538891e-02,  2.00887904e-01, -5.30311391e-02,
         5.42474091e-02,  7.58884549e-02,  1.05634719e-01,
         7.72519186e-02],
       [ 1.77751277e-02,  7.26914182e-02, -1.18369143e-02,
         1.61694095e-01,  2.14668870e-01,  2.57224515e-02,
         1.70707107e-02, -1.21781714e-01,  1.54950470e-02,
         1.84745602e-02],
       [-6.13508523e-02,  1.84844248e-02,  6.67515621e-02,
         2.06116021e-01,  8.20632130e-02, -3.33720110e-02,
         1.61656104e-02, -1.08118489e-01,  1.02530338e-01,
        -6.31780773e-02],
       [-1.79943293e-02,  9.49065685e-02,  1.32015824e-01,
         2.03341365e-01,  1.74627632e-01, -2.10887715e-02,
        -4.94360551e-03, -6.89178705e-02,  8.37352127e-04,
        -3.13487388e-02],
       [-3.01171113e-02,  1.00814000e-01,  6.38750345e-02,
         2.31935561e-01,  2.14272052e-01,  1.15526393e-02,
         3.29931080e-02, -3.20239216e-02,  1.38624549e-01,
         3.15144658e-04],
       [ 8.33273977e-02,  1.49302296e-02,  6.27779365e-02,
         1.21633857e-01,  2.49505237e-01, -3.35178524e-03,
         6.13612197e-02,  1.41675472e-02,  8.82375091e-02,
         2.18972955e-02],
       [-1.37796290e-02,  5.32755479e-02,  7.28620868e-03,
         1.79180741e-01,  1.15639716e-01, -9.98490304e-03,
         2.81114317e-02, -5.76789007e-02,  8.22245702e-02,
        -5.04975729e-02],
       [-5.23501113e-02,  4.50184867e-02, -1.05090141e-01,
         2.56613493e-01,  2.54109889e-01,  1.20403230e-01,
         3.07827052e-02, -9.72324982e-02,  1.41814634e-01,
         2.62806714e-02],
       [-8.42924491e-02, -3.62931937e-02,  6.19453266e-02,
         1.97510898e-01,  1.01141125e-01,  4.92462069e-02,
         1.42575458e-01, -1.43614020e-02,  9.63314697e-02,
         7.52940252e-02],
       [-6.30331710e-02,  1.10456228e-01, -1.16325170e-03,
         2.14524224e-01,  1.68336466e-01,  5.55673912e-02,
         4.99481447e-02, -5.70731312e-02,  1.61831453e-01,
         3.46939489e-02],
       [-1.88177004e-02,  7.25489110e-02,  1.19085103e-01,
         1.92574352e-01,  2.01698959e-01, -3.57061811e-02,
         1.15002375e-02, -4.60720807e-02,  1.91118985e-01,
        -5.56473061e-03],
       [ 3.05368230e-02,  1.54953808e-01, -4.16156948e-02,
         1.48835093e-01,  2.12115660e-01,  5.11528924e-02,
         3.77153419e-02, -6.21527247e-02, -6.73634261e-02,
         8.61428082e-02],
       [ 1.02778897e-02, -6.76253531e-03,  1.17319196e-01,
         2.65516996e-01,  1.48470923e-01,  9.82715935e-03,
         7.03027099e-03, -1.74337670e-01,  5.36636263e-02,
        -2.72150040e-02],
       [-2.98051350e-02,  3.70360985e-02,  1.05787680e-01,
         2.19452754e-01,  1.52634084e-01, -6.46401569e-02,
        -5.00597581e-02, -3.35918702e-02,  1.23672239e-01,
        -1.06470473e-01],
       [-3.39470804e-03,  1.25394747e-01,  1.50448173e-01,
         1.67577818e-01,  1.63375765e-01, -7.10059553e-02,
        -3.06443200e-02, -6.04387000e-02,  1.64830744e-01,
        -1.66564062e-02],
       [-8.02340582e-02, -6.78222030e-02,  1.17694102e-02,
         1.29167646e-01,  1.53485805e-01, -3.83029729e-02,
         3.61376218e-02, -1.06063142e-01,  1.60429657e-01,
        -3.00653260e-02],
       [-1.88026577e-03,  7.66811967e-02,  4.24849242e-03,
         1.72812209e-01,  1.57790601e-01,  5.10438122e-02,
         1.13702655e-01, -4.89350706e-02,  1.02705240e-01,
         1.05789285e-02],
       [ 1.55825317e-02, -2.11698227e-02,  3.87617759e-03,
         1.78774536e-01,  1.79008842e-01,  7.60018677e-02,
        -2.61607729e-02, -7.24049509e-02,  4.00147513e-02,
         2.06921883e-02],
       [-1.05785936e-01,  2.69005522e-02,  1.30724147e-01,
         1.81785494e-01,  1.39355242e-01, -1.05513662e-01,
        -2.34467126e-02, -5.06407171e-02,  1.34027049e-01,
        -8.39241967e-02],
       [-1.06695294e-02, -1.62907988e-02,  5.93233183e-02,
         1.37603402e-01,  1.57422185e-01, -9.31300223e-03,
         4.45749126e-02, -6.58697337e-02,  1.89633757e-01,
        -1.28988009e-02],
       [ 2.25314498e-02,  1.04071319e-01,  8.57289135e-02,
         1.55806616e-01,  1.23344719e-01, -5.48473597e-02,
         1.39017040e-02, -7.37570971e-02,  1.03912190e-01,
        -8.88564065e-02],
       [-3.24133411e-02,  1.03348322e-01, -5.43242693e-03,
         1.79736599e-01,  2.25703120e-01,  3.06486823e-02,
         1.25571102e-01, -4.07368690e-02,  1.38380248e-02,
         3.04749720e-02],
       [ 8.31391141e-02, -6.75428063e-02, -5.24067879e-03,
         8.31556767e-02,  1.54603645e-01,  9.12544504e-03,
         5.84831685e-02,  5.06296791e-02,  1.49320439e-01,
         5.88927306e-02],
       [-1.28553689e-01, -2.58688368e-02,  1.99884214e-02,
         1.42506734e-01,  9.11594704e-02, -1.30070336e-02,
         1.22897610e-01, -4.75370884e-02,  1.82198301e-01,
         1.08608529e-01],
       [ 9.73133743e-03, -8.53723288e-02,  6.68865070e-03,
         3.07952106e-01,  1.85594574e-01,  1.11956030e-01,
         1.56234428e-01, -1.19408086e-01,  1.20581202e-01,
         2.91885622e-02],
       [-3.71664166e-02, -2.75016576e-02,  3.56700271e-02,
         1.48280978e-01,  7.42700398e-02, -3.85177433e-02,
         1.20765958e-02, -5.15548512e-02,  8.17354620e-02,
        -5.99596314e-02],
       [ 3.60207856e-02,  5.67628816e-03,  7.31534325e-03,
         1.46544740e-01,  1.17036000e-01, -4.46580499e-02,
        -9.43199359e-03, -1.45842852e-02,  2.42140442e-01,
        -3.73423174e-02],
       [-6.97065145e-03,  1.50239289e-01,  1.15167022e-01,
         2.41857678e-01,  1.36038661e-01,  4.13343385e-02,
         2.04018950e-02, -1.04828939e-01,  2.65558716e-02,
         3.01114097e-02],
       [ 4.90043722e-02,  1.25448748e-01,  2.49657929e-02,
         1.40089184e-01,  1.48549050e-01, -3.41142938e-02,
         2.23444495e-03, -1.28145814e-01,  2.14175545e-02,
        -1.98877789e-02],
       [-2.49088556e-02,  1.50664337e-02,  9.59001929e-02,
         1.00080341e-01,  1.10940441e-01, -1.27875566e-01,
        -1.44859590e-02, -2.92611606e-02,  7.40258470e-02,
        -9.12837908e-02],
       [ 1.17775667e-02, -7.65066221e-03, -3.96448374e-02,
         2.44866341e-01,  1.80622339e-01,  5.11005148e-02,
         3.65075245e-02, -7.85070285e-02,  9.95315090e-02,
        -1.36934295e-02],
       [ 1.13325082e-02,  9.65912417e-02,  4.68971953e-02,
         1.81209773e-01,  1.70031220e-01, -3.20439041e-03,
         7.47503042e-02, -2.08925754e-02,  1.45839527e-01,
        -7.53867999e-03],
       [-3.94523982e-03,  2.59891041e-02,  6.98198527e-02,
         8.87960121e-02,  6.84892684e-02,  4.49792109e-02,
        -5.35437837e-03, -4.54204381e-02,  3.87659445e-02,
        -3.18287611e-02],
       [ 1.54364463e-02,  4.33265343e-02, -6.54792711e-02,
         1.43275052e-01,  2.10045159e-01,  5.50200455e-02,
        -3.83140631e-02, -3.69091658e-03,  1.41641706e-01,
         2.69475691e-02],
       [ 5.23413718e-03,  5.95823936e-02,  1.35345966e-01,
         2.07676589e-01,  2.18142569e-01, -1.06628716e-01,
        -1.10288337e-02, -9.60890576e-03,  1.72279060e-01,
        -8.00692067e-02],
       [-3.24532464e-02, -8.53006691e-02, -2.28083786e-02,
         2.49700248e-01,  1.22932360e-01,  9.98823196e-02,
         7.98219740e-02, -1.98165178e-02,  1.98496342e-01,
         3.99748459e-02],
       [-1.65512562e-02,  4.26803418e-02,  1.10988386e-01,
         2.04267532e-01,  7.70107210e-02, -6.95031285e-02,
        -2.18536779e-02, -8.64365846e-02,  1.63327545e-01,
        -1.23858586e-01],
       [-3.46216410e-02,  3.69402841e-02,  4.44491208e-02,
         1.07214339e-01,  4.64262627e-02, -4.20800522e-02,
        -2.61899866e-02, -5.10727242e-02,  8.52320194e-02,
        -3.45907062e-02],
       [ 3.82258147e-02,  7.98343718e-02,  7.47422129e-02,
         2.44247720e-01,  1.19946495e-01, -1.43861547e-02,
         3.97951752e-02, -8.07641149e-02,  1.94615513e-01,
        -5.97534962e-02],
       [-1.49032548e-02,  9.34940651e-02, -1.41750425e-02,
         1.63116410e-01,  2.53371537e-01,  7.70481080e-02,
         5.22743240e-02, -2.09091008e-02,  8.33582431e-02,
         8.08733553e-02],
       [-5.74788116e-02,  1.24248944e-01,  4.88924086e-02,
         2.33521372e-01,  1.11080602e-01,  5.38572744e-02,
        -6.54475391e-03, -8.66869837e-02,  1.07936576e-01,
         3.27528305e-02],
       [ 1.36760920e-02,  6.37577698e-02,  7.05489367e-02,
         1.38153672e-01,  1.87819034e-01,  2.05195472e-02,
         4.16136235e-02, -1.10023558e-01,  1.15158200e-01,
         1.21188704e-02],
       [-1.67976394e-02,  6.43584058e-02,  4.97872606e-02,
         2.00935900e-01,  1.02977067e-01, -4.43888381e-02,
         1.69629604e-02, -8.47131386e-02,  1.77591056e-01,
         5.20711020e-03],
       [-1.38747431e-02,  7.62893781e-02,  1.24057420e-01,
         1.09222092e-01,  2.30513036e-01, -2.30862051e-02,
         4.70650718e-02, -5.65604120e-02,  5.07352166e-02,
         2.03110874e-02],
       [-1.19683027e-01, -5.53436130e-02, -1.58019997e-02,
         2.33960822e-01,  6.76316917e-02,  6.41214848e-03,
         2.65618265e-02, -1.01669870e-01,  2.77549416e-01,
         8.90393555e-03],
       [-6.03858158e-02,  9.49390382e-02,  4.75453585e-02,
         1.90521061e-01,  8.18673521e-02, -1.70915686e-02,
         1.14757046e-02, -1.42176598e-01,  4.70287539e-03,
         3.87623981e-02],
       [ 6.30005300e-02,  5.80105782e-02,  6.39670193e-02,
         1.92505836e-01,  3.18967581e-01,  4.36524674e-02,
         8.30903947e-02, -3.29740420e-02,  1.20322190e-01,
         6.76450431e-02],
       [ 2.49748677e-03,  2.24482343e-02,  9.36928466e-02,
         2.22687796e-01,  1.27259031e-01, -9.75219458e-02,
        -7.78435841e-02, -6.33900538e-02,  6.02504648e-02,
        -1.22365154e-01],
       [-5.69571182e-02,  3.24623324e-02,  3.16162184e-02,
         2.58825004e-01,  3.25450040e-02, -3.07036564e-03,
         6.44948483e-02, -9.86602530e-02,  1.12803228e-01,
         1.18407831e-02],
       [ 1.90220550e-02,  6.24117628e-03,  2.09219232e-02,
         1.88787565e-01,  2.12144852e-01,  7.06500411e-02,
         1.06127352e-01, -5.55426367e-02,  2.15229660e-01,
         8.16711038e-02],
       [ 1.42297149e-03,  1.25140235e-01,  4.28500026e-03,
         1.83053583e-01,  1.59013152e-01,  3.55741754e-02,
         3.27162668e-02, -8.87800902e-02, -2.34700181e-02,
         5.96281476e-02],
       [-3.87744904e-02,  1.92414634e-02,  4.32254001e-02,
         1.53217122e-01,  1.00404315e-01, -1.79725401e-02,
         6.01330772e-03, -1.00095481e-01, -9.74661298e-03,
        -1.61799528e-02],
       [ 1.34596229e-02,  3.01905517e-02,  2.71790177e-02,
         1.40652016e-01,  1.00337476e-01, -1.43815279e-02,
        -2.28257477e-03, -2.32312549e-02,  8.89860541e-02,
        -5.03262430e-02],
       [-8.23000669e-02,  7.25591630e-02,  9.52396244e-02,
         1.23983532e-01,  1.54516488e-01, -9.35958922e-02,
         3.23653817e-02, -5.89253604e-02,  4.97971401e-02,
        -6.04775250e-02],
       [ 9.20941308e-03,  9.72152501e-02,  7.12111145e-02,
         2.28228480e-01,  2.53475666e-01,  2.13885531e-02,
         1.02964509e-03, -7.49186948e-02,  2.16152027e-01,
         4.82185371e-02],
       [ 1.22462027e-02, -2.49450672e-02, -1.11322328e-02,
         1.29344121e-01,  1.73525229e-01,  3.91557105e-02,
        -5.27686253e-03, -8.26972947e-02,  1.03697672e-01,
         6.55747429e-02],
       [-8.74861330e-03,  9.42332298e-02,  7.50117898e-02,
         2.08360299e-01,  1.10086218e-01, -3.53808403e-02,
        -4.18840423e-02, -9.85472724e-02,  1.30476773e-01,
        -5.43486439e-02],
       [-6.40377104e-02,  7.37114474e-02,  1.13857999e-01,
         1.52148321e-01,  1.00814626e-01, -8.46893489e-02,
        -1.34603903e-02, -1.28374472e-02,  1.68438375e-01,
        -7.58913308e-02],
       [ 5.89922555e-02,  8.93255249e-02, -1.58086307e-02,
         1.09041616e-01,  1.76710129e-01, -4.74173501e-02,
        -2.83682551e-02, -7.42465854e-02,  2.14172080e-01,
         3.09006460e-02],
       [ 1.40921958e-03,  6.13394827e-02,  2.23291628e-02,
         1.53173387e-01,  1.54445231e-01,  1.49792843e-02,
         3.21674161e-04, -7.95597583e-02,  5.11914585e-03,
         5.77842854e-02],
       [ 3.09692323e-03,  1.08940132e-01,  1.10969543e-01,
         1.38779700e-01,  2.14474663e-01, -4.47251610e-02,
         6.32301420e-02, -6.82619065e-02,  1.15532689e-01,
        -3.01051997e-02],
       [-9.17289406e-03, -9.95587558e-03,  1.15491107e-01,
         1.75497115e-01,  1.20890349e-01, -2.09938101e-02,
         5.45932874e-02, -1.52857117e-02,  8.73069540e-02,
        -4.77676019e-02],
       [-5.60427830e-02, -8.99625197e-03,  5.30303754e-02,
         1.72452822e-01,  1.58994213e-01,  2.95356996e-02,
        -2.59171817e-02, -1.37496695e-01,  7.89206550e-02,
         1.59071907e-02]], dtype=float32)>}
2022-04-22 06:51:28.467609: W tensorflow/core/kernels/data/cache_dataset_ops.cc:856] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.

You can also load and do inference in a distributed manner:

another_strategy = tf.distribute.MirroredStrategy()
with another_strategy.scope():
  loaded = tf.saved_model.load(saved_model_path)
  inference_func = loaded.signatures[DEFAULT_FUNCTION_KEY]

  dist_predict_dataset = another_strategy.experimental_distribute_dataset(
      predict_dataset)

  # Calling the function in a distributed manner
  for batch in dist_predict_dataset:
    result = another_strategy.run(inference_func, args=(batch,))
    print(result)
    break
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)
2022-04-22 06:51:28.825714: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:547] The `assert_cardinality` transformation is currently not handled by the auto-shard rewrite and will be removed.
WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.
WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.
{'dense_3': <tf.Tensor: shape=(64, 10), dtype=float32, numpy=
array([[-1.66829824e-01, -6.57023340e-02,  1.00433722e-01,
         2.64007002e-01,  1.64257541e-01,  4.53347489e-02,
         4.71319817e-02, -5.05662560e-02,  7.54343793e-02,
         6.02332428e-02],
       [ 9.63039920e-02, -2.26380471e-02,  5.72972409e-02,
         2.80538891e-02,  2.00887904e-01, -5.30311391e-02,
         5.42474091e-02,  7.58884549e-02,  1.05634719e-01,
         7.72519186e-02],
       [ 1.77751277e-02,  7.26914182e-02, -1.18369143e-02,
         1.61694095e-01,  2.14668870e-01,  2.57224515e-02,
         1.70707107e-02, -1.21781714e-01,  1.54950470e-02,
         1.84745602e-02],
       [-6.13508523e-02,  1.84844248e-02,  6.67515621e-02,
         2.06116021e-01,  8.20632130e-02, -3.33720110e-02,
         1.61656104e-02, -1.08118489e-01,  1.02530338e-01,
        -6.31780773e-02],
       [-1.79943293e-02,  9.49065685e-02,  1.32015824e-01,
         2.03341365e-01,  1.74627632e-01, -2.10887715e-02,
        -4.94360551e-03, -6.89178705e-02,  8.37352127e-04,
        -3.13487388e-02],
       [-3.01171113e-02,  1.00814000e-01,  6.38750345e-02,
         2.31935561e-01,  2.14272052e-01,  1.15526393e-02,
         3.29931080e-02, -3.20239216e-02,  1.38624549e-01,
         3.15144658e-04],
       [ 8.33273977e-02,  1.49302296e-02,  6.27779365e-02,
         1.21633857e-01,  2.49505237e-01, -3.35178524e-03,
         6.13612197e-02,  1.41675472e-02,  8.82375091e-02,
         2.18972955e-02],
       [-1.37796290e-02,  5.32755479e-02,  7.28620868e-03,
         1.79180741e-01,  1.15639716e-01, -9.98490304e-03,
         2.81114317e-02, -5.76789007e-02,  8.22245702e-02,
        -5.04975729e-02],
       [-5.23501113e-02,  4.50184867e-02, -1.05090141e-01,
         2.56613493e-01,  2.54109889e-01,  1.20403230e-01,
         3.07827052e-02, -9.72324982e-02,  1.41814634e-01,
         2.62806714e-02],
       [-8.42924491e-02, -3.62931937e-02,  6.19453266e-02,
         1.97510898e-01,  1.01141125e-01,  4.92462069e-02,
         1.42575458e-01, -1.43614020e-02,  9.63314697e-02,
         7.52940252e-02],
       [-6.30331710e-02,  1.10456228e-01, -1.16325170e-03,
         2.14524224e-01,  1.68336466e-01,  5.55673912e-02,
         4.99481447e-02, -5.70731312e-02,  1.61831453e-01,
         3.46939489e-02],
       [-1.88177004e-02,  7.25489110e-02,  1.19085103e-01,
         1.92574352e-01,  2.01698959e-01, -3.57061811e-02,
         1.15002375e-02, -4.60720807e-02,  1.91118985e-01,
        -5.56473061e-03],
       [ 3.05368230e-02,  1.54953808e-01, -4.16156948e-02,
         1.48835093e-01,  2.12115660e-01,  5.11528924e-02,
         3.77153419e-02, -6.21527247e-02, -6.73634261e-02,
         8.61428082e-02],
       [ 1.02778897e-02, -6.76253531e-03,  1.17319196e-01,
         2.65516996e-01,  1.48470923e-01,  9.82715935e-03,
         7.03027099e-03, -1.74337670e-01,  5.36636263e-02,
        -2.72150040e-02],
       [-2.98051350e-02,  3.70360985e-02,  1.05787680e-01,
         2.19452754e-01,  1.52634084e-01, -6.46401569e-02,
        -5.00597581e-02, -3.35918702e-02,  1.23672239e-01,
        -1.06470473e-01],
       [-3.39470804e-03,  1.25394747e-01,  1.50448173e-01,
         1.67577818e-01,  1.63375765e-01, -7.10059553e-02,
        -3.06443200e-02, -6.04387000e-02,  1.64830744e-01,
        -1.66564062e-02],
       [-8.02340582e-02, -6.78222030e-02,  1.17694102e-02,
         1.29167646e-01,  1.53485805e-01, -3.83029729e-02,
         3.61376218e-02, -1.06063142e-01,  1.60429657e-01,
        -3.00653260e-02],
       [-1.88026577e-03,  7.66811967e-02,  4.24849242e-03,
         1.72812209e-01,  1.57790601e-01,  5.10438122e-02,
         1.13702655e-01, -4.89350706e-02,  1.02705240e-01,
         1.05789285e-02],
       [ 1.55825317e-02, -2.11698227e-02,  3.87617759e-03,
         1.78774536e-01,  1.79008842e-01,  7.60018677e-02,
        -2.61607729e-02, -7.24049509e-02,  4.00147513e-02,
         2.06921883e-02],
       [-1.05785936e-01,  2.69005522e-02,  1.30724147e-01,
         1.81785494e-01,  1.39355242e-01, -1.05513662e-01,
        -2.34467126e-02, -5.06407171e-02,  1.34027049e-01,
        -8.39241967e-02],
       [-1.06695294e-02, -1.62907988e-02,  5.93233183e-02,
         1.37603402e-01,  1.57422185e-01, -9.31300223e-03,
         4.45749126e-02, -6.58697337e-02,  1.89633757e-01,
        -1.28988009e-02],
       [ 2.25314498e-02,  1.04071319e-01,  8.57289135e-02,
         1.55806616e-01,  1.23344719e-01, -5.48473597e-02,
         1.39017040e-02, -7.37570971e-02,  1.03912190e-01,
        -8.88564065e-02],
       [-3.24133411e-02,  1.03348322e-01, -5.43242693e-03,
         1.79736599e-01,  2.25703120e-01,  3.06486823e-02,
         1.25571102e-01, -4.07368690e-02,  1.38380248e-02,
         3.04749720e-02],
       [ 8.31391141e-02, -6.75428063e-02, -5.24067879e-03,
         8.31556767e-02,  1.54603645e-01,  9.12544504e-03,
         5.84831685e-02,  5.06296791e-02,  1.49320439e-01,
         5.88927306e-02],
       [-1.28553689e-01, -2.58688368e-02,  1.99884214e-02,
         1.42506734e-01,  9.11594704e-02, -1.30070336e-02,
         1.22897610e-01, -4.75370884e-02,  1.82198301e-01,
         1.08608529e-01],
       [ 9.73133743e-03, -8.53723288e-02,  6.68865070e-03,
         3.07952106e-01,  1.85594574e-01,  1.11956030e-01,
         1.56234428e-01, -1.19408086e-01,  1.20581202e-01,
         2.91885622e-02],
       [-3.71664166e-02, -2.75016576e-02,  3.56700271e-02,
         1.48280978e-01,  7.42700398e-02, -3.85177433e-02,
         1.20765958e-02, -5.15548512e-02,  8.17354620e-02,
        -5.99596314e-02],
       [ 3.60207856e-02,  5.67628816e-03,  7.31534325e-03,
         1.46544740e-01,  1.17036000e-01, -4.46580499e-02,
        -9.43199359e-03, -1.45842852e-02,  2.42140442e-01,
        -3.73423174e-02],
       [-6.97065145e-03,  1.50239289e-01,  1.15167022e-01,
         2.41857678e-01,  1.36038661e-01,  4.13343385e-02,
         2.04018950e-02, -1.04828939e-01,  2.65558716e-02,
         3.01114097e-02],
       [ 4.90043722e-02,  1.25448748e-01,  2.49657929e-02,
         1.40089184e-01,  1.48549050e-01, -3.41142938e-02,
         2.23444495e-03, -1.28145814e-01,  2.14175545e-02,
        -1.98877789e-02],
       [-2.49088556e-02,  1.50664337e-02,  9.59001929e-02,
         1.00080341e-01,  1.10940441e-01, -1.27875566e-01,
        -1.44859590e-02, -2.92611606e-02,  7.40258470e-02,
        -9.12837908e-02],
       [ 1.17775667e-02, -7.65066221e-03, -3.96448374e-02,
         2.44866341e-01,  1.80622339e-01,  5.11005148e-02,
         3.65075245e-02, -7.85070285e-02,  9.95315090e-02,
        -1.36934295e-02],
       [ 1.13325082e-02,  9.65912417e-02,  4.68971953e-02,
         1.81209773e-01,  1.70031220e-01, -3.20439041e-03,
         7.47503042e-02, -2.08925754e-02,  1.45839527e-01,
        -7.53867999e-03],
       [-3.94523982e-03,  2.59891041e-02,  6.98198527e-02,
         8.87960121e-02,  6.84892684e-02,  4.49792109e-02,
        -5.35437837e-03, -4.54204381e-02,  3.87659445e-02,
        -3.18287611e-02],
       [ 1.54364463e-02,  4.33265343e-02, -6.54792711e-02,
         1.43275052e-01,  2.10045159e-01,  5.50200455e-02,
        -3.83140631e-02, -3.69091658e-03,  1.41641706e-01,
         2.69475691e-02],
       [ 5.23413718e-03,  5.95823936e-02,  1.35345966e-01,
         2.07676589e-01,  2.18142569e-01, -1.06628716e-01,
        -1.10288337e-02, -9.60890576e-03,  1.72279060e-01,
        -8.00692067e-02],
       [-3.24532464e-02, -8.53006691e-02, -2.28083786e-02,
         2.49700248e-01,  1.22932360e-01,  9.98823196e-02,
         7.98219740e-02, -1.98165178e-02,  1.98496342e-01,
         3.99748459e-02],
       [-1.65512562e-02,  4.26803418e-02,  1.10988386e-01,
         2.04267532e-01,  7.70107210e-02, -6.95031285e-02,
        -2.18536779e-02, -8.64365846e-02,  1.63327545e-01,
        -1.23858586e-01],
       [-3.46216410e-02,  3.69402841e-02,  4.44491208e-02,
         1.07214339e-01,  4.64262627e-02, -4.20800522e-02,
        -2.61899866e-02, -5.10727242e-02,  8.52320194e-02,
        -3.45907062e-02],
       [ 3.82258147e-02,  7.98343718e-02,  7.47422129e-02,
         2.44247720e-01,  1.19946495e-01, -1.43861547e-02,
         3.97951752e-02, -8.07641149e-02,  1.94615513e-01,
        -5.97534962e-02],
       [-1.49032548e-02,  9.34940651e-02, -1.41750425e-02,
         1.63116410e-01,  2.53371537e-01,  7.70481080e-02,
         5.22743240e-02, -2.09091008e-02,  8.33582431e-02,
         8.08733553e-02],
       [-5.74788116e-02,  1.24248944e-01,  4.88924086e-02,
         2.33521372e-01,  1.11080602e-01,  5.38572744e-02,
        -6.54475391e-03, -8.66869837e-02,  1.07936576e-01,
         3.27528305e-02],
       [ 1.36760920e-02,  6.37577698e-02,  7.05489367e-02,
         1.38153672e-01,  1.87819034e-01,  2.05195472e-02,
         4.16136235e-02, -1.10023558e-01,  1.15158200e-01,
         1.21188704e-02],
       [-1.67976394e-02,  6.43584058e-02,  4.97872606e-02,
         2.00935900e-01,  1.02977067e-01, -4.43888381e-02,
         1.69629604e-02, -8.47131386e-02,  1.77591056e-01,
         5.20711020e-03],
       [-1.38747431e-02,  7.62893781e-02,  1.24057420e-01,
         1.09222092e-01,  2.30513036e-01, -2.30862051e-02,
         4.70650718e-02, -5.65604120e-02,  5.07352166e-02,
         2.03110874e-02],
       [-1.19683027e-01, -5.53436130e-02, -1.58019997e-02,
         2.33960822e-01,  6.76316917e-02,  6.41214848e-03,
         2.65618265e-02, -1.01669870e-01,  2.77549416e-01,
         8.90393555e-03],
       [-6.03858158e-02,  9.49390382e-02,  4.75453585e-02,
         1.90521061e-01,  8.18673521e-02, -1.70915686e-02,
         1.14757046e-02, -1.42176598e-01,  4.70287539e-03,
         3.87623981e-02],
       [ 6.30005300e-02,  5.80105782e-02,  6.39670193e-02,
         1.92505836e-01,  3.18967581e-01,  4.36524674e-02,
         8.30903947e-02, -3.29740420e-02,  1.20322190e-01,
         6.76450431e-02],
       [ 2.49748677e-03,  2.24482343e-02,  9.36928466e-02,
         2.22687796e-01,  1.27259031e-01, -9.75219458e-02,
        -7.78435841e-02, -6.33900538e-02,  6.02504648e-02,
        -1.22365154e-01],
       [-5.69571182e-02,  3.24623324e-02,  3.16162184e-02,
         2.58825004e-01,  3.25450040e-02, -3.07036564e-03,
         6.44948483e-02, -9.86602530e-02,  1.12803228e-01,
         1.18407831e-02],
       [ 1.90220550e-02,  6.24117628e-03,  2.09219232e-02,
         1.88787565e-01,  2.12144852e-01,  7.06500411e-02,
         1.06127352e-01, -5.55426367e-02,  2.15229660e-01,
         8.16711038e-02],
       [ 1.42297149e-03,  1.25140235e-01,  4.28500026e-03,
         1.83053583e-01,  1.59013152e-01,  3.55741754e-02,
         3.27162668e-02, -8.87800902e-02, -2.34700181e-02,
         5.96281476e-02],
       [-3.87744904e-02,  1.92414634e-02,  4.32254001e-02,
         1.53217122e-01,  1.00404315e-01, -1.79725401e-02,
         6.01330772e-03, -1.00095481e-01, -9.74661298e-03,
        -1.61799528e-02],
       [ 1.34596229e-02,  3.01905517e-02,  2.71790177e-02,
         1.40652016e-01,  1.00337476e-01, -1.43815279e-02,
        -2.28257477e-03, -2.32312549e-02,  8.89860541e-02,
        -5.03262430e-02],
       [-8.23000669e-02,  7.25591630e-02,  9.52396244e-02,
         1.23983532e-01,  1.54516488e-01, -9.35958922e-02,
         3.23653817e-02, -5.89253604e-02,  4.97971401e-02,
        -6.04775250e-02],
       [ 9.20941308e-03,  9.72152501e-02,  7.12111145e-02,
         2.28228480e-01,  2.53475666e-01,  2.13885531e-02,
         1.02964509e-03, -7.49186948e-02,  2.16152027e-01,
         4.82185371e-02],
       [ 1.22462027e-02, -2.49450672e-02, -1.11322328e-02,
         1.29344121e-01,  1.73525229e-01,  3.91557105e-02,
        -5.27686253e-03, -8.26972947e-02,  1.03697672e-01,
         6.55747429e-02],
       [-8.74861330e-03,  9.42332298e-02,  7.50117898e-02,
         2.08360299e-01,  1.10086218e-01, -3.53808403e-02,
        -4.18840423e-02, -9.85472724e-02,  1.30476773e-01,
        -5.43486439e-02],
       [-6.40377104e-02,  7.37114474e-02,  1.13857999e-01,
         1.52148321e-01,  1.00814626e-01, -8.46893489e-02,
        -1.34603903e-02, -1.28374472e-02,  1.68438375e-01,
        -7.58913308e-02],
       [ 5.89922555e-02,  8.93255249e-02, -1.58086307e-02,
         1.09041616e-01,  1.76710129e-01, -4.74173501e-02,
        -2.83682551e-02, -7.42465854e-02,  2.14172080e-01,
         3.09006460e-02],
       [ 1.40921958e-03,  6.13394827e-02,  2.23291628e-02,
         1.53173387e-01,  1.54445231e-01,  1.49792843e-02,
         3.21674161e-04, -7.95597583e-02,  5.11914585e-03,
         5.77842854e-02],
       [ 3.09692323e-03,  1.08940132e-01,  1.10969543e-01,
         1.38779700e-01,  2.14474663e-01, -4.47251610e-02,
         6.32301420e-02, -6.82619065e-02,  1.15532689e-01,
        -3.01051997e-02],
       [-9.17289406e-03, -9.95587558e-03,  1.15491107e-01,
         1.75497115e-01,  1.20890349e-01, -2.09938101e-02,
         5.45932874e-02, -1.52857117e-02,  8.73069540e-02,
        -4.77676019e-02],
       [-5.60427830e-02, -8.99625197e-03,  5.30303754e-02,
         1.72452822e-01,  1.58994213e-01,  2.95356996e-02,
        -2.59171817e-02, -1.37496695e-01,  7.89206550e-02,
         1.59071907e-02]], dtype=float32)>}
2022-04-22 06:51:29.787107: W tensorflow/core/kernels/data/cache_dataset_ops.cc:856] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.

Calling the restored function is just a forward pass on the saved model (tf.keras.Model.predict). What if you want to continue training the loaded function? Or what if you need to embed the loaded function into a bigger model? A common practice is to wrap this loaded object into a Keras layer to achieve this. Luckily, TF Hub has hub.KerasLayer for this purpose, shown here:

import tensorflow_hub as hub

def build_model(loaded):
  x = tf.keras.layers.Input(shape=(28, 28, 1), name='input_x')
  # Wrap what's loaded to a KerasLayer
  keras_layer = hub.KerasLayer(loaded, trainable=True)(x)
  model = tf.keras.Model(x, keras_layer)
  return model

another_strategy = tf.distribute.MirroredStrategy()
with another_strategy.scope():
  loaded = tf.saved_model.load(saved_model_path)
  model = build_model(loaded)

  model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                optimizer=tf.keras.optimizers.Adam(),
                metrics=[tf.metrics.SparseCategoricalAccuracy()])
  model.fit(train_dataset, epochs=2)
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)
Epoch 1/2
2022-04-22 06:51:30.412009: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:547] The `assert_cardinality` transformation is currently not handled by the auto-shard rewrite and will be removed.
938/938 [==============================] - 6s 4ms/step - loss: 0.2015 - sparse_categorical_accuracy: 0.9423
Epoch 2/2
938/938 [==============================] - 4s 4ms/step - loss: 0.0696 - sparse_categorical_accuracy: 0.9796

In the above example, Tensorflow Hub's hub.KerasLayer wraps the result loaded back from tf.saved_model.load into a Keras layer that is used to build another model. This is very useful for transfer learning.

Which API should I use?

For saving, if you are working with a Keras model, use the Keras Model.save API unless you need the additional control allowed by the low-level API. If what you are saving is not a Keras model, then the lower-level API, tf.saved_model.save, is your only choice.

For loading, your API choice depends on what you want to get from the model loading API. If you cannot (or do not want to) get a Keras model, then use tf.saved_model.load. Otherwise, use tf.keras.models.load_model. Note that you can get a Keras model back only if you saved a Keras model.

It is possible to mix and match the APIs. You can save a Keras model with Model.save, and load a non-Keras model with the low-level API, tf.saved_model.load.

model = get_model()

# Saving the model using Keras `Model.save`
model.save(keras_model_path)

another_strategy = tf.distribute.MirroredStrategy()
# Loading the model using the lower-level API
with another_strategy.scope():
  loaded = tf.saved_model.load(keras_model_path)
WARNING:absl:Found untraced functions such as _jit_compiled_convolution_op while saving (showing 1 of 1). These functions will not be directly callable after loading.
INFO:tensorflow:Assets written to: /tmp/keras_save/assets
INFO:tensorflow:Assets written to: /tmp/keras_save/assets
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)

Saving/Loading from a local device

When saving and loading from a local I/O device while training on remote devices—for example, when using a Cloud TPU—you must use the option experimental_io_device in tf.saved_model.SaveOptions and tf.saved_model.LoadOptions to set the I/O device to localhost. For example:

model = get_model()

# Saving the model to a path on localhost.
saved_model_path = '/tmp/tf_save'
save_options = tf.saved_model.SaveOptions(experimental_io_device='/job:localhost')
model.save(saved_model_path, options=save_options)

# Loading the model from a path on localhost.
another_strategy = tf.distribute.MirroredStrategy()
with another_strategy.scope():
  load_options = tf.saved_model.LoadOptions(experimental_io_device='/job:localhost')
  loaded = tf.keras.models.load_model(saved_model_path, options=load_options)
WARNING:absl:Found untraced functions such as _jit_compiled_convolution_op while saving (showing 1 of 1). These functions will not be directly callable after loading.
INFO:tensorflow:Assets written to: /tmp/tf_save/assets
INFO:tensorflow:Assets written to: /tmp/tf_save/assets
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)

Caveats

One special case is when you create Keras models in certain ways, and then save them before training. For example:

class SubclassedModel(tf.keras.Model):
  """Example model defined by subclassing `tf.keras.Model`."""

  output_name = 'output_layer'

  def __init__(self):
    super(SubclassedModel, self).__init__()
    self._dense_layer = tf.keras.layers.Dense(
        5, dtype=tf.dtypes.float32, name=self.output_name)

  def call(self, inputs):
    return self._dense_layer(inputs)

my_model = SubclassedModel()
try:
  my_model.save(keras_model_path)
except ValueError as e:
  print(f'{type(e).__name__}: ', *e.args)
WARNING:tensorflow:Skipping full serialization of Keras layer <__main__.SubclassedModel object at 0x7f9468f4b310>, because it is not built.
WARNING:tensorflow:Skipping full serialization of Keras layer <__main__.SubclassedModel object at 0x7f9468f4b310>, because it is not built.
ValueError:  Model <__main__.SubclassedModel object at 0x7f9468f4b310> cannot be saved either because the input shape is not available or because the forward pass of the model is not defined.To define a forward pass, please override `Model.call()`. To specify an input shape, either call `build(input_shape)` directly, or call the model on actual data using `Model()`, `Model.fit()`, or `Model.predict()`. If you have a custom training step, please make sure to invoke the forward pass in train step through `Model.__call__`, i.e. `model(inputs)`, as opposed to `model.call()`.

A SavedModel saves the tf.types.experimental.ConcreteFunction objects generated when you trace a tf.function (check When is a Function tracing? in the Introduction to graphs and tf.function guide to learn more). If you get a ValueError like this it's because Model.save was not able to find or create a traced ConcreteFunction.

tf.saved_model.save(my_model, saved_model_path)
x = tf.saved_model.load(saved_model_path)
x.signatures
WARNING:tensorflow:Skipping full serialization of Keras layer <__main__.SubclassedModel object at 0x7f9468f4b310>, because it is not built.
WARNING:tensorflow:Skipping full serialization of Keras layer <__main__.SubclassedModel object at 0x7f9468f4b310>, because it is not built.
WARNING:tensorflow:Skipping full serialization of Keras layer <keras.layers.core.dense.Dense object at 0x7f9468e63310>, because it is not built.
WARNING:tensorflow:Skipping full serialization of Keras layer <keras.layers.core.dense.Dense object at 0x7f9468e63310>, because it is not built.
INFO:tensorflow:Assets written to: /tmp/tf_save/assets
INFO:tensorflow:Assets written to: /tmp/tf_save/assets
_SignatureMap({})

Usually the model's forward pass—the call method—will be traced automatically when the model is called for the first time, often via the Keras Model.fit method. A ConcreteFunction can also be generated by the Keras Sequential and Functional APIs, if you set the input shape, for example, by making the first layer either a tf.keras.layers.InputLayer or another layer type, and passing it the input_shape keyword argument.

To verify if your model has any traced ConcreteFunctions, check if Model.save_spec is None:

print(my_model.save_spec() is None)
True

Let's use tf.keras.Model.fit to train the model, and notice that the save_spec gets defined and model saving will work:

BATCH_SIZE_PER_REPLICA = 4
BATCH_SIZE = BATCH_SIZE_PER_REPLICA * mirrored_strategy.num_replicas_in_sync

dataset_size = 100
dataset = tf.data.Dataset.from_tensors(
    (tf.range(5, dtype=tf.float32), tf.range(5, dtype=tf.float32))
    ).repeat(dataset_size).batch(BATCH_SIZE)

my_model.compile(optimizer='adam', loss='mean_squared_error')
my_model.fit(dataset, epochs=2)

print(my_model.save_spec() is None)
my_model.save(keras_model_path)
Epoch 1/2
25/25 [==============================] - 0s 2ms/step - loss: 14.0348
Epoch 2/2
25/25 [==============================] - 0s 2ms/step - loss: 12.3347
False
INFO:tensorflow:Assets written to: /tmp/keras_save/assets
INFO:tensorflow:Assets written to: /tmp/keras_save/assets