SavedModel ワークフローを移行する

TensorFlow.org で表示 Google Colab で実行 GitHub でソースを表示 ノートブックをダウンロード

モデルを TensorFlow 1 のグラフとセッションから tf.keras.Modeltf.functiontf.Module などの TensorFlow 2 API に移行したら、モデルの保存と読み込みコードを移行できます。このノートブックは、TensorFlow 1 と TensorFlow 2 の SavedModel 形式で保存および読み込む方法の例を提供します。TensorFlow 1 から TensorFlow 2 への移行に関連する API の変更の概要を次に示します。

TensorFlow 1 TensorFlow 2 への移行
保存 tf.compat.v1.saved_model.Builder
tf.compat.v1.saved_model.simple_save
tf.saved_model.save
Keras: tf.keras.models.save_model
読み込み tf.compat.v1.saved_model.load tf.saved_model.load
Keras: tf.keras.models.load_model
シグネチャ: 実行に使用できる
入力と出力
テンソルのセット
*.signature_def ユーティリティを使用して生成
(例: tf.compat.v1.saved_model.predict_signature_def
tf.function を作成し、tf.saved_model.savesignatures 引数を
使用してエクスポートします。
分類
および回帰
:
特別な種類のシグネチャ
tf.compat.v1.saved_model.classification_signature_def
tf.compat.v1.saved_model.regression_signature_def、および特定の Estimator エクスポートで生成。

これら 2 つのシグネチャタイプは TensorFlow 2 から削除されました。
サービス提供ライブラリがこれらのメソッド名を必要とする場合は、
tf.compat.v1.saved_model.signature_def_utils.MethodNameUpdater

マッピングのより詳細な説明については、以下の TensorFlow 1 から TensorFlow 2 への変更セクションをご覧ください。

セットアップ

以下の例は、TensorFlow 1 および TensorFlow 2 API を使用して、同じダミーの TensorFlow モデル(以下で add_two として定義)を SavedModel 形式にエクスポートおよびロードする方法を示しています。インポートとユーティリティ関数を設定することから始めます。

import tensorflow as tf
import tensorflow.compat.v1 as tf1
import shutil

def remove_dir(path):
  try:
    shutil.rmtree(path)
  except:
    pass

def add_two(input):
  return input + 2
2024-01-11 17:53:44.666422: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-01-11 17:53:44.666468: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-01-11 17:53:44.668030: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered

TensorFlow 1: SavedModel を保存してエクスポートする

TensorFlow 1 では、tf.compat.v1.saved_model.Buildertf.compat.v1.saved_model.simple_save、および tf.estimator.Estimator.export_saved_model API を使用して、TensorFlow グラフとセッションを構築、保存、およびエクスポートします。

1. SavedModelBuilder を使用して、グラフを SavedModel として保存する

remove_dir("saved-model-builder")

with tf.Graph().as_default() as g:
  with tf1.Session() as sess:
    input = tf1.placeholder(tf.float32, shape=[])
    output = add_two(input)
    print("add two output: ", sess.run(output, {input: 3.}))

    # Save with SavedModelBuilder
    builder = tf1.saved_model.Builder('saved-model-builder')
    sig_def = tf1.saved_model.predict_signature_def(
        inputs={'input': input},
        outputs={'output': output})
    builder.add_meta_graph_and_variables(
        sess, tags=["serve"], signature_def_map={
            tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY: sig_def
    })
    builder.save()
add two output:  5.0
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/saved_model/signature_def_utils_impl.py:225: build_tensor_info (from tensorflow.python.saved_model.utils_impl) is deprecated and will be removed in a future version.
Instructions for updating:
This API was designed for TensorFlow v1. See https://www.tensorflow.org/guide/migrate for instructions on how to migrate your code to TensorFlow v2.
INFO:tensorflow:No assets to save.
INFO:tensorflow:No assets to write.
INFO:tensorflow:SavedModel written to: saved-model-builder/saved_model.pb
!saved_model_cli run --dir saved-model-builder --tag_set serve \
 --signature_def serving_default --input_exprs input=10
2024-01-11 17:53:49.367763: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-01-11 17:53:49.367815: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-01-11 17:53:49.369319: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/tools/saved_model_cli.py:705: load (from tensorflow.python.saved_model.loader_impl) is deprecated and will be removed in a future version.
Instructions for updating:
Use `tf.saved_model.load` instead.
W0111 17:53:52.603642 140459272128320 deprecation.py:50] From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/tools/saved_model_cli.py:705: load (from tensorflow.python.saved_model.loader_impl) is deprecated and will be removed in a future version.
Instructions for updating:
Use `tf.saved_model.load` instead.
INFO:tensorflow:Saver not created because there are no variables in the graph to restore
I0111 17:53:52.605584 140459272128320 saver.py:1635] Saver not created because there are no variables in the graph to restore
INFO:tensorflow:The specified SavedModel has no variables; no checkpoints were restored.
I0111 17:53:52.605740 140459272128320 loader_impl.py:451] The specified SavedModel has no variables; no checkpoints were restored.
Result for output key output:
12.0

