Migrate the SavedModel workflow

View on TensorFlow.org Run in Google Colab View source on GitHub Download notebook

Once you have migrated your model from TensorFlow 1's graphs and sessions to TensorFlow 2 APIs, such as tf.function, tf.Module, and tf.keras.Model, you can migrate the model saving and loading code. This notebook provides examples of how you can save and load in the SavedModel format in TensorFlow 1 and TensorFlow 2. Here is a quick overview of the related API changes for migration from TensorFlow 1 to TensorFlow 2:

TensorFlow 1 Migration to TensorFlow 2
Saving tf.compat.v1.saved_model.Builder
tf.compat.v1.saved_model.simple_save
tf.saved_model.save
Keras: tf.keras.models.save_model
Loading tf.compat.v1.saved_model.load tf.saved_model.load
Keras: tf.keras.models.load_model
Signatures: a set of input
and output tensors that
can be used to run the
Generated using the *.signature_def utils
(e.g. tf.compat.v1.saved_model.predict_signature_def)
Write a tf.function and export it using the signatures argument
in tf.saved_model.save.
Classification
and regression
:
special types of signatures
Generated with
tf.compat.v1.saved_model.classification_signature_def,
tf.compat.v1.saved_model.regression_signature_def,
and certain Estimator exports.
These two signature types have been removed from TensorFlow 2.
If the serving library requires these method names,
tf.compat.v1.saved_model.signature_def_utils.MethodNameUpdater.

For a more in-depth explanation of the mapping, refer to the Changes from TensorFlow 1 to TensorFlow 2 section below.

Setup

The examples below show how to export and load the same dummy TensorFlow model (defined as add_two below) to a SavedModel format using the TensorFlow 1 and TensorFlow 2 APIs. Start by setting up the imports and utility functions:

import tensorflow as tf
import tensorflow.compat.v1 as tf1
import shutil

def remove_dir(path):
  try:
    shutil.rmtree(path)
  except:
    pass

def add_two(input):
  return input + 2

TensorFlow 1: Save and export a SavedModel

In TensorFlow 1, you use the tf.compat.v1.saved_model.Builder, tf.compat.v1.saved_model.simple_save, and tf.estimator.Estimator.export_saved_model APIs to build, save, and export the TensorFlow graph and session:

1. Save the graph as a SavedModel with SavedModelBuilder

remove_dir("saved-model-builder")

with tf.Graph().as_default() as g:
  with tf1.Session() as sess:
    input = tf1.placeholder(tf.float32, shape=[])
    output = add_two(input)
    print("add two output: ", sess.run(output, {input: 3.}))

    # Save with SavedModelBuilder
    builder = tf1.saved_model.Builder('saved-model-builder')
    sig_def = tf1.saved_model.predict_signature_def(
        inputs={'input': input},
        outputs={'output': output})
    builder.add_meta_graph_and_variables(
        sess, tags=["serve"], signature_def_map={
            tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY: sig_def
    })
    builder.save()
!saved_model_cli run --dir saved-model-builder --tag_set serve \
 --signature_def serving_default --input_exprs input=10

2. Build a SavedModel for serving

remove_dir("simple-save")

with tf.Graph().as_default() as g:
  with tf1.Session() as sess:
    input = tf1.placeholder(tf.float32, shape=[])
    output = add_two(input)
    print("add_two output: ", sess.run(output, {input: 3.}))

    tf1.saved_model.simple_save(
        sess, 'simple-save',
        inputs={'input': input},
        outputs={'output': output})
!saved_model_cli run --dir simple-save --tag_set serve \
 --signature_def serving_default --input_exprs input=10

3. Export the Estimator inference graph as a SavedModel

In the definition of the Estimator model_fn (defined below), you can define signatures in your model by returning export_outputs in the tf.estimator.EstimatorSpec. There are different types of outputs:

  • tf.estimator.export.ClassificationOutput
  • tf.estimator.export.RegressionOutput
  • tf.estimator.export.PredictOutput

These will produce classification, regression, and prediction signature types, respectively.

