TensorFlow 2.0 モデルを TensorFlow Lite に変換するには、モデルを具象関数 (concrete function) としてエクスポートする必要があります。 このドキュメントでは、具象関数とは何か、既存のモデルからどのように具象関数を生成するか、について概説します。
背景
TensorFlow 2.0 では、Eager Execution がデフォルトでオンになっています。 TensorFlow において Eager Execution とは、グラフを作成せずにオペレーションを即時評価する命令型プログラミング環境のことです。各オペレーションは、後で実行するために計算グラフを構築するのではなく、具体的な値を返します。 Eager Execution に関する詳細なガイドはこちらにあります。
Eager Execution で命令的に実行すると開発とデバッグがより対話的になりますが、デバイスへのデプロイはできなくなります。 tf.function
API はモデルをグラフとして保存することを可能にします。 これは TensorFlow2.0 で TensorFlow Lite を実行するために必要です。 tf.function
デコレータでラップされたオペレーションはすべてグラフとしてエクスポートできるので、 TensorFlow Lite FlatBuffer フォーマットに変換できます。
用語
この文書では次の用語を使用します。
- シグネチャ - 一連のオペレーションの入力と出力.
- 具象関数 - 単一のシグネチャを持つグラフ.
- 多相関数 - いくつかの具象関数グラフを1つの関数内にカプセル化した Python の呼び出し可能オブジェクト.
方法論
この章では、具象関数をエクスポートする方法を解説します。
関数に tf.function
デコーレータを付与する
関数に tf.function
デコレータを付与すると、その関数のオペレーションを含む 多相関数 が生成されます。 tf.function
のデコレータが付けられていないオペレーションはすべて Eager Execution で評価されます。 以下の例は tf.function
の使い方を示しています。
@tf.function
def pow(x):
return x ** 2
tf.function(lambda x : x ** 2)
保存したいオブジェクトを生成する
tf.function
は、tf.Module
オブジェクトの一部として保存することもできます。 その際、変数は tf.Module
内で一度だけ定義されるべきです。 以下の例は Checkpoint
を派生するクラスを作成するための2つの異なるアプローチを示しています。
class BasicModel(tf.Module):
def __init__(self):
self.const = None
@tf.function
def pow(self, x):
if self.const is None:
self.const = tf.Variable(2.)
return x ** self.const
root = BasicModel()
root = tf.Module()
root.const = tf.Variable(2.)
root.pow = tf.function(lambda x : x ** root.const)
具象関数をエクスポートする
具象関数は、TensorFlow Lite モデルに変換したり、SavedModel にエクスポートしたりできるグラフを定義します。 多相関数から具象関数をエクスポートするには、シグネチャを定義する必要があります。 シグネチャは次のようにして定義できます。
tf.function
にinput_signature
パラメータを定義します。tf.TensorSpec
をget_concrete_function
に渡します: 例)tf.TensorSpec(shape=[1], dtype = tf.float32)
- サンプルの入力テンソルを
get_concrete_function
に渡します: 例)tf.constant(1., shape=[1])
次の例は tf.function
の input_signature
パラメータを定義する方法を示しています。
class BasicModel(tf.Module):
def __init__(self):
self.const = None
@tf.function(input_signature=[tf.TensorSpec(shape=[1], dtype=tf.float32)])
def pow(self, x):
if self.const is None:
self.const = tf.Variable(2.)
return x ** self.const
# tf.Module オブジェクトを生成
root = BasicModel()
# 具象関数を取得
concrete_func = root.pow.get_concrete_function()
以下の例では、サンプルの入力テンソルを get_concrete_function
に渡しています。
# tf.Module オブジェクトを生成
root = tf.Module()
root.const = tf.Variable(2.)
root.pow = tf.function(lambda x : x ** root.const)
# 具象関数を取得
input_data = tf.constant(1., shape=[1])
concrete_func = root.pow.get_concrete_function(input_data)
プログラム例
import tensorflow as tf
# tf.Module オブジェクトを初期化
root = tf.Module()
# 変数を一度だけインスタンス化する
root.var = None
# 演算が事前に計算されないように関数を定義
@tf.function
def exported_function(x):
# 変数は一度だけ定義できます。変数は関数内で定義できますが、関数外の参照を含める必要があります。
if root.var is None:
root.var = tf.Variable(tf.random.uniform([2, 2]))
root.const = tf.constant([[37.0, -23.0], [1.0, 4.0]])
root.mult = tf.matmul(root.const, root.var)
return root.mult * x
# tf.Moduleオブジェクトの一部として関数を保存
root.func = exported_function
# 具象関数を取得
concrete_func = root.func.get_concrete_function(
tf.TensorSpec([1, 1], tf.float32))
よくある質問
具象関数を SavedModel として保存するにはどうすればいいですか?
TensorFlow Lite に変換する前に TensorFlow モデルを保存したい場合は、 SavedModel として保存する必要があります。
具象関数を取得したあとで tf.saved_model.save
を呼び出すことでモデルを保存できます。上述の例の場合、以下のようにして保存できます。
tf.saved_model.save(root, export_dir, concrete_func)
SavedModel の使用方法の詳細については、SavedModel ガイド を参照してください。
SavedModel から具象関数を得るにはどうすればいいですか?
SavedModel 内の各象徴関数は、シグネチャキーによって識別できます。
デフォルトのシグネチャキーは tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY
です。以下の例は、モデルから具象関数を取得する方法を示しています。
model = tf.saved_model.load(export_dir)
concrete_func = model.signatures[
tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY]
tf.Keras
モデルから具象関数を得るにはどうしたらいいですか?
方法が2つあります:
- モデルを SavedModel として保存します。保存処理中に具象関数が生成されるので、上述の要領でモデルをロードして具象関数を取得することができます。
- 下記のようにモデルを
tf.function
でラップします。
model = tf.keras.Sequential([tf.keras.layers.Dense(units=1, input_shape=[1])])
model.compile(optimizer='sgd', loss='mean_squared_error')
model.fit(x=[-1, 0, 1, 2, 3, 4], y=[-3, -1, 1, 3, 5, 7], epochs=50)
# 具象関数を Keras モデルから取得
run_model = tf.function(lambda x : model(x))
# 具象関数を保存
concrete_func = run_model.get_concrete_function(
tf.TensorSpec(model.inputs[0].shape, model.inputs[0].dtype))