2. 提供する SavedModel を構築する

remove_dir("simple-save")

with tf.Graph().as_default() as g:
  with tf1.Session() as sess:
    input = tf1.placeholder(tf.float32, shape=[])
    output = add_two(input)
    print("add_two output: ", sess.run(output, {input: 3.}))

    tf1.saved_model.simple_save(
        sess, 'simple-save',
        inputs={'input': input},
        outputs={'output': output})
add_two output:  5.0
WARNING:tensorflow:From /tmpfs/tmp/ipykernel_21364/250978412.py:9: simple_save (from tensorflow.python.saved_model.simple_save) is deprecated and will be removed in a future version.
Instructions for updating:
This API was designed for TensorFlow v1. See https://www.tensorflow.org/guide/migrate for instructions on how to migrate your code to TensorFlow v2.
INFO:tensorflow:Assets added to graph.
INFO:tensorflow:No assets to write.
INFO:tensorflow:SavedModel written to: simple-save/saved_model.pb
!saved_model_cli run --dir simple-save --tag_set serve \
 --signature_def serving_default --input_exprs input=10
2024-01-11 17:53:54.162330: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-01-11 17:53:54.162378: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-01-11 17:53:54.163735: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/tools/saved_model_cli.py:705: load (from tensorflow.python.saved_model.loader_impl) is deprecated and will be removed in a future version.
Instructions for updating:
Use `tf.saved_model.load` instead.
W0111 17:53:57.390006 140525008353088 deprecation.py:50] From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/tools/saved_model_cli.py:705: load (from tensorflow.python.saved_model.loader_impl) is deprecated and will be removed in a future version.
Instructions for updating:
Use `tf.saved_model.load` instead.
INFO:tensorflow:Saver not created because there are no variables in the graph to restore
I0111 17:53:57.391991 140525008353088 saver.py:1635] Saver not created because there are no variables in the graph to restore
INFO:tensorflow:The specified SavedModel has no variables; no checkpoints were restored.
I0111 17:53:57.392152 140525008353088 loader_impl.py:451] The specified SavedModel has no variables; no checkpoints were restored.
Result for output key output:
12.0

3. Estimator 推論グラフを SavedModel としてエクスポートする

Estimator model_fn(以下で定義)の定義では、tf.estimator.EstimatorSpecexport_outputs を返すことにより、モデルでシグネチャを定義できます。出力にはさまざまなタイプがあります。

  • tf.estimator.export.ClassificationOutput
  • tf.estimator.export.RegressionOutput
  • tf.estimator.export.PredictOutput

これらは、それぞれ分類、回帰、および予測シグネチャタイプを生成します。

estimator が tf.estimator.Estimator.export_saved_model でエクスポートされると、これらのシグネチャはモデルとともに保存されます。

def model_fn(features, labels, mode):
  output = add_two(features['input'])
  step = tf1.train.get_global_step()
  return tf.estimator.EstimatorSpec(
      mode,
      predictions=output,
      train_op=step.assign_add(1),
      loss=tf.constant(0.),
      export_outputs={
          tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY: \
          tf.estimator.export.PredictOutput({'output': output})})