When the estimator is exported with tf.estimator.Estimator.export_saved_model, these signatures will be saved with the model.

def model_fn(features, labels, mode):
  output = add_two(features['input'])
  step = tf1.train.get_global_step()
  return tf.estimator.EstimatorSpec(
      mode,
      predictions=output,
      train_op=step.assign_add(1),
      loss=tf.constant(0.),
      export_outputs={
          tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY: \
          tf.estimator.export.PredictOutput({'output': output})})
est = tf.estimator.Estimator(model_fn, 'estimator-checkpoints')

# Train for one step to create a checkpoint.
def train_fn():
  return tf.data.Dataset.from_tensors({'input': 3.})
est.train(train_fn, steps=1)

# This utility function `build_raw_serving_input_receiver_fn` takes in raw
# tensor features and builds an "input serving receiver function", which
# creates placeholder inputs to the model.
serving_input_fn = tf.estimator.export.build_raw_serving_input_receiver_fn(
    {'input': tf.constant(3.)})  # Pass in a dummy input batch.
estimator_path = est.export_saved_model('exported-estimator', serving_input_fn)

# Estimator's export_saved_model creates a time stamped directory. Move this
# to a set path so it can be inspected with `saved_model_cli` in the cell below.
!rm -rf estimator-model
import shutil
shutil.move(estimator_path, 'estimator-model')
!saved_model_cli run --dir estimator-model --tag_set serve \
 --signature_def serving_default --input_exprs input=[10]

TensorFlow 2: Save and export a SavedModel

Save and export a SavedModel defined with tf.Module

To export your model in TensorFlow 2, you must define a tf.Module or a tf.keras.Model to hold all of your model's variables and functions. Then, you can call tf.saved_model.save to create a SavedModel. Refer to the Saving a custom model section in the Using the SavedModel format guide to learn more.

class MyModel(tf.Module):
  @tf.function
  def __call__(self, input):
    return add_two(input)

model = MyModel()

@tf.function
def serving_default(input):
  return {'output': model(input)}

signature_function = serving_default.get_concrete_function(
    tf.TensorSpec(shape=[], dtype=tf.float32))
tf.saved_model.save(
    model, 'tf2-save', signatures={
        tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY: signature_function})
!saved_model_cli run --dir tf2-save --tag_set serve \
 --signature_def serving_default --input_exprs input=10

Save and export a SavedModel defined with Keras

The Keras APIs for saving and exporting—Model.save or tf.keras.models.save_model—can export a SavedModel from a tf.keras.Model. Check out the Save and load Keras models for more details.

inp = tf.keras.Input(3)
out = add_two(inp)
model = tf.keras.Model(inputs=inp, outputs=out)

@tf.function(input_signature=[tf.TensorSpec(shape=[], dtype=tf.float32)])
def serving_default(input):
  return {'output': model(input)}

model.save('keras-model', save_format='tf', signatures={
        tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY: serving_default})
!saved_model_cli run --dir keras-model --tag_set serve \
 --signature_def serving_default --input_exprs input=10

Loading a SavedModel

A SavedModel saved with any of the above APIs can be loaded using either TensorFlow 1 or TensorFlow 2 APIs.

A TensorFlow 1 SavedModel can generally be used for inference when loaded into TensorFlow 2, but training (generating gradients) is only possible if the SavedModel contains resource variables. You can check the dtype of the variables—if the variable dtype contains "_ref", then it is a reference variable.

A TensorFlow 2 SavedModel can be loaded and executed from TensorFlow 1 as long as the SavedModel is saved with signatures.

The sections below contain code samples showing how to load the SavedModels saved in the previous sections, and call the exported signature.

TensorFlow 1: Load a SavedModel with tf.saved_model.load

In TensorFlow 1, you can import a SavedModel directly into the current graph and session using tf.saved_model.load. You can call Session.run on the tensor input and output names:

