モジュール、レイヤー、モデルの概要

TensorFlow.org で表示 Google Colab で実行 GitHub でソースを表示 ノートブックをダウンロード

TensorFlow で機械学習を実行するには、モデルを定義、保存、復元する必要があります。

モデルは、抽象的に次のように定義できます。

  • テンソルで何かを計算する関数(フォワードパス
  • トレーニングに応じて更新できる何らかの変数

このガイドでは、Keras 内で TensorFlow モデルがどのように定義されているか、そして、TensorFlow が変数とモデルを収集する方法、および、モデルを保存および復元する方法を説明します。

注意:今すぐ Keras を使用するのであれば、一連の Keras ガイドをご覧ください。

セットアップ

import tensorflow as tf
from datetime import datetime

%load_ext tensorboard
2022-12-14 20:34:59.326463: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory
2022-12-14 20:34:59.326565: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory
2022-12-14 20:34:59.326574: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.

TensorFlow におけるモデルとレイヤーの定義

ほとんどのモデルはレイヤーで構成されています。レイヤーは、再利用およびトレーニング可能な変数を持つ既知の数学的構造を持つ関数です。TensorFlow では、Keras や Sonnet といった、レイヤーとモデルの高位実装の多くは、同じ基本クラスの tf.Module に基づいて構築されています。

スカラーテンソルで動作する非常に単純な tf.Module の例を次に示します。

class SimpleModule(tf.Module):
  def __init__(self, name=None):
    super().__init__(name=name)
    self.a_variable = tf.Variable(5.0, name="train_me")
    self.non_trainable_variable = tf.Variable(5.0, trainable=False, name="do_not_train_me")
  def __call__(self, x):
    return self.a_variable * x + self.non_trainable_variable

simple_module = SimpleModule(name="simple")

simple_module(tf.constant(5.0))
<tf.Tensor: shape=(), dtype=float32, numpy=30.0>

モジュールと(その延長としての)レイヤーは、「オブジェクト」のディープラーニング用語です。これらには、内部状態と、その状態を使用するメソッドがあります。

__ call__Python コーラブルのように動作する以外何も特別なことではないため、任意の関数を使用してモデルを呼び出すことができます。

ファインチューニング中のレイヤーと変数を凍結するなど、様々な理由で、変数をトレーニング対象とするかどうかを設定することができます。

注意: tf.Moduletf.keras.layers.Layertf.keras.Model の基本クラスであるため、ここに説明されているすべての内容は Keras にも当てはまります。過去の互換性の理由から、Keras レイヤーはモジュールから変数を収集しないため、モデルはモジュールのみ、または Keras レイヤーのみを使用する必要があります。ただし、以下に示す変数の検査方法はどちらの場合も同じです。

tf.Module をサブクラス化することにより、このオブジェクトのプロパティに割り当てられた tf.Variable または tf.Module インスタンスが自動的に収集されます。これにより、変数の保存や読み込みのほか、tf.Module のコレクションを作成することができます。

# All trainable variables
print("trainable variables:", simple_module.trainable_variables)
# Every variable
print("all variables:", simple_module.variables)
trainable variables: (<tf.Variable 'train_me:0' shape=() dtype=float32, numpy=5.0>,)
all variables: (<tf.Variable 'train_me:0' shape=() dtype=float32, numpy=5.0>, <tf.Variable 'do_not_train_me:0' shape=() dtype=float32, numpy=5.0>)

これは、モジュールで構成された 2 層線形レイヤーモデルの例です。

最初の高密度(線形)レイヤーは以下のとおりです。

class Dense(tf.Module):
  def __init__(self, in_features, out_features, name=None):
    super().__init__(name=name)
    self.w = tf.Variable(
      tf.random.normal([in_features, out_features]), name='w')
    self.b = tf.Variable(tf.zeros([out_features]), name='b')
  def __call__(self, x):
    y = tf.matmul(x, self.w) + self.b
    return tf.nn.relu(y)

2 つのレイヤーインスタンスを作成して適用する完全なモデルは以下のとおりです。

class SequentialModule(tf.Module):
  def __init__(self, name=None):
    super().__init__(name=name)

    self.dense_1 = Dense(in_features=3, out_features=3)
    self.dense_2 = Dense(in_features=3, out_features=2)

  def __call__(self, x):
    x = self.dense_1(x)
    return self.dense_2(x)

# You have made a model!
my_model = SequentialModule(name="the_model")

# Call it, with random results
print("Model results:", my_model(tf.constant([[2.0, 2.0, 2.0]])))
Model results: tf.Tensor([[0. 0.]], shape=(1, 2), dtype=float32)

tf.Module インスタンスは、それに割り当てられた tf.Variable または tf.Module インスタンスを再帰的に自動収集します。これにより、単一のモデルインスタンスで tf.Module のコレクションを管理し、モデル全体を保存して読み込むことができます。

print("Submodules:", my_model.submodules)
Submodules: (<__main__.Dense object at 0x7f322f119100>, <__main__.Dense object at 0x7f322f119cd0>)
for var in my_model.variables:
  print(var, "\n")
<tf.Variable 'b:0' shape=(3,) dtype=float32, numpy=array([0., 0., 0.], dtype=float32)> 

<tf.Variable 'w:0' shape=(3, 3) dtype=float32, numpy=
array([[-2.175897  , -0.32256535, -0.44930542],
       [ 0.69492286,  0.7950312 ,  0.08173124],
       [ 0.7834185 ,  1.1647267 , -0.10429258]], dtype=float32)> 

<tf.Variable 'b:0' shape=(2,) dtype=float32, numpy=array([0., 0.], dtype=float32)> 

<tf.Variable 'w:0' shape=(3, 2) dtype=float32, numpy=
array([[ 0.34717873, -1.4900515 ],
       [-0.8630259 , -0.47714272],
       [-1.2313994 ,  0.2546531 ]], dtype=float32)>

変数の作成を延期する

ここで、レイヤーへの入力サイズと出力サイズの両方を定義する必要があることに気付いたかもしれません。これは、w 変数が既知の形状を持ち、割り当てることができるようにするためです。

モジュールが特定の入力形状で最初に呼び出されるまで変数の作成を延期することにより、入力サイズを事前に指定する必要がありません。

class FlexibleDenseModule(tf.Module):
  # Note: No need for `in_features`
  def __init__(self, out_features, name=None):
    super().__init__(name=name)
    self.is_built = False
    self.out_features = out_features

  def __call__(self, x):
    # Create variables on first call.
    if not self.is_built:
      self.w = tf.Variable(
        tf.random.normal([x.shape[-1], self.out_features]), name='w')
      self.b = tf.Variable(tf.zeros([self.out_features]), name='b')
      self.is_built = True

    y = tf.matmul(x, self.w) + self.b
    return tf.nn.relu(y)
# Used in a module
class MySequentialModule(tf.Module):
  def __init__(self, name=None):
    super().__init__(name=name)

    self.dense_1 = FlexibleDenseModule(out_features=3)
    self.dense_2 = FlexibleDenseModule(out_features=2)

  def __call__(self, x):
    x = self.dense_1(x)
    return self.dense_2(x)

my_model = MySequentialModule(name="the_model")
print("Model results:", my_model(tf.constant([[2.0, 2.0, 2.0]])))
Model results: tf.Tensor([[4.012148 0.      ]], shape=(1, 2), dtype=float32)

この柔軟性のため、多くの場合、TensorFlow レイヤーは、出力の形状(tf.keras.layers.Dense)などを指定するだけで済みます。入出力サイズの両方を指定する必要はありません。

重みを保存する

tf.ModuleチェックポイントSavedModel の両方として保存できます。

チェックポイントは単なる重み(モジュールとそのサブモジュール内の変数のセットの値)です。

chkp_path = "my_checkpoint"
checkpoint = tf.train.Checkpoint(model=my_model)
checkpoint.write(chkp_path)
'my_checkpoint'

チェックポイントは、データ自体とメタデータのインデックスファイルの 2 種類のファイルで構成されます。インデックスファイルは、実際に保存されているものとチェックポイントの番号を追跡し、チェックポイントデータには変数値とその属性ルックアップパスが含まれています。

ls my_checkpoint*
my_checkpoint.data-00000-of-00001  my_checkpoint.index

チェックポイントの内部を調べると、変数のコレクション全体が保存されており、変数を含む Python オブジェクト別に並べ替えられていることを確認できます。

tf.train.list_variables(chkp_path)
[('_CHECKPOINTABLE_OBJECT_GRAPH', []),
 ('model/dense_1/b/.ATTRIBUTES/VARIABLE_VALUE', [3]),
 ('model/dense_1/w/.ATTRIBUTES/VARIABLE_VALUE', [3, 3]),
 ('model/dense_2/b/.ATTRIBUTES/VARIABLE_VALUE', [2]),
 ('model/dense_2/w/.ATTRIBUTES/VARIABLE_VALUE', [3, 2])]

分散(マルチマシン)トレーニング中にシャーディングされる可能性があるため、番号が付けられています(「00000-of-00001」など)。ただし、この例の場合、シャードは 1 つしかありません。

モデルを再度読み込むと、Python オブジェクトの値が上書きされます。

new_model = MySequentialModule()
new_checkpoint = tf.train.Checkpoint(model=new_model)
new_checkpoint.restore("my_checkpoint")

# Should be the same result as above
new_model(tf.constant([[2.0, 2.0, 2.0]]))
<tf.Tensor: shape=(1, 2), dtype=float32, numpy=array([[4.012148, 0.      ]], dtype=float32)>

注意: チェックポイントは長いトレーニングワークフローでは重要であり、tf.checkpoint.CheckpointManager はヘルパークラスとして、チェックポイント管理を大幅に簡単にすることができます。詳細については、トレーニングチェックポイントガイドをご覧ください。

関数の保存

TensorFlow は、TensorFlow ServingTensorFlow Lite で見たように、元の Python オブジェクトなしでモデルを実行できます。また、TensorFlow Hub からトレーニング済みのモデルをダウンロードした場合でも同じです。

TensorFlow は、Pythonで説明されている計算の実行方法を認識する必要がありますが、元のコードは必要ありません。認識させるには、グラフを作成することができます。これについてはグラフと関数の入門ガイドをご覧ください。

このグラフには、関数を実装する演算が含まれています。

@tf.function デコレータを追加して、このコードをグラフとして実行する必要があることを示すことにより、上記のモデルでグラフを定義できます。

class MySequentialModule(tf.Module):
  def __init__(self, name=None):
    super().__init__(name=name)

    self.dense_1 = Dense(in_features=3, out_features=3)
    self.dense_2 = Dense(in_features=3, out_features=2)

  @tf.function
  def __call__(self, x):
    x = self.dense_1(x)
    return self.dense_2(x)

# You have made a model with a graph!
my_model = MySequentialModule(name="the_model")

作成したモジュールは、前と全く同じように動作します。関数に渡される一意のシグネチャごとにグラフが作成されます。詳細については、グラフと関数の基礎ガイドをご覧ください。

print(my_model([[2.0, 2.0, 2.0]]))
print(my_model([[[2.0, 2.0, 2.0], [2.0, 2.0, 2.0]]]))
tf.Tensor([[0. 0.]], shape=(1, 2), dtype=float32)
tf.Tensor(
[[[0. 0.]
  [0. 0.]]], shape=(1, 2, 2), dtype=float32)

TensorBoard のサマリー内でグラフをトレースすると、グラフを視覚化できます。

# Set up logging.
stamp = datetime.now().strftime("%Y%m%d-%H%M%S")
logdir = "logs/func/%s" % stamp
writer = tf.summary.create_file_writer(logdir)

# Create a new model to get a fresh trace
# Otherwise the summary will not see the graph.
new_model = MySequentialModule()

# Bracket the function call with
# tf.summary.trace_on() and tf.summary.trace_export().
tf.summary.trace_on(graph=True)
tf.profiler.experimental.start(logdir)
# Call only one tf.function when tracing.
z = print(new_model(tf.constant([[2.0, 2.0, 2.0]])))
with writer.as_default():
  tf.summary.trace_export(
      name="my_func_trace",
      step=0,
      profiler_outdir=logdir)
2022-12-14 20:35:04.445899: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcupti.so.11.2'; dlerror: libcupti.so.11.2: cannot open shared object file: No such file or directory
tf.Tensor([[1.5583546 0.9466069]], shape=(1, 2), dtype=float32)

TensorBoard を起動して、トレースの結果を確認します。

#docs_infra: no_execute
%tensorboard --logdir logs/func

A screenshot of the graph, in tensorboard

SavedModel の作成

トレーニングが完了したモデルを共有するには、SavedModel の使用が推奨されます。SavedModel には関数のコレクションと重みのコレクションの両方が含まれています。

次のようにして、トレーニングしたモデルを保存することができます。

tf.saved_model.save(my_model, "the_saved_model")
INFO:tensorflow:Assets written to: the_saved_model/assets
# Inspect the SavedModel in the directory
ls -l the_saved_model
total 28
drwxr-sr-x 2 kbuilder kokoro  4096 Dec 14 20:35 assets
-rw-rw-r-- 1 kbuilder kokoro    55 Dec 14 20:35 fingerprint.pb
-rw-rw-r-- 1 kbuilder kokoro 14672 Dec 14 20:35 saved_model.pb
drwxr-sr-x 2 kbuilder kokoro  4096 Dec 14 20:35 variables
# The variables/ directory contains a checkpoint of the variables
ls -l the_saved_model/variables
total 8
-rw-rw-r-- 1 kbuilder kokoro 490 Dec 14 20:35 variables.data-00000-of-00001
-rw-rw-r-- 1 kbuilder kokoro 356 Dec 14 20:35 variables.index

saved_model.pb ファイルは、関数型の tf.Graph を記述するプロトコルバッファです。

モデルとレイヤーは、それを作成したクラスのインスタンスを実際に作成しなくても、この表現から読み込めます。これは、大規模なサービスやエッジデバイスでのサービスなど、Python インタープリタがない(または使用しない)場合や、元の Python コードが利用できないか実用的でない場合に有用です。

モデルを新しいオブジェクトとして読み込みます。

new_model = tf.saved_model.load("the_saved_model")

保存したモデルを読み込んで作成された new_model は、クラスを認識しない内部の TensorFlow ユーザーオブジェクトです。SequentialModule ではありません。

isinstance(new_model, SequentialModule)
False

この新しいモデルは、すでに定義されている入力シグネチャで機能します。このように復元されたモデルにシグネチャを追加することはできません。

print(my_model([[2.0, 2.0, 2.0]]))
print(my_model([[[2.0, 2.0, 2.0], [2.0, 2.0, 2.0]]]))
tf.Tensor([[0. 0.]], shape=(1, 2), dtype=float32)
tf.Tensor(
[[[0. 0.]
  [0. 0.]]], shape=(1, 2, 2), dtype=float32)

したがって、SavedModel を使用すると、tf.Module を使用して TensorFlow の重みとグラフを保存し、それらを再度読み込むことができます。

Keras モデルとレイヤー

ここまでは、Keras に触れずに説明してきましたが、tf.Module の上に独自の高位 API を構築することは可能です。

このセクションでは、Keras が tf.Module をどのように使用するかを説明します。Keras モデルの完全なユーザーガイドは、Keras ガイドをご覧ください。

Keras レイヤー

tf.keras.layers.Layer はすべての Keras レイヤーの基本クラスであり、tf.Module から継承します。

親を交換してから、__call__call に変更するだけで、モジュールを Keras レイヤーに変換できます。

class MyDense(tf.keras.layers.Layer):
  # Adding **kwargs to support base Keras layer arguments
  def __init__(self, in_features, out_features, **kwargs):
    super().__init__(**kwargs)

    # This will soon move to the build step; see below
    self.w = tf.Variable(
      tf.random.normal([in_features, out_features]), name='w')
    self.b = tf.Variable(tf.zeros([out_features]), name='b')
  def call(self, x):
    y = tf.matmul(x, self.w) + self.b
    return tf.nn.relu(y)

simple_layer = MyDense(name="simple", in_features=3, out_features=3)

Keras レイヤーには独自の __call__ があり、次のセクションで説明する手順を実行してから、call() を呼び出します。動作には違いはありません。

simple_layer([[2.0, 2.0, 2.0]])
<tf.Tensor: shape=(1, 3), dtype=float32, numpy=array([[0.        , 2.654318  , 0.91793084]], dtype=float32)>

build ステップ

前述のように、多くの場合都合よく、入力形状が確定するまで変数の作成を延期できます。

Keras レイヤーには追加のライフサイクルステップがあり、レイヤーをより柔軟に定義することができます。このステップは、build() 関数で定義されます。

build は 1 回だけ呼び出され、入力形状で呼び出されます。通常、変数(重み)を作成するために使用されます。

上記の MyDense レイヤーを、入力のサイズに柔軟に合わせられるように書き換えることができます。

class FlexibleDense(tf.keras.layers.Layer):
  # Note the added `**kwargs`, as Keras supports many arguments
  def __init__(self, out_features, **kwargs):
    super().__init__(**kwargs)
    self.out_features = out_features

  def build(self, input_shape):  # Create the state of the layer (weights)
    self.w = tf.Variable(
      tf.random.normal([input_shape[-1], self.out_features]), name='w')
    self.b = tf.Variable(tf.zeros([self.out_features]), name='b')

  def call(self, inputs):  # Defines the computation from inputs to outputs
    return tf.matmul(inputs, self.w) + self.b

# Create the instance of the layer
flexible_dense = FlexibleDense(out_features=3)

この時点では、モデルは構築されていないため、変数も存在しません。

flexible_dense.variables
[]

関数を呼び出すと、適切なサイズの変数が割り当てられます。

# Call it, with predictably random results
print("Model results:", flexible_dense(tf.constant([[2.0, 2.0, 2.0], [3.0, 3.0, 3.0]])))
Model results: tf.Tensor(
[[ 3.4541562   0.1854068  -6.500838  ]
 [ 5.1812344   0.27811027 -9.751257  ]], shape=(2, 3), dtype=float32)
flexible_dense.variables
[<tf.Variable 'flexible_dense/w:0' shape=(3, 3) dtype=float32, numpy=
 array([[ 0.51052654, -0.95449483, -0.72358954],
        [ 1.2890267 ,  1.3898759 , -0.31284416],
        [-0.07247515, -0.34267765, -2.2139852 ]], dtype=float32)>,
 <tf.Variable 'flexible_dense/b:0' shape=(3,) dtype=float32, numpy=array([0., 0., 0.], dtype=float32)>]

buildは 1 回しか呼び出されないため、入力形状がレイヤーの変数と互換性がない場合、入力は拒否されます。

try:
  print("Model results:", flexible_dense(tf.constant([[2.0, 2.0, 2.0, 2.0]])))
except tf.errors.InvalidArgumentError as e:
  print("Failed:", e)
Failed: Exception encountered when calling layer 'flexible_dense' (type FlexibleDense).

{ {function_node __wrapped__MatMul_device_/job:localhost/replica:0/task:0/device:GPU:0} } Matrix size-incompatible: In[0]: [1,4], In[1]: [3,3] [Op:MatMul]

Call arguments received by layer 'flexible_dense' (type FlexibleDense):
  • inputs=tf.Tensor(shape=(1, 4), dtype=float32)

Keras レイヤーには、次のような多くの追加機能があります。

  • オプションの損失
  • メトリクスのサポート
  • トレーニングと推論の使用を区別する、オプションの training 引数の組み込みサポート
  • Python でモデルのクローンを作成するための構成を正確に保存する get_configfrom_config メソッド

詳細は、カスタムレイヤーとモデルに関する完全ガイドをご覧ください。

Keras モデル

モデルはネストされた Keras レイヤーとして定義できます。

ただし、Keras は tf.keras.Model と呼ばれるフル機能のモデルクラスも提供します。Keras モデルは tf.keras.layers.Layer を継承しているため、 Keras レイヤーと同じ方法で使用、ネスト、保存することができます。Keras モデルには、トレーニング、評価、読み込み、保存、および複数のマシンでのトレーニングを容易にする追加機能があります。

上記の SequentialModule をほぼ同じコードで定義できます。先ほどと同じように、__call__call() に変換して、親を変更します。

class MySequentialModel(tf.keras.Model):
  def __init__(self, name=None, **kwargs):
    super().__init__(**kwargs)

    self.dense_1 = FlexibleDense(out_features=3)
    self.dense_2 = FlexibleDense(out_features=2)
  def call(self, x):
    x = self.dense_1(x)
    return self.dense_2(x)

# You have made a Keras model!
my_sequential_model = MySequentialModel(name="the_model")

# Call it on a tensor, with random results
print("Model results:", my_sequential_model(tf.constant([[2.0, 2.0, 2.0]])))
Model results: tf.Tensor([[ 0.6988172 -0.0814582]], shape=(1, 2), dtype=float32)

追跡変数やサブモジュールなど、すべて同じ機能を利用できます。

注意: 上記の「注意」を繰り返すと、Keras レイヤーまたはモデル内にネストされた生の tf.Module は、トレーニングまたは保存のために変数を収集しません。代わりに、Keras レイヤーを Keras レイヤーの内側にネストします。

my_sequential_model.variables
[<tf.Variable 'my_sequential_model/flexible_dense_1/w:0' shape=(3, 3) dtype=float32, numpy=
 array([[ 0.75197715,  0.21306331,  0.9158741 ],
        [-0.8419782 ,  0.32473782, -0.4771087 ],
        [-1.0678096 ,  0.00764688, -0.27426288]], dtype=float32)>,
 <tf.Variable 'my_sequential_model/flexible_dense_1/b:0' shape=(3,) dtype=float32, numpy=array([0., 0., 0.], dtype=float32)>,
 <tf.Variable 'my_sequential_model/flexible_dense_2/w:0' shape=(3, 2) dtype=float32, numpy=
 array([[-0.02459872, -0.18258274],
        [ 0.49242234, -0.660783  ],
        [ 0.31815398,  0.6583327 ]], dtype=float32)>,
 <tf.Variable 'my_sequential_model/flexible_dense_2/b:0' shape=(2,) dtype=float32, numpy=array([0., 0.], dtype=float32)>]
my_sequential_model.submodules
(<__main__.FlexibleDense at 0x7f32db23f1f0>,
 <__main__.FlexibleDense at 0x7f32dbc0dbe0>)

非常に Python 的なアプローチとして、tf.keras.Model をオーバーライドして TensorFlow モデルを構築することができます。ほかのフレームワークからモデルを移行する場合、これは非常に簡単な方法です。

モデルが既存のレイヤーと入力の単純な集合として構築されている場合は、モデルの再構築とアーキテクチャに関する追加機能を備えた Functional API を使用すると手間とスペースを節約できます。

以下は、Functional API を使用した同じモデルです。

inputs = tf.keras.Input(shape=[3,])

x = FlexibleDense(3)(inputs)
x = FlexibleDense(2)(x)

my_functional_model = tf.keras.Model(inputs=inputs, outputs=x)

my_functional_model.summary()
Model: "model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 input_1 (InputLayer)        [(None, 3)]               0         
                                                                 
 flexible_dense_3 (FlexibleD  (None, 3)                12        
 ense)                                                           
                                                                 
 flexible_dense_4 (FlexibleD  (None, 2)                8         
 ense)                                                           
                                                                 
=================================================================
Total params: 20
Trainable params: 20
Non-trainable params: 0
_________________________________________________________________
my_functional_model(tf.constant([[2.0, 2.0, 2.0]]))
<tf.Tensor: shape=(1, 2), dtype=float32, numpy=array([[-2.1804159, -2.2475119]], dtype=float32)>

ここでの主な違いは、入力形状が関数構築プロセスの一部として事前に指定されることです。この場合、input_shape 引数を完全に指定する必要がないため、一部の次元を None のままにしておくことができます。

注意:サブクラス化されたモデルでは、input_shapeInputLayer を指定する必要はありません。これらの引数とレイヤーは無視されます。

Keras モデルの保存

Keras モデルでは tf.Moduleと同じようにチェックポイントを設定できます。

Keras モデルはモジュールであるため、tf.saved_models.save() を使用して保存することもできます。ただし、Keras モデルには便利なメソッドやその他の機能があります。

my_sequential_model.save("exname_of_file")
INFO:tensorflow:Assets written to: exname_of_file/assets

このように簡単に、読み込み直すことができます。

reconstructed_model = tf.keras.models.load_model("exname_of_file")
WARNING:tensorflow:No training configuration found in save file, so the model was *not* compiled. Compile it manually.

また、Keras SavedModel は、メトリクス、損失、およびオプティマイザの状態も保存します。

再構築されたこのモデルを使用すると、同じデータで呼び出されたときと同じ結果が得られます。

reconstructed_model(tf.constant([[2.0, 2.0, 2.0]]))
<tf.Tensor: shape=(1, 2), dtype=float32, numpy=array([[ 0.6988172, -0.0814582]], dtype=float32)>

機能サポートのためにカスタムレイヤーで使用できる構成メソッドなど、Keras モデルの保存とシリアル化についてのその他の詳細情報は、保存とシリアル化のガイドをご覧ください。

次のステップ

Keras の詳細については、こちらから既存の Keras ガイドをご覧ください。

tf.module 上に構築された高位 API の例には、DeepMind の Sonnet も利用できます。詳細についてはウェブサイトをご覧ください。