est = tf.estimator.Estimator(model_fn, 'estimator-checkpoints')

# Train for one step to create a checkpoint.
def train_fn():
  return tf.data.Dataset.from_tensors({'input': 3.})
est.train(train_fn, steps=1)

# This utility function `build_raw_serving_input_receiver_fn` takes in raw
# tensor features and builds an "input serving receiver function", which
# creates placeholder inputs to the model.
serving_input_fn = tf.estimator.export.build_raw_serving_input_receiver_fn(
    {'input': tf.constant(3.)})  # Pass in a dummy input batch.
estimator_path = est.export_saved_model('exported-estimator', serving_input_fn)

# Estimator's export_saved_model creates a time stamped directory. Move this
# to a set path so it can be inspected with `saved_model_cli` in the cell below.
!rm -rf estimator-model
import shutil
shutil.move(estimator_path, 'estimator-model')
WARNING:tensorflow:From /tmpfs/tmp/ipykernel_21364/2386241072.py:12: Estimator.__init__ (from tensorflow_estimator.python.estimator.estimator) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/estimator.py:1844: RunConfig.__init__ (from tensorflow_estimator.python.estimator.run_config) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
INFO:tensorflow:Using default config.
INFO:tensorflow:Using config: {'_model_dir': 'estimator-checkpoints', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/estimator.py:385: StopAtStepHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
INFO:tensorflow:Calling model_fn.
WARNING:tensorflow:From /tmpfs/tmp/ipykernel_21364/2386241072.py:11: PredictOutput.__init__ (from tensorflow.python.saved_model.model_utils.export_output) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
WARNING:tensorflow:From /tmpfs/tmp/ipykernel_21364/2386241072.py:4: EstimatorSpec.__new__ (from tensorflow_estimator.python.estimator.model_fn) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
INFO:tensorflow:Done calling model_fn.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/estimator.py:1416: NanTensorHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/estimator.py:1419: LoggingTensorHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/basic_session_run_hooks.py:232: SecondOrStepTimer.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/estimator.py:1456: CheckpointSaverHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
INFO:tensorflow:Create CheckpointSaverHook.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:579: StepCounterHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:586: SummarySaverHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Saving checkpoints for 0 into estimator-checkpoints/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:1455: SessionRunArgs.__new__ (from tensorflow.python.training.session_run_hook) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:1454: SessionRunContext.__init__ (from tensorflow.python.training.session_run_hook) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:1474: SessionRunValues.__new__ (from tensorflow.python.training.session_run_hook) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
INFO:tensorflow:loss = 0.0, step = 1
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 1...
INFO:tensorflow:Saving checkpoints for 1 into estimator-checkpoints/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 1...
INFO:tensorflow:Loss for final step: 0.0.
WARNING:tensorflow:From /tmpfs/tmp/ipykernel_21364/2386241072.py:22: build_raw_serving_input_receiver_fn (from tensorflow_estimator.python.estimator.export.export) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/export/export.py:375: ServingInputReceiver.__new__ (from tensorflow_estimator.python.estimator.export.export) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/saved_model/model_utils/export_utils.py:83: get_tensor_from_tensor_info (from tensorflow.python.saved_model.utils_impl) is deprecated and will be removed in a future version.
Instructions for updating:
This API was designed for TensorFlow v1. See https://www.tensorflow.org/guide/migrate for instructions on how to migrate your code to TensorFlow v2.
INFO:tensorflow:Signatures INCLUDED in export for Classify: None
INFO:tensorflow:Signatures INCLUDED in export for Regress: None
INFO:tensorflow:Signatures INCLUDED in export for Predict: ['serving_default']
INFO:tensorflow:Signatures INCLUDED in export for Train: None
INFO:tensorflow:Signatures INCLUDED in export for Eval: None
INFO:tensorflow:Restoring parameters from estimator-checkpoints/model.ckpt-1
INFO:tensorflow:Assets added to graph.
INFO:tensorflow:No assets to write.
INFO:tensorflow:SavedModel written to: exported-estimator/temp-1704995638/saved_model.pb
'estimator-model'
!saved_model_cli run --dir estimator-model --tag_set serve \
 --signature_def serving_default --input_exprs input=[10]
2024-01-11 17:53:59.588488: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-01-11 17:53:59.588536: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-01-11 17:53:59.590076: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/tools/saved_model_cli.py:705: load (from tensorflow.python.saved_model.loader_impl) is deprecated and will be removed in a future version.
Instructions for updating:
Use `tf.saved_model.load` instead.
W0111 17:54:02.841524 139832957400896 deprecation.py:50] From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/tools/saved_model_cli.py:705: load (from tensorflow.python.saved_model.loader_impl) is deprecated and will be removed in a future version.
Instructions for updating:
Use `tf.saved_model.load` instead.
INFO:tensorflow:Restoring parameters from estimator-model/variables/variables
I0111 17:54:02.848292 139832957400896 saver.py:1413] Restoring parameters from estimator-model/variables/variables
Result for output key output:
[12.]

TensorFlow 2: SavedModel を保存してエクスポートする

tf.Module で定義された SavedModel を保存してエクスポートする

TensorFlow 2 でモデルをエクスポートするには、モデルのすべての変数と関数を保持するために tf.Module または tf.keras.Model を定義する必要があります。次に、tf.saved_model.save を呼び出して、SavedModel を作成できます。詳細については、SavedModel 形式の使用カスタムモデルの保存セクションを参照してください。

class MyModel(tf.Module):
  @tf.function
  def __call__(self, input):
    return add_two(input)

model = MyModel()

@tf.function
def serving_default(input):
  return {'output': model(input)}

signature_function = serving_default.get_concrete_function(
    tf.TensorSpec(shape=[], dtype=tf.float32))
tf.saved_model.save(
    model, 'tf2-save', signatures={
        tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY: signature_function})
INFO:tensorflow:Assets written to: tf2-save/assets
!saved_model_cli run --dir tf2-save --tag_set serve \
 --signature_def serving_default --input_exprs input=10
2024-01-11 17:54:04.519420: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-01-11 17:54:04.519467: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-01-11 17:54:04.521057: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/tools/saved_model_cli.py:705: load (from tensorflow.python.saved_model.loader_impl) is deprecated and will be removed in a future version.
Instructions for updating:
Use `tf.saved_model.load` instead.
W0111 17:54:07.769636 140012752820032 deprecation.py:50] From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/tools/saved_model_cli.py:705: load (from tensorflow.python.saved_model.loader_impl) is deprecated and will be removed in a future version.
Instructions for updating:
Use `tf.saved_model.load` instead.
INFO:tensorflow:Restoring parameters from tf2-save/variables/variables
I0111 17:54:07.784440 140012752820032 saver.py:1413] Restoring parameters from tf2-save/variables/variables
Result for output key output:
12.0

Keras で定義された SavedModel を保存してエクスポートする

廃止: Keras オブジェクトについては、新しい高レベルの .keras 形式と tf.keras.Model.export を使用することをお勧めします。これについては、こちらのガイドで説明されています。既存のコードについては、低レベルの SavedModel 形式が引き続きサポートされます。

保存およびエクスポート用の Keras API(Mode.save または tf.keras.models.save_model)は、 SavedModel を tf.keras.Model からエクスポートできます。詳細については、Keras モデルの保存と読み込みをご覧ください。

inp = tf.keras.Input(3)
out = add_two(inp)
model = tf.keras.Model(inputs=inp, outputs=out)

@tf.function(input_signature=[tf.TensorSpec(shape=[], dtype=tf.float32)])
def serving_default(input):
  return {'output': model(input)}

model.save('keras-model', save_format='tf', signatures={
        tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY: serving_default})
WARNING:tensorflow:Compiled the loaded model, but the compiled metrics have yet to be built. `model.compile_metrics` will be empty until you train or evaluate the model.
WARNING:tensorflow:Compiled the loaded model, but the compiled metrics have yet to be built. `model.compile_metrics` will be empty until you train or evaluate the model.
INFO:tensorflow:Assets written to: keras-model/assets
INFO:tensorflow:Assets written to: keras-model/assets
!saved_model_cli run --dir keras-model --tag_set serve \
 --signature_def serving_default --input_exprs input=10
2024-01-11 17:54:09.503135: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-01-11 17:54:09.503181: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-01-11 17:54:09.504635: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/tools/saved_model_cli.py:705: load (from tensorflow.python.saved_model.loader_impl) is deprecated and will be removed in a future version.
Instructions for updating:
Use `tf.saved_model.load` instead.
W0111 17:54:12.746095 140284956575552 deprecation.py:50] From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/tools/saved_model_cli.py:705: load (from tensorflow.python.saved_model.loader_impl) is deprecated and will be removed in a future version.
Instructions for updating:
Use `tf.saved_model.load` instead.
INFO:tensorflow:Restoring parameters from keras-model/variables/variables
I0111 17:54:12.762121 140284956575552 saver.py:1413] Restoring parameters from keras-model/variables/variables
Result for output key output:
12.0

SavedModel の読み込み

上記の API のいずれかで保存された SavedModel は、TensorFlow 1 または TensorFlow 2 API のいずれかを使用して読み込むことができます。

TensorFlow 1 の SavedModel は通常、TensorFlow 2 に読み込まれたときに推論に使用できますが、トレーニング(勾配の生成)は、SavedModel にリソース変数が含まれている場合にのみ可能です。変数の dtype を確認できます。変数 dtype に「_ref」が含まれている場合、それは参照変数です。

TensorFlow 2 SavedModel は、SavedModel がシグネチャ付きで保存されている限り、TensorFlow 1 から読み込んで実行できます。

以下のセクションには、前のセクションで保存された SavedModels を読み込み、エクスポートされたシグネチャを呼び出す方法を示すコードサンプルが含まれています。

TensorFlow 1: tf.saved_model.load で SavedModel をロードする

TensorFlow 1 では、tf.saved_model.loadを使用して、SavedModel を現在のグラフとセッションに直接インポートできます。テンソルの入力名と出力名で Session.run を呼び出すことができます。

def load_tf1(path, input):
  print('Loading from', path)
  with tf.Graph().as_default() as g:
    with tf1.Session() as sess:
      meta_graph = tf1.saved_model.load(sess, ["serve"], path)
      sig_def = meta_graph.signature_def[tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY]
      input_name = sig_def.inputs['input'].name
      output_name = sig_def.outputs['output'].name
      print('  Output with input', input, ': ', 
            sess.run(output_name, feed_dict={input_name: input}))

load_tf1('saved-model-builder', 5.)
load_tf1('simple-save', 5.)
load_tf1('estimator-model', [5.])  # Estimator's input must be batched.
load_tf1('tf2-save', 5.)
load_tf1('keras-model', 5.)
Loading from saved-model-builder
WARNING:tensorflow:From /tmpfs/tmp/ipykernel_21364/1548963983.py:5: load (from tensorflow.python.saved_model.loader_impl) is deprecated and will be removed in a future version.
Instructions for updating:
Use `tf.saved_model.load` instead.
WARNING:tensorflow:From /tmpfs/tmp/ipykernel_21364/1548963983.py:5: load (from tensorflow.python.saved_model.loader_impl) is deprecated and will be removed in a future version.
Instructions for updating:
Use `tf.saved_model.load` instead.
INFO:tensorflow:Saver not created because there are no variables in the graph to restore
INFO:tensorflow:Saver not created because there are no variables in the graph to restore
INFO:tensorflow:The specified SavedModel has no variables; no checkpoints were restored.
INFO:tensorflow:The specified SavedModel has no variables; no checkpoints were restored.
Output with input 5.0 :  7.0
Loading from simple-save
INFO:tensorflow:Saver not created because there are no variables in the graph to restore
INFO:tensorflow:Saver not created because there are no variables in the graph to restore
INFO:tensorflow:The specified SavedModel has no variables; no checkpoints were restored.
INFO:tensorflow:The specified SavedModel has no variables; no checkpoints were restored.
Output with input 5.0 :  7.0
Loading from estimator-model
INFO:tensorflow:Restoring parameters from estimator-model/variables/variables
INFO:tensorflow:Restoring parameters from estimator-model/variables/variables
Output with input [5.0] :  [7.]
Loading from tf2-save
INFO:tensorflow:Restoring parameters from tf2-save/variables/variables
INFO:tensorflow:Restoring parameters from tf2-save/variables/variables
Output with input 5.0 :  7.0
Loading from keras-model
INFO:tensorflow:Restoring parameters from keras-model/variables/variables
INFO:tensorflow:Restoring parameters from keras-model/variables/variables
Output with input 5.0 :  7.0

TensorFlow 2: tf.saved_model で保存されたモデルを読み込む

TensorFlow 2 では、変数と関数を格納する Python オブジェクトにオブジェクトが読み込まれます。これは、TensorFlow 1 から保存されたモデルと互換性があります。

詳細については、SavedModel フォーマットの使用ガイドの tf.saved_model.load API ドキュメント、およびカスタムのモデルの読み込みと使用セクションを確認してください。

def load_tf2(path, input):
  print('Loading from', path)
  loaded = tf.saved_model.load(path)
  out = loaded.signatures[tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY](
      tf.constant(input))['output']
  print('  Output with input', input, ': ', out)

load_tf2('saved-model-builder', 5.)
load_tf2('simple-save', 5.)
load_tf2('estimator-model', [5.])  # Estimator's input must be batched.
load_tf2('tf2-save', 5.)
load_tf2('keras-model', 5.)
Loading from saved-model-builder
INFO:tensorflow:Saver not created because there are no variables in the graph to restore
INFO:tensorflow:Saver not created because there are no variables in the graph to restore
Output with input 5.0 :  tf.Tensor(7.0, shape=(), dtype=float32)
Loading from simple-save
INFO:tensorflow:Saver not created because there are no variables in the graph to restore
INFO:tensorflow:Saver not created because there are no variables in the graph to restore
Output with input 5.0 :  tf.Tensor(7.0, shape=(), dtype=float32)
Loading from estimator-model
  Output with input [5.0] :  tf.Tensor([7.], shape=(1,), dtype=float32)
Loading from tf2-save
  Output with input 5.0 :  tf.Tensor(7.0, shape=(), dtype=float32)
Loading from keras-model
  Output with input 5.0 :  tf.Tensor(7.0, shape=(), dtype=float32)

TensorFlow 2 API で保存されたモデルは、(シグネチャとしてエクスポートされたものの代わりに)モデルに添付された tf.function と変数にもアクセスできます。例えば、次の例をご覧ください。

loaded = tf.saved_model.load('tf2-save')
print('restored __call__:', loaded.__call__)
print('output with input 5.', loaded(5))
restored __call__: <tensorflow.python.saved_model.function_deserialization.RestoredFunction object at 0x7ff95bb10820>
output with input 5. tf.Tensor(7.0, shape=(), dtype=float32)

TensorFlow 2: Keras で保存されたモデルを読み込む

廃止: Keras オブジェクトについては、新しい高レベルの .keras 形式と tf.keras.Model.export を使用することをお勧めします。これについては、こちらのガイドで説明されています。既存のコードについては、低レベルの SavedModel 形式が引き続きサポートされます。

The Keras loading API—tf.keras.models.load_model—allows you to reload a saved model back into a Keras Model object. Note that this only allows you to load SavedModels saved with Keras (Model.save or tf.keras.models.save_model).

tf.saved_model.save で保存されたモデルは、tf.saved_model.load でロードする必要があります。tf.saved_model.load を使用して Model.save で保存された Keras モデルを読み込めますが、TensorFlow グラフしか取得できません。詳細については、tf.keras.models.load_model API ドキュメントと Keras モデルの保存と読み込みに関するガイドを参照してください。

loaded_model = tf.keras.models.load_model('keras-model')
loaded_model.predict_on_batch(tf.constant([1, 3, 4]))
WARNING:tensorflow:No training configuration found in save file, so the model was *not* compiled. Compile it manually.
WARNING:tensorflow:No training configuration found in save file, so the model was *not* compiled. Compile it manually.
array([3., 5., 6.], dtype=float32)

GraphDef と MetaGraphDef

<a name="graphdef_and_metagraphdef">

未処理の GraphDef または MetaGraphDef を TF2 に読み込む簡単な方法はありません。ただし、TF1 コードを変換し、v1.wrap_function を使用して、グラフを TF2 concrete_function にインポートできます。

まず、MetaGraphDef を保存します。

# Save a simple multiplication computation:
with tf.Graph().as_default() as g:
  x = tf1.placeholder(tf.float32, shape=[], name='x')
  v = tf.Variable(3.0, name='v')
  y = tf.multiply(x, v, name='y')
  with tf1.Session() as sess:
    sess.run(v.initializer)
    print(sess.run(y, feed_dict={x: 5}))
    s = tf1.train.Saver()
    s.export_meta_graph('multiply.pb', as_text=True)
    s.save(sess, 'multiply_values.ckpt')
15.0

TF1 API を使用すると、tf1.train.import_meta_graph を使用してグラフをインポートし、値を復元できます。

with tf.Graph().as_default() as g:
  meta = tf1.train.import_meta_graph('multiply.pb')
  x = g.get_tensor_by_name('x:0')
  y = g.get_tensor_by_name('y:0')
  with tf1.Session() as sess:
    meta.restore(sess, 'multiply_values.ckpt')
    print(sess.run(y, feed_dict={x: 5}))
INFO:tensorflow:Restoring parameters from multiply_values.ckpt
INFO:tensorflow:Restoring parameters from multiply_values.ckpt
15.0

グラフを読み込むための TF2 API はありませんが、eager モードで実行できる具象関数にインポートできます。

def import_multiply():
  # Any graph-building code is allowed here.
  tf1.train.import_meta_graph('multiply.pb')

# Creates a tf.function with all the imported elements in the function graph.
wrapped_import = tf1.wrap_function(import_multiply, [])
import_graph = wrapped_import.graph
x = import_graph.get_tensor_by_name('x:0')
y = import_graph.get_tensor_by_name('y:0')

# Restore the variable values.
tf1.train.Saver(wrapped_import.variables).restore(
    sess=None, save_path='multiply_values.ckpt')

# Create a concrete function by pruning the wrap_function (similar to sess.run).
multiply_fn = wrapped_import.prune(feeds=x, fetches=y)

# Run this function
multiply_fn(tf.constant(5.))  # inputs to concrete functions must be Tensors.
WARNING:tensorflow:Saver is deprecated, please switch to tf.train.Checkpoint or tf.keras.Model.save_weights for training checkpoints. When executing eagerly variables do not necessarily have unique names, and so the variable.name-based lookups Saver performs are error-prone.
WARNING:tensorflow:Saver is deprecated, please switch to tf.train.Checkpoint or tf.keras.Model.save_weights for training checkpoints. When executing eagerly variables do not necessarily have unique names, and so the variable.name-based lookups Saver performs are error-prone.
INFO:tensorflow:Restoring parameters from multiply_values.ckpt
INFO:tensorflow:Restoring parameters from multiply_values.ckpt
<tf.Tensor: shape=(), dtype=float32, numpy=15.0>

TensorFlow 1 から TensorFlow 2 への変更

<a id="changes_from_tf1_to_tf2">

このセクションでは、TensorFlow 1 の重要な保存と読み込みの用語、TensorFlow 2 の同等の用語、および変更点を一覧表示します。

SavedModel

SavedModel は、パラメータと計算を含む完全な TensorFlow プログラムを保存する形式です。これには、サービスプラットフォームがモデルを実行するために使用するシグネチャが含まれています。

ファイル形式自体は大幅に変更されていないため、SavedModels は TensorFlow 1 または TensorFlow 2 API のいずれかを使用して読み込みおよび提供できます。

TensorFlow 1 と TensorFlow 2 の違い