def load_tf1(path, input):
  print('Loading from', path)
  with tf.Graph().as_default() as g:
    with tf1.Session() as sess:
      meta_graph = tf1.saved_model.load(sess, ["serve"], path)
      sig_def = meta_graph.signature_def[tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY]
      input_name = sig_def.inputs['input'].name
      output_name = sig_def.outputs['output'].name
      print('  Output with input', input, ': ', 
            sess.run(output_name, feed_dict={input_name: input}))

load_tf1('saved-model-builder', 5.)
load_tf1('simple-save', 5.)
load_tf1('estimator-model', [5.])  # Estimator's input must be batched.
load_tf1('tf2-save', 5.)
load_tf1('keras-model', 5.)

TensorFlow 2: Load a model saved with tf.saved_model

In TensorFlow 2, objects are loaded into a Python object that stores the variables and functions. This is compatible with models saved from TensorFlow 1.

Check out the tf.saved_model.load API docs and Loading and using a custom model section from the Using the SavedModel format guide for details.

def load_tf2(path, input):
  print('Loading from', path)
  loaded = tf.saved_model.load(path)
  out = loaded.signatures[tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY](
      tf.constant(input))['output']
  print('  Output with input', input, ': ', out)

load_tf2('saved-model-builder', 5.)
load_tf2('simple-save', 5.)
load_tf2('estimator-model', [5.])  # Estimator's input must be batched.
load_tf2('tf2-save', 5.)
load_tf2('keras-model', 5.)

Models saved with the TensorFlow 2 API can also access tf.functions and variables that are attached to the model (instead of those exported as signatures). For example:

loaded = tf.saved_model.load('tf2-save')
print('restored __call__:', loaded.__call__)
print('output with input 5.', loaded(5))

TensorFlow 2: Load a model saved with Keras

The Keras loading API—tf.keras.models.load_model—allows you to reload a saved model back into a Keras Model object. Note that this only allows you to load SavedModels saved with Keras (Model.save or tf.keras.models.save_model).

Models saved with tf.saved_model.save should be loaded with tf.saved_model.load. You can load a Keras model saved with Model.save using tf.saved_model.load but you will only get the TensorFlow graph. Refer to the tf.keras.models.load_model API docs and Save and load Keras models guide for details.

loaded_model = tf.keras.models.load_model('keras-model')
loaded_model.predict_on_batch(tf.constant([1, 3, 4]))

GraphDef and MetaGraphDef

There is no straightforward way to load a raw GraphDef or MetaGraphDef to TF2. However, you can convert the TF1 code that imports the graph into a TF2 concrete_function using v1.wrap_function.

First, save a MetaGraphDef:

# Save a simple multiplication computation:
with tf.Graph().as_default() as g:
  x = tf1.placeholder(tf.float32, shape=[], name='x')
  v = tf.Variable(3.0, name='v')
  y = tf.multiply(x, v, name='y')
  with tf1.Session() as sess:
    sess.run(v.initializer)
    print(sess.run(y, feed_dict={x: 5}))
    s = tf1.train.Saver()
    s.export_meta_graph('multiply.pb', as_text=True)
    s.save(sess, 'multiply_values.ckpt')

Using TF1 APIs, you can use tf1.train.import_meta_graph to import the graph and restore the values:

with tf.Graph().as_default() as g:
  meta = tf1.train.import_meta_graph('multiply.pb')
  x = g.get_tensor_by_name('x:0')
  y = g.get_tensor_by_name('y:0')
  with tf1.Session() as sess:
    meta.restore(sess, 'multiply_values.ckpt')
    print(sess.run(y, feed_dict={x: 5}))

There are no TF2 APIs for loading the graph, but you can still import it into a concrete function that can be executed in eager mode:

def import_multiply():
  # Any graph-building code is allowed here.
  tf1.train.import_meta_graph('multiply.pb')

# Creates a tf.function with all the imported elements in the function graph.
wrapped_import = tf1.wrap_function(import_multiply, [])
import_graph = wrapped_import.graph
x = import_graph.get_tensor_by_name('x:0')
y = import_graph.get_tensor_by_name('y:0')

# Restore the variable values.
tf1.train.Saver(wrapped_import.variables).restore(
    sess=None, save_path='multiply_values.ckpt')

