FFJORD

TensorFlow.org で表示 Google Colab で実行 GitHub でソースを表示 ノートブックをダウンロード

セットアップ

まず、このデモで使用するパッケージをインストールします。

pip install -q dm-sonnet

Imports (tf, tfp with adjoint trick, etc)

/usr/local/lib/python3.6/dist-packages/statsmodels/tools/_testing.py:19: FutureWarning: pandas.util.testing is deprecated. Use the functions in the public API at pandas.testing instead.
  import pandas.util.testing as tm

Helper functions for visualization

FFJORD バイジェクター

この Colab では、元々 Grathwohl, Will, et al. arxiv リンクの論文で提案された FFJORD バイジェクターを説明します。

手短に言うと、このアプローチの背後には、既知の基礎分布データ分布間の対応を確立するという考えがあります。

この関係を確立するには、以下を行う必要があります。

  1. 基底分布が定義される空間 \(\mathcal{Y}\) とデータドメインの空間 \(\mathcal{X}\) 間のバイジェクティブマップ \(\mathcal{T}*{\theta}:\mathbf{x} \rightarrow \mathbf{y}\), \(\mathcal{T}*{\theta}^{1}:\mathbf{y} \rightarrow \mathbf{x}\) を定義します。
  2. 確率の概念を \(\mathcal{X}\) に転移するために、実行する変形を効率的に追跡します。

2 つ目の条件は、\(\mathcal{X}\) に定義された確率分布のために、以下の式で形式化されます。

\[ \log p_{\mathbf{x} }(\mathbf{x})=\log p_{\mathbf{y} }(\mathbf{y})-\log \operatorname{det}\left|\frac{\partial \mathcal{T}_{\theta}(\mathbf{y})}{\partial \mathbf{y} }\right| \]

FFJORD バイジェクターは、変換 \( \mathcal{T_{\theta} } を定義することで、これを行います。\mathbf{x} = \mathbf{z}(t_{0}) \rightarrow \mathbf{y} = \mathbf{z}(t_{1}) \quad : \quad \frac{d \mathbf{z} }{dt} = \mathbf{f}(t, \mathbf{z}, \theta) \)

この変換は、状態 \(\mathbf{z}\) の進化を説明する関数 \(\mathbf{f}\) がうまく動作する限り可逆的であり、log_det_jacobian は、次の式を組み込むことで計算することができます。

\[ \log \operatorname{det}\left|\frac{\partial \mathcal{T}*{\theta}(\mathbf{y})}{\partial \mathbf{y} }\right| = -\int*{t_{0} }^{t_{1} } \operatorname{Tr}\left(\frac{\partial \mathbf{f}(t, \mathbf{z}, \theta)}{\partial \mathbf{z}(t)}\right) d t \]

このデモでは、ガウス分布を moons データセットで定義される分布にラップするように、FFJORD バイジェクターをトレーニングします。これは以下の 3 つのステップで行います。

  • 基底分布を定義する
  • FFJORD バイジェクターを定義する
  • データセットの正確な対数尤度を最小化にする

まず、データを読み込みます。

Dataset

png

次に、基底分布をインスタンス化します。

base_loc = np.array([0.0, 0.0]).astype(np.float32)
base_sigma = np.array([0.8, 0.8]).astype(np.float32)
base_distribution = tfd.MultivariateNormalDiag(base_loc, base_sigma)

多層パーセプトロンを使用して、state_derivative_fn をモデル化します。

このデータセットでは必要なことではありませんが、通常 state_derivative_fn を時間に依存させると有利です。ここでは、t をネットワークの入力に連結しています。

class MLP_ODE(snt.Module):
  """Multi-layer NN ode_fn."""
  def __init__(self, num_hidden, num_layers, num_output, name='mlp_ode'):
    super(MLP_ODE, self).__init__(name=name)
    self._num_hidden = num_hidden
    self._num_output = num_output
    self._num_layers = num_layers
    self._modules = []
    for _ in range(self._num_layers - 1):
      self._modules.append(snt.Linear(self._num_hidden))
      self._modules.append(tf.math.tanh)
    self._modules.append(snt.Linear(self._num_output))
    self._model = snt.Sequential(self._modules)

  def __call__(self, t, inputs):
    inputs = tf.concat([tf.broadcast_to(t, inputs.shape), inputs], -1)
    return self._model(inputs)

Model and training parameters

次に、FFJORD バイジェクターのスタックを構築します。バイジェクターごとに ode_solve_fntrace_augmentation_fn があり、独自の state_derivative_fn モデルであるため、これらは様々な変換のシーケンスです。

Building bijector

これで、base_distributionstacked_ffjord バイジェクターでラップしてできた TransformedDistribution を使用することができます。

transformed_distribution = tfd.TransformedDistribution(
    distribution=base_distribution, bijector=stacked_ffjord)

ここで、トレーニングの手順を定義します。単純に、データの負の対数尤度を最小化します。

Training

Samples

基底分布と変換した分布のサンプルをプロットします。

evaluation_samples = []
base_samples, transformed_samples = get_samples()
transformed_grid = get_transformed_grid()
evaluation_samples.append((base_samples, transformed_samples, transformed_grid))
WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/resource_variable_ops.py:1817: calling BaseResourceVariable.__init__ (from tensorflow.python.ops.resource_variable_ops) with constraint is deprecated and will be removed in a future version.
Instructions for updating:
If using Keras pass *_constraint arguments to layers.
panel_id = 0
panel_data = evaluation_samples[panel_id]
fig, axarray = plt.subplots(
  1, 4, figsize=(16, 6))
plot_panel(
    grid, panel_data[0], panel_data[2], panel_data[1], moons, axarray, False)
plt.tight_layout()

png

learning_rate = tf.Variable(LR, trainable=False)
optimizer = snt.optimizers.Adam(learning_rate)

for epoch in tqdm.trange(NUM_EPOCHS // 2):
  base_samples, transformed_samples = get_samples()
  transformed_grid = get_transformed_grid()
  evaluation_samples.append(
      (base_samples, transformed_samples, transformed_grid))
  for batch in moons_ds:
    _ = train_step(optimizer, batch)
0%|          | 0/40 [00:00<?, ?it/s]
WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow_probability/python/math/ode/base.py:350: calling while_loop_v2 (from tensorflow.python.ops.control_flow_ops) with back_prop=False is deprecated and will be removed in a future version.
Instructions for updating:
back_prop=False is deprecated. Consider using tf.stop_gradient instead.
Instead of:
results = tf.while_loop(c, b, vars, back_prop=False)
Use:
results = tf.nest.map_structure(tf.stop_gradient, tf.while_loop(c, b, vars))
100%|██████████| 40/40 [07:00<00:00, 10.52s/it]
panel_id = -1
panel_data = evaluation_samples[panel_id]
fig, axarray = plt.subplots(
  1, 4, figsize=(16, 6))
plot_panel(grid, panel_data[0], panel_data[2], panel_data[1], moons, axarray)
plt.tight_layout()

png

学習率を使ってより長い期間トレーニングすると、さらに改善します。

この例では説明されていませんが、FFJORD バイジェクターはハッチンソンの確率的トレース推定法をサポートしています。特定の Estimator は、trace_augmentation_fn を通じて提供されます。同様に、カスタム ode_solve_fn を定義することで、別のインテグレーターを使用することも可能です。