![]() |
![]() |
![]() |
![]() |
TensorFlow 2 の Eager execution はデフォルトで有効になっています。ユーザーインターフェースは直感的で柔軟性に優れていますが(一度限りの演算の実行ははるかに簡単で高速に行われます)、パフォーマンスとデプロイ能力に影響がでることがあります。
プログラムからグラフを作成するには、tf.function
を使用できます。変換ツールで Python コードから Python に依存しないデータフローグラフを作成するため、パフォーマンスと移植性に優れたモデルを作成できます。また、SavedModel
を使用する際に必要となります。
このチュートリアルでは tf.function
と AutoGraph の基本的な特徴についてひととおり確認します。
主に次の内容と推奨事項について説明しています。
- Eager モードでデバッグしてから、
@tf.function
でデコレートする。 - オブジェクトミューテーションまたはリストの追加といった Python 側の効果に依存しないこと。
tf.function
は TensorFlow 演算子と最も相性が良く、NumPy と Python 呼び出しは定数に変換される。
セットアップ
# Update TensorFlow, as this notebook requires version 2.9 or later
!pip install -q -U tensorflow>=2.9.0
import tensorflow as tf
2022-12-14 20:47:44.526294: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory 2022-12-14 20:47:44.526430: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory 2022-12-14 20:47:44.526441: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.
発生する可能性のあるエラーの種類を示すヘルパー関数を定義します。
import traceback
import contextlib
# Some helper code to demonstrate the kinds of errors you might encounter.
@contextlib.contextmanager
def assert_raises(error_class):
try:
yield
except error_class as e:
print('Caught expected exception \n {}:'.format(error_class))
traceback.print_exc(limit=2)
except Exception as e:
raise e
else:
raise Exception('Expected {} to be raised but no error was raised!'.format(
error_class))
基礎
使い方
定義する Function
(@tf.function
デコレーターを適用するなどして)は、コアの TensorFlow 演算とまったく変わりません。Eager での実行や勾配の計算などを行えます。
@tf.function # The decorator converts `add` into a `Function`.
def add(a, b):
return a + b
add(tf.ones([2, 2]), tf.ones([2, 2])) # [[2., 2.], [2., 2.]]
<tf.Tensor: shape=(2, 2), dtype=float32, numpy= array([[2., 2.], [2., 2.]], dtype=float32)>
v = tf.Variable(1.0)
with tf.GradientTape() as tape:
result = add(v, 1.0)
tape.gradient(result, v)
<tf.Tensor: shape=(), dtype=float32, numpy=1.0>
Function
をほかの Function
内で使用できます。
@tf.function
def dense_layer(x, w, b):
return add(tf.matmul(x, w), b)
dense_layer(tf.ones([3, 2]), tf.ones([2, 2]), tf.ones([2]))
<tf.Tensor: shape=(3, 2), dtype=float32, numpy= array([[3., 3.], [3., 3.], [3., 3.]], dtype=float32)>
Function
は、特に小さな演算が多数含まれるグラフでは、Eager コードよりも高速に実行されることがありますが、高価な演算がいくつか含まれるグラフ(畳み込みなど)では、速度の差はあまり見られません。
import timeit
conv_layer = tf.keras.layers.Conv2D(100, 3)
@tf.function
def conv_fn(image):
return conv_layer(image)
image = tf.zeros([1, 200, 200, 100])
# Warm up
conv_layer(image); conv_fn(image)
print("Eager conv:", timeit.timeit(lambda: conv_layer(image), number=10))
print("Function conv:", timeit.timeit(lambda: conv_fn(image), number=10))
print("Note how there's not much difference in performance for convolutions")
Eager conv: 0.005778620000000956 Function conv: 0.005597081999894726 Note how there's not much difference in performance for convolutions
トレーシング
このセクションは、Function
の内部動作や実装の詳細を説明します。将来的に変更する可能性がありますが、いつなぜトレーシングが発生するのかを理解しておけば、tf.function
を効果的に使用しやすくなります。
「トレーシング」とは?
Function
は TensorFlow Graph でプログラムを実行しますが、tf.Graph
は、Eager TensorFlow プログラムにユーザーが記述するすべてのものを表現することはできません。たとえば、Python はポリモーフィズムをサポートしていますが、tf.Graph
では、その入力に特定のデータ型と次元が必要です。またはコマンドラインの引数を読み取る、エラーを発生させる、より複雑な Python オブジェクトを扱うといったサイドタスクを実施しようとしても、どれも tf.Graph
で実行することはできません。
Function
はコードを 2 つの段階に分けることで、このギャップの橋渡しの役割を果たします。
「トレーシング」と呼ばれる第 1 段階において、
Function
は新しいtf.Graph
を作成します。Python コードは通常通り実行しますが、すべての TensorFlow 演算(2 つのテンソルを加算するなど)は 据え置きとなります。これらはtf.Graph
にとらわれるため、実行しません。第 2 段階では、最初の段階で据え置きとなったすべての演算を含む
tf.Graph
が実行されます。この段階は、トレーシングの段階よりもはるかに高速に行われます。
Function
は、その入力によっては必ずしも最初の段階で呼び出されたときに実行するわけではありません。この判定がどのように行われるのかについては、以下の「トレーシングの規則」をご覧ください。最初の段階を省略して 2 番目の段階のみを実行できれば、TensorFlow の高いパフォーマンスが発揮されます。
Function
がトレーシングしないと判断した場合、トレーシング段階の直後に 第 2 段階が始まるため、Function
を呼び出すと、tf.Graph
の作成と実行が行われます。後の方で、get_concrete_function
を使ってトレーシング段階のみを実行する方法を説明します。
型の異なる引数を Function
に渡すと、両方の段階が実行されます。
@tf.function
def double(a):
print("Tracing with", a)
return a + a
print(double(tf.constant(1)))
print()
print(double(tf.constant(1.1)))
print()
print(double(tf.constant("a")))
print()
Tracing with Tensor("a:0", shape=(), dtype=int32) tf.Tensor(2, shape=(), dtype=int32) Tracing with Tensor("a:0", shape=(), dtype=float32) tf.Tensor(2.2, shape=(), dtype=float32) Tracing with Tensor("a:0", shape=(), dtype=string) tf.Tensor(b'aa', shape=(), dtype=string)
同じ型の引数で Function
を繰り返し呼び出すと、生成されるグラフはまったく同じになるため、TensorFlow はトレーシング段階を省略して前にトレーシングしたグラフを再利用することに注意してください。
# This doesn't print 'Tracing with ...'
print(double(tf.constant("b")))
tf.Tensor(b'bb', shape=(), dtype=string)
すべての利用可能なトレースを確認するには、pretty_printed_concrete_signatures()
を使用できます。
print(double.pretty_printed_concrete_signatures())
double(a) Args: a: int32 Tensor, shape=() Returns: int32 Tensor, shape=() double(a) Args: a: float32 Tensor, shape=() Returns: float32 Tensor, shape=() double(a) Args: a: string Tensor, shape=() Returns: string Tensor, shape=()
ここまで、tf.function
が TensorFlow のグラフトレーシングロジックにキャッシュされた動的ディスパッチレイヤーを作成するのを見てきました。用語についてより具体的に説明すると、次のように言えます。
tf.Graph
は、言語に依存しない、生の移植可能な TensorFlow 計算の表現です。ConcreteFunction
はtf.Graph
をラップします。Function
はConcreteFunction
のキャッシュを管理し、入力に適したものを選択します。tf.function
は Python 関数をラップし、Function
オブジェクトを返します。- トレーシングは
tf.Graph
を作成し、それをConcreteFunction
(またはトレース)をラップします。
トレーシングの規則
Function
が呼び出されると、各引数の tf.types.experimental.TraceType
を使用して呼び出し引数を既存の ConcreteFunction
に一致させます。一致する ConcreteFunction
が見つかった場合、呼び出しはそれにディスパッチされます。一致するものが見つからない場合、新しい ConcreteFunction
がトレースされます。
複数の一致が見つかった場合は、最も具体的なシグネチャが選択されます。マッチングは、たとえば C++ や Java での通常の関数呼び出しと同じように、サブタイプ化によって行われます。例えば、TensorShape([1, 2])
は TensorShape([None, None])
のサブタイプ化であるため、TensorShape([1, 2])
を使用した tf.function への呼び出しは、TensorShape([None, None])
で生成された ConcreteFunction
にディスパッチできます。しかし、TensorShape([1, None])
を持つ ConcreteFunction
も存在する場合は、より具体的であるため優先されます。
TraceType
は、次のように入力引数から決定されます。
Tensor
の場合、型はTensor
のdtype
とshape
によってパラメータ化されます。階数付けされた形状は、階数付けされていない形状のサブタイプです。固定次元は未知次元のサブタイプですVariable
の場合、型はTensor
に似ていますが、変数の一意のリソース ID も含まれています。これは、制御の依存関係を正しく設定するために必要です。- Python プリミティブ値の場合、型は値自体に対応します。たとえば、値
3
のTraceType
は、int
ではなくLiteralTraceType<3>
です。 list
やtuple
などの Python の順序付きコンテナの場合、型はそれらの要素の型によってパラメータ化されます。たとえば、[1, 2]
の型はListTraceType<LiteralTraceType<1>, LiteralTraceType<2>>
であり、[2, 1]
の型はListTraceType<LiteralTraceType<2>, LiteralTraceType<1>>
であり、異なります。dict
などの Python マッピングの場合、型も同じキーからのマッピングですが、実際の値ではなく値の型へのマッピングです。たとえば、{1: 2, 3: 4}
の型はMappingTraceType<<KeyValue<1, LiteralTraceType<2>>>, <KeyValue<3, LiteralTraceType<4>>>>
です。ただし、順序付きコンテナとは異なり、{1: 2, 3: 4}
と{3: 4, 1: 2}
の型は同等です。__tf_tracing_type__
メソッドを実装する Python オブジェクトの場合、型はそのメソッドが返すものです- 他のすべての Python オブジェクトの場合、型はオブジェクトの Python の等価性とハッシュを使用して照合するジェネリック
TraceType
です。(注意: オブジェクトへの weakref に依存しているため、オブジェクトがスコープ内にある/削除されていない場合にのみ機能します。)
注意: TraceType
は Function
入力パラメータに基づいているため、グローバル変数と自由変数を変更するだけでは、新しいトレースは作成されません。Python のグローバル変数と自由変数を扱う際の推奨される方法については、こちらのセクションをご覧ください。
リトレーシングの制御
リトレーシングは、Function
が 2 つ以上のトレースを作成する際に発生します。これは、TensorFlow が一連の入力ごとに正しいグラフを生成する上で役立ちますが、トレーシングは高価な演算です!Function
が呼び出しごとに新しいグラフをリトレーシングすると、コードの実行は tf.function
を使用しない場合よりも遅くなってしまいます。
トレーシングの動作を制御するには、次のテクニックを使用できます。
固定の input_signature
を tf.function
に渡す
@tf.function(input_signature=(tf.TensorSpec(shape=[None], dtype=tf.int32),))
def next_collatz(x):
print("Tracing with", x)
return tf.where(x % 2 == 0, x // 2, 3 * x + 1)
print(next_collatz(tf.constant([1, 2])))
# You specified a 1-D tensor in the input signature, so this should fail.
with assert_raises(ValueError):
next_collatz(tf.constant([[1, 2], [3, 4]]))
# You specified an int32 dtype in the input signature, so this should fail.
with assert_raises(ValueError):
next_collatz(tf.constant([1.0, 2.0]))
Tracing with Tensor("x:0", shape=(None,), dtype=int32) tf.Tensor([4 1], shape=(2,), dtype=int32) Caught expected exception <class 'ValueError'>: Caught expected exception <class 'ValueError'>: Traceback (most recent call last): File "/tmpfs/tmp/ipykernel_72089/3551158538.py", line 8, in assert_raises yield File "/tmpfs/tmp/ipykernel_72089/1851403433.py", line 9, in <module> next_collatz(tf.constant([[1, 2], [3, 4]])) ValueError: Python inputs incompatible with input_signature: inputs: ( tf.Tensor( [[1 2] [3 4]], shape=(2, 2), dtype=int32)) input_signature: ( TensorSpec(shape=(None,), dtype=tf.int32, name=None)). Traceback (most recent call last): File "/tmpfs/tmp/ipykernel_72089/3551158538.py", line 8, in assert_raises yield File "/tmpfs/tmp/ipykernel_72089/1851403433.py", line 13, in <module> next_collatz(tf.constant([1.0, 2.0])) ValueError: Python inputs incompatible with input_signature: inputs: ( tf.Tensor([1. 2.], shape=(2,), dtype=float32)) input_signature: ( TensorSpec(shape=(None,), dtype=tf.int32, name=None)).
柔軟性のために未知の次元を使用する
TensorFlow は形状に基づいてテンソルを一致させるため、ワイルドカードとして None
次元を使用することで、Function
が可変サイズの入力にトレースを再利用できるようになります。可変サイズの入力は、長さの異なるシーケンスがある場合や、バッチごとに画像のサイズが異なる場合に発生します(例として、Transformer と Deep Dream チュートリアルをご覧ください)。
@tf.function(input_signature=(tf.TensorSpec(shape=[None], dtype=tf.int32),))
def g(x):
print('Tracing with', x)
return x
# No retrace!
print(g(tf.constant([1, 2, 3])))
print(g(tf.constant([1, 2, 3, 4, 5])))
Tracing with Tensor("x:0", shape=(None,), dtype=int32) tf.Tensor([1 2 3], shape=(3,), dtype=int32) tf.Tensor([1 2 3 4 5], shape=(5,), dtype=int32)
Python リテラルの代わりにテンソルを渡す
通常、Python 引数は、num_layers=10
または training=True
または nonlinearity='relu'
などのように、ハイパーパラメータとグラフ構造の制御に使用されます。そのため、Python 引数が変わると、当然グラフをリトレースする必要が出てきます。
しかし、Python 引数がグラフ構造の制御に使用されていない場合もあります。こういった場合、Python の値の変化によってリトレーシングがトリガーされますが、これは不要です。この、AutoGraph が動的にアンロールするトレーニングループを例に見てみましょう。トレースが何度も行われますが、生成されたグラフはまったく同じであるため、リトレーシングは不要と言えます。
def train_one_step():
pass
@tf.function
def train(num_steps):
print("Tracing with num_steps = ", num_steps)
tf.print("Executing with num_steps = ", num_steps)
for _ in tf.range(num_steps):
train_one_step()
print("Retracing occurs for different Python arguments.")
train(num_steps=10)
train(num_steps=20)
print()
print("Traces are reused for Tensor arguments.")
train(num_steps=tf.constant(10))
train(num_steps=tf.constant(20))
Retracing occurs for different Python arguments. Tracing with num_steps = 10 Executing with num_steps = 10 Tracing with num_steps = 20 Executing with num_steps = 20 Traces are reused for Tensor arguments. Tracing with num_steps = Tensor("num_steps:0", shape=(), dtype=int32) Executing with num_steps = 10 Executing with num_steps = 20
リトレーシングを強制する必要がある場合は、新しい Function
を作成します。トレースは絶対に、各 Function
オブジェクト間で共有されることはありません。
def f():
print('Tracing!')
tf.print('Executing')
tf.function(f)()
tf.function(f)()
Tracing! Executing Tracing! Executing
トレースプロトコルを使用する
可能であれば、代わりに Python 型を tf.experimental.ExtensionType
に変換することをお勧めします。さらに、ExtensionType
の TraceType
は、それに関連付けられた tf.TypeSpec
です。したがって、必要に応じて、デフォルトの tf.TypeSpec
を単純にオーバーライドして、ExtensionType
の Tracing Protocol
を制御できます。詳細については、拡張型ガイドの ExtensionType の TypeSpec のカスタマイズセクションを参照してください。
それ以外の場合は、Function
が特定の Python 型に関していつ再トレースする必要があるかを直接制御するために、Tracing Protocol
を自分で実装できます。
@tf.function
def get_mixed_flavor(fruit_a, fruit_b):
return fruit_a.flavor + fruit_b.flavor
class Fruit:
flavor = tf.constant([0, 0])
class Apple(Fruit):
flavor = tf.constant([1, 2])
class Mango(Fruit):
flavor = tf.constant([3, 4])
# As described in the above rules, a generic TraceType for `Apple` and `Mango`
# is generated (and a corresponding ConcreteFunction is traced) but it fails to
# match the second function call since the first pair of Apple() and Mango()
# have gone out out of scope by then and deleted.
get_mixed_flavor(Apple(), Mango()) # Traces a new concrete function
get_mixed_flavor(Apple(), Mango()) # Traces a new concrete function again
# However, each subclass of the `Fruit` class has a fixed flavor, and you
# can reuse an existing traced concrete function if it was the same
# subclass. Avoiding such unnecessary tracing of concrete functions
# can have significant performance benefits.
class FruitTraceType(tf.types.experimental.TraceType):
def __init__(self, fruit_type):
self.fruit_type = fruit_type
def is_subtype_of(self, other):
return (type(other) is FruitTraceType and
self.fruit_type is other.fruit_type)
def most_specific_common_supertype(self, others):
return self if all(self == other for other in others) else None
def __eq__(self, other):
return type(other) is FruitTraceType and self.fruit_type == other.fruit_type
def __hash__(self):
return hash(self.fruit_type)
class FruitWithTraceType:
def __tf_tracing_type__(self, context):
return FruitTraceType(type(self))
class AppleWithTraceType(FruitWithTraceType):
flavor = tf.constant([1, 2])
class MangoWithTraceType(FruitWithTraceType):
flavor = tf.constant([3, 4])
# Now if you try calling it again:
get_mixed_flavor(AppleWithTraceType(), MangoWithTraceType()) # Traces a new concrete function
get_mixed_flavor(AppleWithTraceType(), MangoWithTraceType()) # Re-uses the traced concrete function
<tf.Tensor: shape=(2,), dtype=int32, numpy=array([4, 6], dtype=int32)>
具象関数の取得
関数がトレースされるたびに新しい具象関数が作成されますが、get_concrete_function
を使うことで、具象関数を直接取得できます。
print("Obtaining concrete trace")
double_strings = double.get_concrete_function(tf.constant("a"))
print("Executing traced function")
print(double_strings(tf.constant("a")))
print(double_strings(a=tf.constant("b")))
Obtaining concrete trace Executing traced function tf.Tensor(b'aa', shape=(), dtype=string) tf.Tensor(b'bb', shape=(), dtype=string)
# You can also call get_concrete_function on an InputSpec
double_strings_from_inputspec = double.get_concrete_function(tf.TensorSpec(shape=[], dtype=tf.string))
print(double_strings_from_inputspec(tf.constant("c")))
tf.Tensor(b'cc', shape=(), dtype=string)
ConcreteFunction
を出力すると、入力引数(型付き)とその出力型の概要が表示されます。
print(double_strings)
ConcreteFunction double(a) Args: a: string Tensor, shape=() Returns: string Tensor, shape=()
また、具象関数のシグネチャを直接取得することもできます。
print(double_strings.structured_input_signature)
print(double_strings.structured_outputs)
((TensorSpec(shape=(), dtype=tf.string, name='a'),), {}) Tensor("Identity:0", shape=(), dtype=string)
互換性のない型で具象トレースを使用すると、エラーが発生します。
with assert_raises(tf.errors.InvalidArgumentError):
double_strings(tf.constant(1))
Caught expected exception <class 'tensorflow.python.framework.errors_impl.InvalidArgumentError'>: Traceback (most recent call last): File "/tmpfs/tmp/ipykernel_72089/3551158538.py", line 8, in assert_raises yield File "/tmpfs/tmp/ipykernel_72089/3196284684.py", line 2, in <module> double_strings(tf.constant(1)) tensorflow.python.framework.errors_impl.InvalidArgumentError: cannot compute __inference_double_166 as input #0(zero-based) was expected to be a string tensor but is a int32 tensor [Op:__inference_double_166]
Python 引数は、具象関数の入力シグネチャで特別に扱われていることに気づいたかもしれません。TensorFlow 2.3 より前では、Python 引数は単に具象関数のシグネチャから削除されていましたが、TensorFlow 2.3 からはシグネチャに残されたまま、トレーシング中に値セットを取るように制約されています。
@tf.function
def pow(a, b):
return a ** b
square = pow.get_concrete_function(a=tf.TensorSpec(None, tf.float32), b=2)
print(square)
ConcreteFunction pow(a, b=2) Args: a: float32 Tensor, shape=<unknown> Returns: float32 Tensor, shape=<unknown>
assert square(tf.constant(10.0)) == 100
with assert_raises(TypeError):
square(tf.constant(10.0), b=3)
Caught expected exception <class 'TypeError'>: Traceback (most recent call last): File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/monomorphic_function.py", line 1487, in _call_impl return self._call_with_flat_signature(args, kwargs, File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/monomorphic_function.py", line 1532, in _call_with_flat_signature raise TypeError(f"{self._flat_signature_summary()} got unexpected " TypeError: pow(a) got unexpected keyword arguments: b. During handling of the above exception, another exception occurred: Traceback (most recent call last): File "/tmpfs/tmp/ipykernel_72089/3551158538.py", line 8, in assert_raises yield File "/tmpfs/tmp/ipykernel_72089/2310937119.py", line 4, in <module> square(tf.constant(10.0), b=3) TypeError: ConcreteFunction pow(a, b) was constructed with int value 2 in b, but was called with int value 3.
グラフの取得
それぞれの具象関数は、tf.Graph
を囲む呼び出し可能なラッパーです。通常、実際の tf.Graph
オブジェクトを取得する必要はないにしろ、具象関数から簡単に取得することが可能です。
graph = double_strings.graph
for node in graph.as_graph_def().node:
print(f'{node.input} -> {node.name}')
[] -> a ['a', 'a'] -> add ['add'] -> Identity
デバッグ
一般的に、コードのデバックは、tf.function
内で行うよりも、Eager モードで行う方が簡単です。Eager モードでは、tf.function
でデコレートする前に、コードがエラーなく実行することを確認しておく必要があります。デバッグプロセスを支援する目的で、tf.config.run_functions_eagerly(True)
を呼び出すと、tf.function
をグローバルに無効にして、有効にし直すことができます。
tf.function
内でのみ出現する問題を追跡する場合、次のようなヒントがあります。
- 従来のシンプルな Python
print
呼び出しは、トレーシング中にのみ実行されるため、関数が(リ)トレーシングされるときに追跡しやすくなります。 tf.print
呼び出しは毎回実行するため、実行中の中間値の追跡に役立ちます。tf.debugging.enable_check_numerics
は、NaN と Inf がいつ作成されるかを簡単に追跡できます。pdb
(Python デバッガ)は、トレーシング中に何が起きているのかを理解する上で役立ちます。(注意:pdb
が示すのは、AutoGraph 変換ソースコードです。)
AutoGraph 変換
AutoGraph は、tf.function
内でデフォルトで利用できるようになっているライブラリで、Python の Eager コードのサブセットとグラフ対応の TensorFlow 演算に変換します。これには、if
、for
、while
などの制御フローが含まれます。
tf.cond
や tf.while_loop
などの TensorFlow 演算は機能し続けますが、制御フローは、Python で記述された場合の方が書きやすく理解しやすいことがほとんどです。
# A simple loop
@tf.function
def f(x):
while tf.reduce_sum(x) > 1:
tf.print(x)
x = tf.tanh(x)
return x
f(tf.random.uniform([5]))
[0.580614805 0.490377903 0.517188191 0.0688399076 0.130100965] [0.523112118 0.454516321 0.475526899 0.0687313601 0.129371852] [0.480098397 0.425604343 0.442654103 0.0686233342 0.128654867] [0.446322411 0.40164122 0.415842 0.0685158074 0.127949685] [0.418871284 0.381352365 0.393421769 0.0684087873 0.127255991] [0.395979106 0.36388135 0.374306321 0.0683022663 0.126573473] [0.376503289 0.348628223 0.357752621 0.0681962371 0.125901818] [0.359666914 0.335158378 0.34323293 0.0680907071 0.125240773] [0.344920605 0.323148131 0.330360562 0.0679856688 0.124590032] [0.331863195 0.31235069 0.318844706 0.0678811148 0.123949341] [0.320193917 0.302574098 0.308461905 0.0677770451 0.123318449] [0.30968225 0.293666482 0.299037188 0.0676734447 0.1226971] [0.300148 0.285505921 0.290431261 0.0675703213 0.122085057] [0.291448027 0.277993202 0.282531679 0.0674676672 0.121482089] [0.283467025 0.271046698 0.275246531 0.0673654824 0.12088798] <tf.Tensor: shape=(5,), dtype=float32, numpy= array([0.27611086, 0.26459852, 0.2684999 , 0.06726376, 0.12030251], dtype=float32)>
興味があれば、AutoGraph が生成するコードを検査できます。
print(tf.autograph.to_code(f.python_function))
def tf__f(x): with ag__.FunctionScope('f', 'fscope', ag__.ConversionOptions(recursive=True, user_requested=True, optional_features=(), internal_convert_user_code=True)) as fscope: do_return = False retval_ = ag__.UndefinedReturnValue() def get_state(): return (x,) def set_state(vars_): nonlocal x (x,) = vars_ def loop_body(): nonlocal x ag__.converted_call(ag__.ld(tf).print, (ag__.ld(x),), None, fscope) x = ag__.converted_call(ag__.ld(tf).tanh, (ag__.ld(x),), None, fscope) def loop_test(): return ag__.converted_call(ag__.ld(tf).reduce_sum, (ag__.ld(x),), None, fscope) > 1 ag__.while_stmt(loop_test, loop_body, get_state, set_state, ('x',), {}) try: do_return = True retval_ = ag__.ld(x) except: do_return = False raise return fscope.ret(retval_, do_return)
条件文
AutoGraph は if <condition>
文を相当する tf.cond
呼び出しに変換します。この置換は、<condition>
がテンソルである場合に行われます。テンソルでない場合は、if
文は Python の条件文として実行されます。
Python 条件文はトレーシング中に実行するため、条件文のブランチが 1 つだけグラフに追加されます。AutoGraph を使用しない場合、データに依存する制御フローが存在すると、トレーシングされたこのグラフは別のブランチを取ることができません。
tf.cond
は、条件文の両方のブランチをトレーシングし、実行時に動的に 1 つのブランチを選択してグラフに追加します。トレーシングには意図しない副作用がある場合があります。詳細は、AutoGraph のトレーシング効果をご覧ください。
@tf.function
def fizzbuzz(n):
for i in tf.range(1, n + 1):
print('Tracing for loop')
if i % 15 == 0:
print('Tracing fizzbuzz branch')
tf.print('fizzbuzz')
elif i % 3 == 0:
print('Tracing fizz branch')
tf.print('fizz')
elif i % 5 == 0:
print('Tracing buzz branch')
tf.print('buzz')
else:
print('Tracing default branch')
tf.print(i)
fizzbuzz(tf.constant(5))
fizzbuzz(tf.constant(20))
Tracing for loop Tracing fizzbuzz branch Tracing fizz branch Tracing buzz branch Tracing default branch 1 2 fizz 4 buzz 1 2 fizz 4 buzz fizz 7 8 fizz buzz 11 fizz 13 14 fizzbuzz 16 17 fizz 19 buzz
AutoGraph 変換の if 文におけるその他の制約事項については、リファレンスドキュメントをご覧ください。
ループ
AutoGraph は、一部の for
文と while
文を相当する tf.while_loop
などの TensorFlow のループ演算に変換します。変換されない場合、for
または while
ループは Python ループとして実行されます。
この置き換えは、次の場合に行われます。
for x in y
:y
がテンソルである場合、tf.while_loop
に変換されます。y
がtf.data.Dataset
である特別なケースでは、tf.data.Dataset
演算の組み合わせが生成されます。while <condition>
:<condition>
がテンソルである場合、tf.while_loop
に変換されます。
Python ループは、トレーシング中に実行され、ループのいてレーションごとに、tf.Graph
に追加の演算が追加されます。
TensorFlow ループはループの本体をトレーシングし、実行時に実行する反復回数を動的に選択します。ループ本体は、生成された tf.Graph
に一度だけ出現します。
AutoGraph 変換の for
文と while
文におけるその他の制約事項については、リファレンスドキュメントをご覧ください。
Python データのループ
一般的な落とし穴は、tf.function
内で Python/NumPy データをループする際にあります。このループは、トレーシングプロセス中に実行し、ループのイテレーションごとにモデルのコピーを tf.Graph
に追加してしまいます。
トレーニングループ全体を tf.function
にラップしたいのであれば、データを tf.data.Dataset
としてラップし、AutoGraph にトレーニングループを動的に展開させるようにするのが最も安全な方法です。
def measure_graph_size(f, *args):
g = f.get_concrete_function(*args).graph
print("{}({}) contains {} nodes in its graph".format(
f.__name__, ', '.join(map(str, args)), len(g.as_graph_def().node)))
@tf.function
def train(dataset):
loss = tf.constant(0)
for x, y in dataset:
loss += tf.abs(y - x) # Some dummy computation.
return loss
small_data = [(1, 1)] * 3
big_data = [(1, 1)] * 10
measure_graph_size(train, small_data)
measure_graph_size(train, big_data)
measure_graph_size(train, tf.data.Dataset.from_generator(
lambda: small_data, (tf.int32, tf.int32)))
measure_graph_size(train, tf.data.Dataset.from_generator(
lambda: big_data, (tf.int32, tf.int32)))
train([(1, 1), (1, 1), (1, 1)]) contains 11 nodes in its graph train([(1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1)]) contains 32 nodes in its graph train(<FlatMapDataset element_spec=(TensorSpec(shape=<unknown>, dtype=tf.int32, name=None), TensorSpec(shape=<unknown>, dtype=tf.int32, name=None))>) contains 6 nodes in its graph train(<FlatMapDataset element_spec=(TensorSpec(shape=<unknown>, dtype=tf.int32, name=None), TensorSpec(shape=<unknown>, dtype=tf.int32, name=None))>) contains 6 nodes in its graph
Python/Numpy データをデータセットにラップする際は、tf.data.Dataset.from_generator
と tf.data.Dataset.from_tensors
の違いに注意してください。前者は、データを Python に維持し、tf.py_function
経由で取得するため、パフォーマンスに問題がありますが、後者は、データのコピーをグラフ内の大型の tf.constant()
ノードとしてバンドル化するため、メモリに問題が現れます。
データを消費するには、TFRecordDataset
や CsvDataset
などを介してファイルからデータを読み取るのが最も効果的な方法です。そうすれば、Python を使わずに、TensorFlow 自体でデータの非同期読み込みとプリフェッチを管理できるようになります。詳細は、「tf.data
: TensorFlow 入力パイプラインを構築する」ガイドをご覧ください。
ループでの値の累積
ループの反復ごとに値を累積していくのは一般的なパターンです。通常は、Python のリストに追加したり、Python ディクショナリにエントリを追加したりして行われますが、これらは Python の副作用であるため、動的に展開されるループでは期待どおりに動作しません。動的に展開されるループの結果を累積する場合は、tf.TensorArray
を使用してください。
batch_size = 2
seq_len = 3
feature_size = 4
def rnn_step(inp, state):
return inp + state
@tf.function
def dynamic_rnn(rnn_step, input_data, initial_state):
# [batch, time, features] -> [time, batch, features]
input_data = tf.transpose(input_data, [1, 0, 2])
max_seq_len = input_data.shape[0]
states = tf.TensorArray(tf.float32, size=max_seq_len)
state = initial_state
for i in tf.range(max_seq_len):
state = rnn_step(input_data[i], state)
states = states.write(i, state)
return tf.transpose(states.stack(), [1, 0, 2])
dynamic_rnn(rnn_step,
tf.random.uniform([batch_size, seq_len, feature_size]),
tf.zeros([batch_size, feature_size]))
<tf.Tensor: shape=(2, 3, 4), dtype=float32, numpy= array([[[0.16477478, 0.8342267 , 0.05172467, 0.97317135], [1.0312288 , 0.93103075, 0.64041865, 1.3943478 ], [1.2090211 , 1.1547271 , 1.6130002 , 1.6184435 ]], [[0.2088815 , 0.52283025, 0.66112113, 0.7292708 ], [1.1666163 , 0.89018416, 1.2647486 , 1.3063059 ], [1.491866 , 1.0169158 , 1.6302503 , 1.4235928 ]]], dtype=float32)>
制限事項
TensorFlow の Function
には、設計上、いくつかの制限事項があり、Python 関数を Function
に変換する際には、注意が必要です。
Python の副作用の実行
Function
内での出力、リストへのアペンド、グローバル変数のミューテーションといった副作用は、2 回実行されたり、まったく実行しなかったりといったように、予測のつかない動作をすることがあります。また、入力セットで Function
を初めて呼び出した場合にのみ実行し、以降では、Python コードを実行せずに、トレーシング済みの tf.Graph
が再実行されてしまうこともあります。
基本的に、ロジックでは Python の副作用に依存しないようにし、トレースをデバッグするためだけに使用することをお勧めします。呼び出しごとに TensorFlow ランタイムが確実にコードを実行できるようにするには、tf.data
、tf.print
、tf.summary
、tf.Variable.assign
、tf.TensorArray
などの TensorFlow API を使用するのが最善の方法です。
@tf.function
def f(x):
print("Traced with", x)
tf.print("Executed with", x)
f(1)
f(1)
f(2)
Traced with 1 Executed with 1 Executed with 1 Traced with 2 Executed with 2
Function
の呼び出しごとに Python コードを実行する場合は、tf.py_function
が脱出口です。tf.py_function
には移植性がなく、特にパフォーマンスに優れているわけでもなく、SavedModel で保存できなければ、分散型(マルチ GPU、TPU)の環境でうまく動作するわけでもありません。また、tf.py_function
はグラフに組み込む必要もあるため、すべての入力/出力をテンソルにキャストしてしまいます。
Python のグローバル変数と自由変数の変更
Python のグローバル変数と自由変数の変更は、Python の副作用としてみなされるため、トレーシング中にのみ発生します。
external_list = []
@tf.function
def side_effect(x):
print('Python side effect')
external_list.append(x)
side_effect(1)
side_effect(1)
side_effect(1)
# The list append only happened once!
assert len(external_list) == 1
Python side effect
場合によっては気づきにくい予期しない動作が発生することがあります。以下の例では、counter
は変数のインクリメントを保護することを目的としています。ただし、これは Python 整数であり、TensorFlow オブジェクトではないため、その値は最初のトレース中にキャプチャされます。tf.function
を使用すると、assign_add
が下のグラフに無条件に記録されます。したがって、v
は、tf.function
が呼び出されるたびに 1 ずつ増加します。この問題は、tf.function
デコレータを使用して Grpah モードの Tensorflow コードを Tensorflow 2 に移行しようとする場合、Python の副作用 (例では counter
) を使用して、実行する演算を決定すると (例では、assign_add
)によく発生します。通常、ユーザーは、疑わしい数値結果を確認したり、予想よりもパフォーマンスが大幅に低下した場合に、このことに気付きます(たとえば、保護された演算に非常にコストがかかる場合)。
class Model(tf.Module):
def __init__(self):
self.v = tf.Variable(0)
self.counter = 0
@tf.function
def __call__(self):
if self.counter == 0:
# A python side-effect
self.counter += 1
self.v.assign_add(1)
return self.v
m = Model()
for n in range(3):
print(m().numpy()) # prints 1, 2, 3
1 2 3
このような動作を回避し、期待される動作を実現するためには、tf.init_scope
を使用して演算を関数グラフの外に移動します。これにより、変数のインクリメントがトレース時間中に 1 回だけ実行されるようになります。init_scope
には、制御フローのクリアや勾配テープなどの他の副作用があることに注意してください。init_scope
を使用すると非常に複雑になり、現実的に管理できない場合があります。
class Model(tf.Module):
def __init__(self):
self.v = tf.Variable(0)
self.counter = 0
@tf.function
def __call__(self):
if self.counter == 0:
# Lifts ops out of function-building graphs
with tf.init_scope():
self.counter += 1
self.v.assign_add(1)
return self.v
m = Model()
for n in range(3):
print(m().numpy()) # prints 1, 1, 1
1 1 1
まとめると、経験則として、Function
の外側で機能する整数またはリストのようなコンテナなどの Python オブジェクトのミューテーションは避けてください。代わりに、引数と TF オブジェクトを使用しましょう。たとえば、「ループでの値の累積」セクションには、リストのような演算を実装する方法の一例が示されています。
一部のケースでは、tf.Variable
である場合に状態をキャプチャして操作することができます。Keras モデルの重みは、このようにして、同じ ConcreteFunction
への呼び出しの繰り返しで更新されています。
Python イテレータとジェネレータの使用
ジェネレータやイテレータなどの多くの Python 機能は、Python ランタイムに依存して状態を追跡しています。一般的に、これらのコンストラクトは Eager モードでも期待どおりに動作しますが、Python の副作用の例であるため、トレーシング中にしか発生しません。
@tf.function
def buggy_consume_next(iterator):
tf.print("Value:", next(iterator))
iterator = iter([1, 2, 3])
buggy_consume_next(iterator)
# This reuses the first value from the iterator, rather than consuming the next value.
buggy_consume_next(iterator)
buggy_consume_next(iterator)
Value: 1 Value: 1 Value: 1
TensorFlow にリストコントラクト用の特別な tf.TensorArray
があるように、イテレーション用にも特別な tf.data.Iterator
があります。概要は、AutoGraph 変換をご覧ください。また、tf.data
API を使って、ジェネレータのパターンを実装できます。
@tf.function
def good_consume_next(iterator):
# This is ok, iterator is a tf.data.Iterator
tf.print("Value:", next(iterator))
ds = tf.data.Dataset.from_tensor_slices([1, 2, 3])
iterator = iter(ds)
good_consume_next(iterator)
good_consume_next(iterator)
good_consume_next(iterator)
Value: 1 Value: 2 Value: 3
tf.function のすべての出力は値を返す必要がある
tf.Variable
を除いて、tf.function はすべての出力を返す必要があります。戻り値を使用せずに関数からテンソルに直接アクセスしようとすると、「リーク」が発生します。
たとえば、以下の関数は、Python グローバル x
を介してテンソル a
を「リーク」します。
x = None
@tf.function
def leaky_function(a):
global x
x = a + 1 # Bad - leaks local tensor
return a + 2
correct_a = leaky_function(tf.constant(1))
print(correct_a.numpy()) # Good - value obtained from function's returns
try:
x.numpy() # Bad - tensor leaked from inside the function, cannot be used here
except AttributeError as expected:
print(expected)
3 'Tensor' object has no attribute 'numpy'
リークされた値も返される場合でもリークします。
@tf.function
def leaky_function(a):
global x
x = a + 1 # Bad - leaks local tensor
return x # Good - uses local tensor
correct_a = leaky_function(tf.constant(1))
print(correct_a.numpy()) # Good - value obtained from function's returns
try:
x.numpy() # Bad - tensor leaked from inside the function, cannot be used here
except AttributeError as expected:
print(expected)
@tf.function
def captures_leaked_tensor(b):
b += x # Bad - `x` is leaked from `leaky_function`
return b
with assert_raises(TypeError):
captures_leaked_tensor(tf.constant(2))
2 'Tensor' object has no attribute 'numpy' Caught expected exception <class 'TypeError'>: Traceback (most recent call last): File "/tmpfs/tmp/ipykernel_72089/3551158538.py", line 8, in assert_raises yield File "/tmpfs/tmp/ipykernel_72089/566849597.py", line 21, in <module> captures_leaked_tensor(tf.constant(2)) TypeError: <tf.Tensor 'add:0' shape=() dtype=int32> is out of scope and cannot be used here. Use return values, explicit Python locals or TensorFlow collections to access it. Please see https://www.tensorflow.org/guide/function#all_outputs_of_a_tffunction_must_be_return_values for more information. <tf.Tensor 'add:0' shape=() dtype=int32> was defined here: File "/usr/lib/python3.9/runpy.py", line 197, in _run_module_as_main return _run_code(code, main_globals, None, File "/usr/lib/python3.9/runpy.py", line 87, in _run_code exec(code, run_globals) File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/ipykernel_launcher.py", line 17, in <module> app.launch_new_instance() File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/traitlets/config/application.py", line 992, in launch_instance app.start() File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/ipykernel/kernelapp.py", line 711, in start self.io_loop.start() File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tornado/platform/asyncio.py", line 215, in start self.asyncio_loop.run_forever() File "/usr/lib/python3.9/asyncio/base_events.py", line 601, in run_forever self._run_once() File "/usr/lib/python3.9/asyncio/base_events.py", line 1905, in _run_once handle._run() File "/usr/lib/python3.9/asyncio/events.py", line 80, in _run self._context.run(self._callback, *self._args) File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/ipykernel/kernelbase.py", line 510, in dispatch_queue await self.process_one() File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/ipykernel/kernelbase.py", line 499, in process_one await dispatch(*args) File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/ipykernel/kernelbase.py", line 406, in dispatch_shell await result File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/ipykernel/kernelbase.py", line 729, in execute_request reply_content = await reply_content File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/ipykernel/ipkernel.py", line 411, in do_execute res = shell.run_cell( File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/ipykernel/zmqshell.py", line 531, in run_cell return super().run_cell(*args, **kwargs) File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/IPython/core/interactiveshell.py", line 2940, in run_cell result = self._run_cell( File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/IPython/core/interactiveshell.py", line 2995, in _run_cell return runner(coro) File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/IPython/core/async_helpers.py", line 129, in _pseudo_sync_runner coro.send(None) File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/IPython/core/interactiveshell.py", line 3194, in run_cell_async has_raised = await self.run_ast_nodes(code_ast.body, cell_name, File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/IPython/core/interactiveshell.py", line 3373, in run_ast_nodes if await self.run_code(code, result, async_=asy): File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/IPython/core/interactiveshell.py", line 3433, in run_code exec(code_obj, self.user_global_ns, self.user_ns) File "/tmpfs/tmp/ipykernel_72089/566849597.py", line 7, in <module> correct_a = leaky_function(tf.constant(1)) File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/util/traceback_utils.py", line 150, in error_handler return fn(*args, **kwargs) File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/polymorphic_function.py", line 880, in __call__ result = self._call(*args, **kwds) File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/polymorphic_function.py", line 928, in _call self._initialize(args, kwds, add_initializers_to=initializers) File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/polymorphic_function.py", line 749, in _initialize self._variable_creation_fn # pylint: disable=protected-access File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/tracing_compiler.py", line 162, in _get_concrete_function_internal_garbage_collected concrete_function, _ = self._maybe_define_concrete_function(args, kwargs) File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/tracing_compiler.py", line 157, in _maybe_define_concrete_function return self._maybe_define_function(args, kwargs) File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/tracing_compiler.py", line 360, in _maybe_define_function concrete_function = self._create_concrete_function(args, kwargs) File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/tracing_compiler.py", line 284, in _create_concrete_function func_graph_module.func_graph_from_py_func( File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/func_graph.py", line 1283, in func_graph_from_py_func func_outputs = python_func(*func_args, **func_kwargs) File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/polymorphic_function.py", line 645, in wrapped_fn out = weak_wrapped_fn().__wrapped__(*args, **kwds) File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/func_graph.py", line 1258, in autograph_handler return autograph.converted_call( File "/tmpfs/tmp/ipykernel_72089/566849597.py", line 4, in leaky_function x = a + 1 # Bad - leaks local tensor File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/util/traceback_utils.py", line 150, in error_handler return fn(*args, **kwargs) File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/ops/math_ops.py", line 1407, in binary_op_wrapper return func(x, y, name=name) File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/util/traceback_utils.py", line 150, in error_handler return fn(*args, **kwargs) File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/util/dispatch.py", line 1176, in op_dispatch_handler return dispatch_target(*args, **kwargs) File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/ops/math_ops.py", line 1757, in _add_dispatch return gen_math_ops.add_v2(x, y, name=name) File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/ops/gen_math_ops.py", line 475, in add_v2 _, _, _op, _outputs = _op_def_library._apply_op_helper( File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/op_def_library.py", line 795, in _apply_op_helper op = g._create_op_internal(op_type_name, inputs, dtypes=None, File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/func_graph.py", line 749, in _create_op_internal return super(FuncGraph, self)._create_op_internal( # pylint: disable=protected-access File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py", line 3798, in _create_op_internal ret = Operation( The tensor <tf.Tensor 'add:0' shape=() dtype=int32> cannot be accessed from here, because it was defined in FuncGraph(name=leaky_function, id=139638247160848), which is out of scope.
通常、このようなリークは、Python ステートメントまたはデータ構造を使用するときに発生します。アクセスできないテンソルがリークするだけでなく、このようなステートメントは Python の副作用としてカウントされ、すべての関数呼び出しで実行されないことがあるため、間違っている可能性があります。
また、一般的に外部 Python コレクションまたはオブジェクトの変更によりローカルテンソルがリークすることもあります。
class MyClass:
def __init__(self):
self.field = None
external_list = []
external_object = MyClass()
def leaky_function():
a = tf.constant(1)
external_list.append(a) # Bad - leaks tensor
external_object.field = a # Bad - leaks tensor
再帰的な tf.function はサポートされていない
再帰的な Function
はサポートされていないので、無限ループを引き起こす可能性があります。以下に例を示します。
@tf.function
def recursive_fn(n):
if n > 0:
return recursive_fn(n - 1)
else:
return 1
with assert_raises(Exception):
recursive_fn(tf.constant(5)) # Bad - maximum recursion error.
Caught expected exception <class 'Exception'>: Traceback (most recent call last): File "/tmpfs/tmp/ipykernel_72089/3551158538.py", line 8, in assert_raises yield File "/tmpfs/tmp/ipykernel_72089/2233998312.py", line 9, in <module> recursive_fn(tf.constant(5)) # Bad - maximum recursion error. tensorflow.python.autograph.impl.api.StagingError: in user code: File "/tmpfs/tmp/ipykernel_72089/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_72089/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_72089/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_72089/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_72089/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_72089/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_72089/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_72089/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_72089/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_72089/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_72089/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_72089/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_72089/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_72089/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_72089/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_72089/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_72089/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_72089/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_72089/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_72089/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_72089/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_72089/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_72089/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_72089/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_72089/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_72089/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_72089/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_72089/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_72089/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_72089/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_72089/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_72089/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_72089/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_72089/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_72089/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_72089/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_72089/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_72089/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_72089/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_72089/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_72089/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_72089/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_72089/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_72089/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_72089/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_72089/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_72089/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_72089/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_72089/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_72089/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_72089/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_72089/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_72089/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_72089/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_72089/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_72089/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_72089/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_72089/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_72089/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_72089/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_72089/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_72089/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_72089/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_72089/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_72089/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_72089/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_72089/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_72089/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_72089/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_72089/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_72089/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_72089/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_72089/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_72089/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_72089/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_72089/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_72089/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_72089/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_72089/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_72089/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_72089/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_72089/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_72089/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_72089/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_72089/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_72089/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_72089/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_72089/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_72089/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_72089/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_72089/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_72089/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_72089/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_72089/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_72089/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_72089/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_72089/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_72089/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_72089/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_72089/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_72089/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_72089/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_72089/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_72089/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_72089/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_72089/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_72089/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_72089/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_72089/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_72089/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_72089/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_72089/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_72089/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_72089/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_72089/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_72089/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_72089/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmpfs/tmp/ipykernel_72089/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/usr/lib/python3.9/abc.py", line 119, in __instancecheck__ return _abc_instancecheck(cls, instance) File "/usr/lib/python3.9/abc.py", line 123, in __subclasscheck__ return _abc_subclasscheck(cls, subclass) RecursionError: maximum recursion depth exceeded while calling a Python object
再帰的な Function
が正しく動作しているように見えても、Python 関数は複数回トレースされ、パフォーマンスに影響を与える可能性があります。以下に例を示します。
@tf.function
def recursive_fn(n):
if n > 0:
print('tracing')
return recursive_fn(n - 1)
else:
return 1
recursive_fn(5) # Warning - multiple tracings
tracing tracing tracing tracing tracing <tf.Tensor: shape=(), dtype=int32, numpy=1>
既知の問題
Function
が正しく評価していない場合、以下の既知の問題が該当する可能性があります。これらの問題は、今後修正される予定です。
Python のグローバル変数と自由変数への依存
Function
は、Python 引数の新しい値で呼び出された時に新しい ConcreteFunction
を作成しますが、Python クロージャ、グローバル変数、またはその Function
の非ローカル変数に対しては作成しません。Function
への呼び出しごとに値が変化する場合でも、Function
はトレーシングされたときの値をそのまま使用してしまいます。これは、通常の Python 関数の動作とは異なります。
このため、外側の名前を閉じる代わりに引数を使用する関数プログラミングの様式をお勧めします。
@tf.function
def buggy_add():
return 1 + foo
@tf.function
def recommended_add(foo):
return 1 + foo
foo = 1
print("Buggy:", buggy_add())
print("Correct:", recommended_add(foo))
Buggy: tf.Tensor(2, shape=(), dtype=int32) Correct: tf.Tensor(2, shape=(), dtype=int32)
print("Updating the value of `foo` to 100!")
foo = 100
print("Buggy:", buggy_add()) # Did not change!
print("Correct:", recommended_add(foo))
Updating the value of `foo` to 100! Buggy: tf.Tensor(2, shape=(), dtype=int32) Correct: tf.Tensor(101, shape=(), dtype=int32)
グローバル値を更新する別の方法として、それを tf.Variable
にし、代わりに Variable.assign
メソッドを使用することができます。
@tf.function
def variable_add():
return 1 + foo
foo = tf.Variable(1)
print("Variable:", variable_add())
Variable: tf.Tensor(2, shape=(), dtype=int32)
print("Updating the value of `foo` to 100!")
foo.assign(100)
print("Variable:", variable_add())
Updating the value of `foo` to 100! Variable: tf.Tensor(101, shape=(), dtype=int32)
Python オブジェクトへの依存
Python オブジェクトを引数として tf.function
に渡す上での推奨事項には、多数の既知の問題があります。これらは今後修正される予定です。一般的に、Python のプリミティブ型または tf.nest
と互換性のある構造を引数として使用する場合や、オブジェクトの別のインスタンスを Function
に渡す場合には、一貫したトレーシングを期待できますが、同一のオブジェクトであっても属性が異なるものを渡す場合、Function
は、新しいトレースを作成しません。
class SimpleModel(tf.Module):
def __init__(self):
# These values are *not* tf.Variables.
self.bias = 0.
self.weight = 2.
@tf.function
def evaluate(model, x):
return model.weight * x + model.bias
simple_model = SimpleModel()
x = tf.constant(10.)
print(evaluate(simple_model, x))
tf.Tensor(20.0, shape=(), dtype=float32)
print("Adding bias!")
simple_model.bias += 5.0
print(evaluate(simple_model, x)) # Didn't change :(
Adding bias! tf.Tensor(20.0, shape=(), dtype=float32)
同じ Function
を使用して、更新されたモデルのインスタンスを評価する場合、更新されたモデルには、元のモデルと同じキャッシュキーが含まれるため、不具合が生じます。
このため、ミュート可能なオブジェクト属性に依存しない Function
を記述するか、新しいオブジェクトを作成することをお勧めします。
この方法が困難な場合は、回避策として、オブジェクトを変更するたびに新しい Function
がリトレーシングを行うようにする方法が挙げられます。
def evaluate(model, x):
return model.weight * x + model.bias
new_model = SimpleModel()
evaluate_no_bias = tf.function(evaluate).get_concrete_function(new_model, x)
# Don't pass in `new_model`, `Function` already captured its state during tracing.
print(evaluate_no_bias(x))
tf.Tensor(20.0, shape=(), dtype=float32)
print("Adding bias!")
new_model.bias += 5.0
# Create new Function and ConcreteFunction since you modified new_model.
evaluate_with_bias = tf.function(evaluate).get_concrete_function(new_model, x)
print(evaluate_with_bias(x)) # Don't pass in `new_model`.
Adding bias! tf.Tensor(25.0, shape=(), dtype=float32)
リトレーシングにはコストがかかるため、tf.Variable
をオブジェクト属性として使用することができます。こうすることで、リトレーシングを行わずに、ミュートして(変更はしません)同様の効果を得ることができます。
class BetterModel:
def __init__(self):
self.bias = tf.Variable(0.)
self.weight = tf.Variable(2.)
@tf.function
def evaluate(model, x):
return model.weight * x + model.bias
better_model = BetterModel()
print(evaluate(better_model, x))
tf.Tensor(20.0, shape=(), dtype=float32)
print("Adding bias!")
better_model.bias.assign_add(5.0) # Note: instead of better_model.bias += 5
print(evaluate(better_model, x)) # This works!
Adding bias! tf.Tensor(25.0, shape=(), dtype=float32)
tf.Variables の作成
Function
は、最初の呼び出しで 1 回作成され、後続の関数呼び出しで再利用されるシングルトン tf.Variable
のみをサポートします。以下のコードスニペットは、すべての関数呼び出しで新しい tf.Variable
を作成します。これにより、ValueError
例外が発生します。
例:
@tf.function
def f(x):
v = tf.Variable(1.0)
return v
with assert_raises(ValueError):
f(1.0)
Caught expected exception <class 'ValueError'>: Traceback (most recent call last): File "/tmpfs/tmp/ipykernel_72089/3551158538.py", line 8, in assert_raises yield File "/tmpfs/tmp/ipykernel_72089/3018268426.py", line 7, in <module> f(1.0) ValueError: in user code: File "/tmpfs/tmp/ipykernel_72089/3018268426.py", line 3, in f * v = tf.Variable(1.0) ValueError: tf.function only supports singleton tf.Variables created on the first call. Make sure the tf.Variable is only created once or created outside tf.function. See https://www.tensorflow.org/guide/function#creating_tfvariables for more information.
この制限を回避するために使用される一般的なパターンは、Python None 値で開始し、値が None の場合は条件付きで tf.Variable
を作成することです。
class Count(tf.Module):
def __init__(self):
self.count = None
@tf.function
def __call__(self):
if self.count is None:
self.count = tf.Variable(0)
return self.count.assign_add(1)
c = Count()
print(c())
print(c())
tf.Tensor(1, shape=(), dtype=int32) tf.Tensor(2, shape=(), dtype=int32)
複数の Keras オプティマイザとの使用
2 つ以上の Keras オプティマイザを tf.function
で使用しようとすると、「ValueError: tf.function only supports singleton tf.Variables created on the first call.
」というエラーが発生することがあります。このエラーは、オプティマイザが初めて勾配を適用する際に、内部的に tf.Variables
を作成するために発生するものです。
opt1 = tf.keras.optimizers.Adam(learning_rate = 1e-2)
opt2 = tf.keras.optimizers.Adam(learning_rate = 1e-3)
@tf.function
def train_step(w, x, y, optimizer):
with tf.GradientTape() as tape:
L = tf.reduce_sum(tf.square(w*x - y))
gradients = tape.gradient(L, [w])
optimizer.apply_gradients(zip(gradients, [w]))
w = tf.Variable(2.)
x = tf.constant([-1.])
y = tf.constant([2.])
train_step(w, x, y, opt1)
print("Calling `train_step` with different optimizer...")
with assert_raises(ValueError):
train_step(w, x, y, opt2)
Calling `train_step` with different optimizer... Caught expected exception <class 'ValueError'>: Traceback (most recent call last): File "/tmpfs/tmp/ipykernel_72089/3551158538.py", line 8, in assert_raises yield File "/tmpfs/tmp/ipykernel_72089/3167358578.py", line 18, in <module> train_step(w, x, y, opt2) ValueError: in user code: File "/tmpfs/tmp/ipykernel_72089/3167358578.py", line 9, in train_step * optimizer.apply_gradients(zip(gradients, [w])) File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/optimizers/optimizer_experimental/optimizer.py", line 1140, in apply_gradients ** return super().apply_gradients(grads_and_vars, name=name) File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/optimizers/optimizer_experimental/optimizer.py", line 621, in apply_gradients self.build(trainable_variables) File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/optimizers/optimizer_experimental/adam.py", line 139, in build self.add_variable_from_reference( File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/optimizers/optimizer_experimental/optimizer.py", line 1072, in add_variable_from_reference return super().add_variable_from_reference( File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/optimizers/optimizer_experimental/optimizer.py", line 496, in add_variable_from_reference variable = tf.Variable( ValueError: tf.function only supports singleton tf.Variables created on the first call. Make sure the tf.Variable is only created once or created outside tf.function. See https://www.tensorflow.org/guide/function#creating_tfvariables for more information.
トレーニング中にオプティマイザを変更する必要がある場合は、回避策として、オプティマイザごとに新しい Function
を作成し、ConcreteFunction
を直接呼び出すようにすることができます。
opt1 = tf.keras.optimizers.Adam(learning_rate = 1e-2)
opt2 = tf.keras.optimizers.Adam(learning_rate = 1e-3)
# Not a tf.function.
def train_step(w, x, y, optimizer):
with tf.GradientTape() as tape:
L = tf.reduce_sum(tf.square(w*x - y))
gradients = tape.gradient(L, [w])
optimizer.apply_gradients(zip(gradients, [w]))
w = tf.Variable(2.)
x = tf.constant([-1.])
y = tf.constant([2.])
# Make a new Function and ConcreteFunction for each optimizer.
train_step_1 = tf.function(train_step).get_concrete_function(w, x, y, opt1)
train_step_2 = tf.function(train_step).get_concrete_function(w, x, y, opt2)
for i in range(10):
if i % 2 == 0:
train_step_1(w, x, y) # `opt1` is not used as a parameter.
else:
train_step_2(w, x, y) # `opt2` is not used as a parameter.
複数の Keras モデルとの使用
また、別のモデルインスタンスを同一の Function
に渡す際に、「ValueError: tf.function only supports singleton tf.Variables created on the first call.
」というエラーも発生することがあります。
このエラーは、Keras モデル(入力形状が定義されていない)と Keras レイヤーが、初めて呼び出されるときに tf.Variables
を作成するために発生するものです。これらの変数をすでに呼び出された Function
内で初期化しようとしているのでしょう。このエラーを回避するには、モデルをトレーニングする前に、model.build(input_shape)
を呼び出して、すべての重みを初期化するようにしてください。
参考資料
Function
のエクスポートと読み込みの方法については、SavedModel ガイドをご覧ください。トレーシングの後に実行するグラフの最適化については、Grappler ガイドをご覧ください。データパイプラインの最適化方法とモデルのプロファイリングについては、Profiler ガイドをご覧ください。