# Create a concrete function by pruning the wrap_function (similar to sess.run).
multiply_fn = wrapped_import.prune(feeds=x, fetches=y)

# Run this function
multiply_fn(tf.constant(5.))  # inputs to concrete functions must be Tensors.

Changes from TensorFlow 1 to TensorFlow 2

This section lists out key saving and loading terms from TensorFlow 1, their TensorFlow 2 equivalents, and what has changed.

SavedModel

SavedModel is a format that stores a complete TensorFlow program with parameters and computation. It contains signatures used by serving platforms to run the model.

The file format itself has not changed significantly, so SavedModels can be loaded and served using either TensorFlow 1 or TensorFlow 2 APIs.

Differences between TensorFlow 1 and TensorFlow 2

The serving and inference use cases have not been updated in TensorFlow 2, aside from the API changes—the improvement was introduced in the ability to reuse and compose models loaded from SavedModel.

In TensorFlow 2, the program is represented by objects like tf.Variable, tf.Module, or higher-level Keras models (tf.keras.Model) and layers (tf.keras.layers). There are no more global variables that have values stored in a session, and the graph now exists in different tf.functions. Consequently, during a model export, SavedModel saves each component and function graphs separately.

When you write a TensorFlow program with the TensorFlow Python APIs, you must build an object to manage the variables, functions, and other resources. Generally, this is accomplished by using the Keras APIs, but you can also build the object by creating or subclassing tf.Module.

Keras models (tf.keras.Model) and tf.Module automatically track variables and functions attached to them. SavedModel saves these connections between modules, variables, and functions, so that they can be restored when loading.

Signatures

Signatures are the endpoints of a SavedModel—they tell the user how to run the model and what inputs are needed.

In TensorFlow 1, signatures are created by listing the input and output tensors. In TensorFlow 2, signatures are generated by passing in concrete functions. (Read more about TensorFlow functions in the Introduction to graphs and tf.function guide, particularly the Polymorphism: one Function, many graphs section.) In short, a concrete function is generated from a tf.function:

# Option 1: Specify an input signature.
@tf.function(input_signature=[...])
def fn(...):
  ...
  return outputs

tf.saved_model.save(model, path, signatures={
    'name': fn
})
# Option 2: Call `get_concrete_function`
@tf.function
def fn(...):
  ...
  return outputs

tf.saved_model.save(model, path, signatures={
    'name': fn.get_concrete_function(...)
})

Session.run

In TensorFlow 1, you could call Session.run with the imported graph as long as you already know the tensor names. This allows you to retrieve the restored variable values, or run parts of the model that were not exported in the signatures.

In TensorFlow 2, you can directly access a variable, such as a weights matrix (kernel):

model = tf.Module()
model.dense_layer = tf.keras.layers.Dense(...)
tf.saved_model.save('my_saved_model')
loaded = tf.saved_model.load('my_saved_model')
loaded.dense_layer.kernel

or call tf.functions attached to the model object: for example, loaded.__call__.

Unlike TF1, there is no way to extract parts of a function and access intermediate values. You must export all of the needed functionality in the saved object.

TensorFlow Serving migration notes

SavedModel was originally created to work with TensorFlow Serving. This platform offers different types of prediction requests: classify, regress, and predict.

The TensorFlow 1 API allows you to create these types of signatures with the utils:

Classification (classification_signature_def) and regression (regression_signature_def) restrict the inputs and outputs, so the inputs must be a tf.Example, and the outputs must be classes, scores or prediction. Meanwhile, the predict signature (predict_signature_def) has no restrictions.

SavedModels exported with the TensorFlow 2 API are compatible with TensorFlow Serving, but will only contain prediction signatures. The classification and regression signatures have been removed.

If you require the use of the classification and regression signatures, you may modify the exported SavedModel using tf.compat.v1.saved_model.signature_def_utils.MethodNameUpdater.

Next steps

To learn more about SavedModels in TensorFlow 2, check out the following guides:

If you are using TensorFlow Hub, you may find these guides useful: