![]() | ![]() | ![]() | ![]() |
自動微分ガイドには、勾配の計算に必要なすべてのものが含まれています。このガイドでは、tf.GradientTape
より深く一般的ではない機能に焦点を当てています。
セットアップ
import tensorflow as tf
import matplotlib as mpl
import matplotlib.pyplot as plt
mpl.rcParams['figure.figsize'] = (8, 6)
勾配記録の制御
自動微分ガイドでは、勾配計算を構築する際に、テープで監視される変数とテンソルを制御する方法を説明しました。
テープには、録音を操作する方法もあります。
GradientTape.stop_recording()
記録を停止したい場合は、 GradientTape.stop_recording()
を使用して記録を一時的に中断できます。
これは、モデルの途中で複雑な操作を区別したくない場合に、オーバーヘッドを削減するのに役立つ場合があります。これには、メトリックまたは中間結果の計算が含まれる場合があります。
x = tf.Variable(2.0)
y = tf.Variable(3.0)
with tf.GradientTape() as t:
x_sq = x * x
with t.stop_recording():
y_sq = y * y
z = x_sq + y_sq
grad = t.gradient(z, {'x': x, 'y': y})
print('dz/dx:', grad['x']) # 2*x => 4
print('dz/dy:', grad['y'])
dz/dx: tf.Tensor(4.0, shape=(), dtype=float32) dz/dy: None
完全にやり直したい場合は、 reset()
使用してください。通常、グラデーションテープブロックを終了して再起動するだけの方が読みやすくなりますが、テープブロックを終了するのが困難または不可能な場合は、 reset
を使用できます。
x = tf.Variable(2.0)
y = tf.Variable(3.0)
reset = True
with tf.GradientTape() as t:
y_sq = y * y
if reset:
# Throw out all the tape recorded so far
t.reset()
z = x * x + y_sq
grad = t.gradient(z, {'x': x, 'y': y})
print('dz/dx:', grad['x']) # 2*x => 4
print('dz/dy:', grad['y'])
dz/dx: tf.Tensor(4.0, shape=(), dtype=float32) dz/dy: None
グラデーションを停止します
上記のグローバルテープコントロールとは対照的に、 tf.stop_gradient
関数ははるかに正確です。テープ自体にアクセスすることなく、特定のパスに沿ってグラデーションが流れるのを防ぐために使用できます。
x = tf.Variable(2.0)
y = tf.Variable(3.0)
with tf.GradientTape() as t:
y_sq = y**2
z = x**2 + tf.stop_gradient(y_sq)
grad = t.gradient(z, {'x': x, 'y': y})
print('dz/dx:', grad['x']) # 2*x => 4
print('dz/dy:', grad['y'])
dz/dx: tf.Tensor(4.0, shape=(), dtype=float32) dz/dy: None
カスタムグラデーション
場合によっては、デフォルトを使用するのではなく、グラデーションの計算方法を正確に制御したい場合があります。これらの状況は次のとおりです。
- あなたが書いている新しいopのための定義されたグラデーションはありません。
- デフォルトの計算は数値的に不安定です。
- フォワードパスから高価な計算をキャッシュしたいとします。
- :あなたは(使用例の値を変更したい
tf.clip_by_value
、tf.math.round
勾配を変更せずに)。
新しい操作を作成するには、 tf.RegisterGradient
を使用して独自のtf.RegisterGradient
を設定できます。詳細については、そのページを参照してください。 (グラデーションレジストリはグローバルであるため、注意して変更してください。)
後者の3つのケースでは、 tf.custom_gradient
を使用できます。
これは、tf.clip_by_norm
を中間グラデーションに適用する例です。
# Establish an identity operation, but clip during the gradient pass
@tf.custom_gradient
def clip_gradients(y):
def backward(dy):
return tf.clip_by_norm(dy, 0.5)
return y, backward
v = tf.Variable(2.0)
with tf.GradientTape() as t:
output = clip_gradients(v * v)
print(t.gradient(output, v)) # calls "backward", which clips 4 to 2
tf.Tensor(2.0, shape=(), dtype=float32)
詳細については、 tf.custom_gradient
デコレータを参照してください。
複数のテープ
複数のテープがシームレスに相互作用します。たとえば、ここでは、各テープが異なるテンソルのセットを監視しています。
x0 = tf.constant(0.0)
x1 = tf.constant(0.0)
with tf.GradientTape() as tape0, tf.GradientTape() as tape1:
tape0.watch(x0)
tape1.watch(x1)
y0 = tf.math.sin(x0)
y1 = tf.nn.sigmoid(x1)
y = y0 + y1
ys = tf.reduce_sum(y)
tape0.gradient(ys, x0).numpy() # cos(x) => 1.0
1.0
tape1.gradient(ys, x1).numpy() # sigmoid(x1)*(1-sigmoid(x1)) => 0.25
0.25
高次の勾配
GradientTape
コンテキストマネージャー内の操作は、自動区別のために記録されます。そのコンテキストで勾配が計算される場合、勾配計算も記録されます。その結果、まったく同じAPIが高次のグラデーションでも機能します。例えば:
x = tf.Variable(1.0) # Create a Tensorflow variable initialized to 1.0
with tf.GradientTape() as t2:
with tf.GradientTape() as t1:
y = x * x * x
# Compute the gradient inside the outer `t2` context manager
# which means the gradient computation is differentiable as well.
dy_dx = t1.gradient(y, x)
d2y_dx2 = t2.gradient(dy_dx, x)
print('dy_dx:', dy_dx.numpy()) # 3 * x**2 => 3.0
print('d2y_dx2:', d2y_dx2.numpy()) # 6 * x => 6.0
dy_dx: 3.0 d2y_dx2: 6.0
これによりスカラー関数の2次導関数が得られますが、 GradientTape.gradient
はスカラーの勾配のみを計算するため、このパターンは一般化されてヘッセ行列を生成しません。ヘッセ行列を作成するには、ヤコビアンセクションのヘッセ行列の例を参照してください。
「 GradientTape.gradient
へのネストされた呼び出し」は、 GradientTape.gradient
からスカラーを計算する場合に適したパターンです。次の例のように、結果のスカラーは2番目のグラデーション計算のソースとして機能します。
例:入力勾配の正則化
多くのモデルは「敵対的な例」の影響を受けやすくなっています。このテクニックのコレクションは、モデルの入力を変更して、モデルの出力を混乱させます。最も単純な実装は、入力に対する出力の勾配に沿って1つのステップを実行します。 「入力勾配」。
敵対的な例に対するロバスト性を高めるための1つの手法は、入力勾配の正則化です。これは、入力勾配の大きさを最小化しようとします。入力勾配が小さい場合、出力の変化も小さいはずです。
以下は、入力勾配正則化の単純な実装です。実装は次のとおりです。
- 内側のテープを使用して、入力に対する出力の勾配を計算します。
- その入力勾配の大きさを計算します。
- モデルに関してその大きさの勾配を計算します。
x = tf.random.normal([7, 5])
layer = tf.keras.layers.Dense(10, activation=tf.nn.relu)
with tf.GradientTape() as t2:
# The inner tape only takes the gradient with respect to the input,
# not the variables.
with tf.GradientTape(watch_accessed_variables=False) as t1:
t1.watch(x)
y = layer(x)
out = tf.reduce_sum(layer(x)**2)
# 1. Calculate the input gradient.
g1 = t1.gradient(out, x)
# 2. Calculate the magnitude of the input gradient.
g1_mag = tf.norm(g1)
# 3. Calculate the gradient of the magnitude with respect to the model.
dg1_mag = t2.gradient(g1_mag, layer.trainable_variables)
[var.shape for var in dg1_mag]
[TensorShape([5, 10]), TensorShape([10])]
ヤコビアン
これまでのすべての例では、いくつかのソーステンソルに関してスカラーターゲットの勾配を取りました。
ヤコビ行列は、ベクトル値関数の勾配を表します。各行には、ベクトルの要素の1つの勾配が含まれています。
GradientTape.jacobian
メソッドを使用すると、ヤコビ行列を効率的に計算できます。
ご了承ください:
-
gradient
:sources
引数は、テンソルまたはテンソルのコンテナーにすることができます。 -
gradient
とは異なり:target
テンソルは単一テンソルでなければなりません。
スカラーソース
最初の例として、これはスカラーソースに関するベクトルターゲットのヤコビアンです。
x = tf.linspace(-10.0, 10.0, 200+1)
delta = tf.Variable(0.0)
with tf.GradientTape() as tape:
y = tf.nn.sigmoid(x+delta)
dy_dx = tape.jacobian(y, delta)
スカラーに関してヤコビアンを取ると、結果はターゲットの形状になり、ソースに対する各要素の勾配を示します。
print(y.shape)
print(dy_dx.shape)
(201,) (201,)
plt.plot(x.numpy(), y, label='y')
plt.plot(x.numpy(), dy_dx, label='dy/dx')
plt.legend()
_ = plt.xlabel('x')
テンソルソース
入力がスカラーであるかテンソルであるかにかかわらず、 GradientTape.jacobian
は、ターゲットの各要素に対するソースの各要素の勾配を効率的に計算します。
たとえば、このレイヤーの出力は(10, 7)
形状になります。
x = tf.random.normal([7, 5])
layer = tf.keras.layers.Dense(10, activation=tf.nn.relu)
with tf.GradientTape(persistent=True) as tape:
y = layer(x)
y.shape
TensorShape([7, 10])
そして、レイヤーのカーネルの形状は(5, 10)
5、10 (5, 10)
です:
layer.kernel.shape
TensorShape([5, 10])
カーネルに関する出力のヤコビアンの形状は、連結された2つの形状です。
j = tape.jacobian(y, layer.kernel)
j.shape
TensorShape([7, 10, 5, 10])
ターゲットの寸法を合計すると、 GradientTape.gradient
によって計算された合計の勾配が残ります。
g = tape.gradient(y, layer.kernel)
print('g.shape:', g.shape)
j_sum = tf.reduce_sum(j, axis=[0, 1])
delta = tf.reduce_max(abs(g - j_sum)).numpy()
assert delta < 1e-3
print('delta:', delta)
g.shape: (5, 10) delta: 4.7683716e-07
例:ヘシアン
tf.GradientTape
は、ヘッセ行列を構築するための明示的なメソッドを提供していませんが、 GradientTape.jacobian
メソッドを使用して構築することは可能です。
x = tf.random.normal([7, 5])
layer1 = tf.keras.layers.Dense(8, activation=tf.nn.relu)
layer2 = tf.keras.layers.Dense(6, activation=tf.nn.relu)
with tf.GradientTape() as t2:
with tf.GradientTape() as t1:
x = layer1(x)
x = layer2(x)
loss = tf.reduce_mean(x**2)
g = t1.gradient(loss, layer1.kernel)
h = t2.jacobian(g, layer1.kernel)
print(f'layer.kernel.shape: {layer1.kernel.shape}')
print(f'h.shape: {h.shape}')
layer.kernel.shape: (5, 8) h.shape: (5, 8, 5, 8)
このヘッセ行列をニュートン法のステップに使用するには、最初にその軸を行列に平坦化し、勾配をベクトルに平坦化します。
n_params = tf.reduce_prod(layer1.kernel.shape)
g_vec = tf.reshape(g, [n_params, 1])
h_mat = tf.reshape(h, [n_params, n_params])
ヘッセ行列は対称である必要があります。
def imshow_zero_center(image, **kwargs):
lim = tf.reduce_max(abs(image))
plt.imshow(image, vmin=-lim, vmax=lim, cmap='seismic', **kwargs)
plt.colorbar()
imshow_zero_center(h_mat)
ニュートン法の更新手順を以下に示します。
eps = 1e-3
eye_eps = tf.eye(h_mat.shape[0])*eps
# X(k+1) = X(k) - (∇²f(X(k)))^-1 @ ∇f(X(k))
# h_mat = ∇²f(X(k))
# g_vec = ∇f(X(k))
update = tf.linalg.solve(h_mat + eye_eps, g_vec)
# Reshape the update and apply it to the variable.
_ = layer1.kernel.assign_sub(tf.reshape(update, layer1.kernel.shape))
これは単一のtf.Variable
場合は比較的簡単ですが、これを自明でないモデルに適用するには、複数の変数にわたって完全なヘッセ行列を生成するために注意深い連結とスライスが必要になります。
バッチヤコビアン
場合によっては、ソースのスタックに関して、ターゲットのスタックのそれぞれのヤコビアンを取得する必要があります。ここで、各ターゲットとソースのペアのヤコビアンは独立しています。
たとえば、ここでは、入力x
形状(batch, ins)
と出力y
形状(batch, outs)
です。
x = tf.random.normal([7, 5])
layer1 = tf.keras.layers.Dense(8, activation=tf.nn.elu)
layer2 = tf.keras.layers.Dense(6, activation=tf.nn.elu)
with tf.GradientTape(persistent=True, watch_accessed_variables=False) as tape:
tape.watch(x)
y = layer1(x)
y = layer2(y)
y.shape
TensorShape([7, 6])
x
に関するy
の完全なヤコビアンは、 (batch, ins, batch, outs)
のみが必要な場合でも(batch, ins, batch, outs)
形状を持ち(batch, ins, outs)
。
j = tape.jacobian(y, x)
j.shape
TensorShape([7, 6, 7, 5])
スタック内の各アイテムの勾配が独立している場合、このテンソルのすべての(batch, batch)
スライスは対角行列です。
imshow_zero_center(j[:, 0, :, 0])
_ = plt.title('A (batch, batch) slice')
def plot_as_patches(j):
# Reorder axes so the diagonals will each form a contiguous patch.
j = tf.transpose(j, [1, 0, 3, 2])
# Pad in between each patch.
lim = tf.reduce_max(abs(j))
j = tf.pad(j, [[0, 0], [1, 1], [0, 0], [1, 1]],
constant_values=-lim)
# Reshape to form a single image.
s = j.shape
j = tf.reshape(j, [s[0]*s[1], s[2]*s[3]])
imshow_zero_center(j, extent=[-0.5, s[2]-0.5, s[0]-0.5, -0.5])
plot_as_patches(j)
_ = plt.title('All (batch, batch) slices are diagonal')
目的の結果を得るには、重複するbatch
ディメンションを合計するか、 tf.einsum
を使用して対角線を選択します。
j_sum = tf.reduce_sum(j, axis=2)
print(j_sum.shape)
j_select = tf.einsum('bxby->bxy', j)
print(j_select.shape)
(7, 6, 5) (7, 6, 5)
そもそも余分な次元なしで計算を行う方がはるかに効率的です。 GradientTape.batch_jacobian
メソッドはまさにそれを行います。
jb = tape.batch_jacobian(y, x)
jb.shape
WARNING:tensorflow:5 out of the last 5 calls to <function pfor.<locals>.f at 0x7f9a400e8620> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/tutorials/customization/performance#python_or_tensor_args and https://www.tensorflow.org/api_docs/python/tf/function for more details. TensorShape([7, 6, 5])
error = tf.reduce_max(abs(jb - j_sum))
assert error < 1e-3
print(error.numpy())
0.0
x = tf.random.normal([7, 5])
layer1 = tf.keras.layers.Dense(8, activation=tf.nn.elu)
bn = tf.keras.layers.BatchNormalization()
layer2 = tf.keras.layers.Dense(6, activation=tf.nn.elu)
with tf.GradientTape(persistent=True, watch_accessed_variables=False) as tape:
tape.watch(x)
y = layer1(x)
y = bn(y, training=True)
y = layer2(y)
j = tape.jacobian(y, x)
print(f'j.shape: {j.shape}')
WARNING:tensorflow:6 out of the last 6 calls to <function pfor.<locals>.f at 0x7f9a401090d0> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/tutorials/customization/performance#python_or_tensor_args and https://www.tensorflow.org/api_docs/python/tf/function for more details. j.shape: (7, 6, 7, 5)
plot_as_patches(j)
_ = plt.title('These slices are not diagonal')
_ = plt.xlabel("Don't use `batch_jacobian`")
この場合、 batch_jacobian
引き続き実行され、期待される形状の何かを返しますが、その内容の意味は不明です。
jb = tape.batch_jacobian(y, x)
print(f'jb.shape: {jb.shape}')
WARNING:tensorflow:7 out of the last 7 calls to <function pfor.<locals>.f at 0x7f9a4c0637b8> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/tutorials/customization/performance#python_or_tensor_args and https://www.tensorflow.org/api_docs/python/tf/function for more details. jb.shape: (7, 6, 5)