サービス提供推論のユースケースは、API の変更を除けば、TensorFlow 2 では更新されていません。SavedModel から読み込まれたモデルを構成して再利用する機能が改善されました。

TensorFlow 2 では、プログラムは tf.Variabletf.Module、または高レベルの Keras モデル(tf.keras.Model)およびレイヤー(tf.keras.layers)などのオブジェクトによって表されます。セッションに保管された値を持つグローバル変数はなくなり、グラフは異なる tf.function に存在するようになりました。したがって、モデルのエクスポート中に、SavedModel は各コンポーネントと関数グラフを個別に保存します。

TensorFlow Python API を使用して TensorFlow プログラムを作成する場合、変数、関数、およびその他のリソースを管理するオブジェクトを作成する必要があります。通常、これは Keras API を使用して実現されますが、tf.Module を作成またはサブクラス化してオブジェクトを構築することもできます。

Keras モデル(tf.keras.Model)と tf.Module は、それらに接続された変数と関数を自動的に追跡します。SavedModel はモジュール、変数、および関数間のこれらの接続を保存し、読み込み時に復元できるようにします。

シグネチャ

シグニチャは、SavedModel のエンドポイントです。シグニチャは、モデルの実行方法と必要な入力をユーザーに伝えます。

TensorFlow 1 では、シグネチャは入力テンソルと出力テンソルをリストすることによって作成されます。TensorFlow 2 では、シグネチャは具象関数を渡すことによって生成されます。(TensorFlow 関数の詳細については、グラフと tf.function の概要ガイド、特に多態性: 1 つの関数、多数のグラフセクションを参照してください。)要するに、具象関数は tf.function から生成されます。

# Option 1: Specify an input signature.
@tf.function(input_signature=[...])
def fn(...):
  ...
  return outputs

tf.saved_model.save(model, path, signatures={
    'name': fn
})
# Option 2: Call `get_concrete_function`
@tf.function
def fn(...):
  ...
  return outputs

tf.saved_model.save(model, path, signatures={
    'name': fn.get_concrete_function(...)
})

Session.run

TensorFlow 1 では、テンソル名が分かっている限り、インポートされたグラフで Session.run を呼び出すことができました。これにより、復元された変数値を取得したり、シグネチャでエクスポートされなかったモデルの一部を実行したりできます。

TensorFlow 2 では、重み行列(kernel)などの変数に直接アクセスできます。

model = tf.Module()
model.dense_layer = tf.keras.layers.Dense(...)
tf.saved_model.save('my_saved_model')
loaded = tf.saved_model.load('my_saved_model')
loaded.dense_layer.kernel

または、モデルオブジェクトに接続された tf.function を呼び出します。例えば、loaded.__call__ です。

TF1 とは異なり、関数の一部を抽出して中間値にアクセスする方法はありません。保存されたオブジェクトで必要なすべての機能をエクスポートする必要があります

TensorFlow Serving の移行に関する注意事項

SavedModel はもともと TensorFlow Serving で動作するように作成されました。このプラットフォームは、分類、回帰、予測など、さまざまなタイプの予測リクエストを提供します。

TensorFlow 1 API を使用すると、ユーティリティで次のタイプのシグネチャを作成できます。

分類classification_signature_def)と回帰regression_signature_def)は入力と出力を制限するため、入力は tf.Example である必要があり、出力は classesscores、または prediction である必要があります。一方、 予測シグネチャpredict_signature_def)には制限がありません。

TensorFlow 2 API でエクスポートされた SavedModel は TensorFlow Serving と互換性がありますが、予測シグネチャのみが含まれます。分類および回帰シグネチャは削除されました。

分類および回帰シグネチャの使用が必要な場合は、エクスポートされた SavedModel を tf.compat.v1.saved_model.signature_def_utils.MethodNameUpdater を使用して変更できます。

次のステップ

TensorFlow 2 の SavedModel の詳細については、次のガイドをご覧ください。

  • SavedModel 形式の使用
  • Kerasモデルの保存と読み込み

TensorFlow Hub を使用している場合は、次のガイドが役立つことがあります。