JAXによる分散推論

TensorFlow.orgで表示 GoogleColabで実行 GitHubでソースを表示 ノートブックをダウンロード

JAXのTensorFlowProbability(TFP)に、分散数値計算用のツールが追加されました。多数のアクセラレーターに拡張するために、ツールは「単一プログラム複数データ」パラダイム(略してSPMD)を使用してコードを記述することを中心に構築されています。

このノートブックでは、「SPMDで考える」方法を説明し、TPUポッドやGPUのクラスターなどの構成にスケーリングするための新しいTFP抽象化を紹介します。このコードを自分で実行している場合は、必ずTPUランタイムを選択してください。

まず、最新バージョンのTFP、JAX、TFをインストールします。

インストール

いくつかのJAXユーティリティとともに、いくつかの一般的なライブラリをインポートします。

セットアップとインポート

INFO:tensorflow:Enabling eager execution
INFO:tensorflow:Enabling v2 tensorshape
INFO:tensorflow:Enabling resource variables
INFO:tensorflow:Enabling tensor equality
INFO:tensorflow:Enabling control flow v2

また、いくつかの便利なTFPエイリアスを設定します。新しい抽象化は、現在で提供されtfp.experimental.distributetfp.experimental.mcmc

tfd = tfp.distributions
tfb = tfp.bijectors
tfm = tfp.mcmc
tfed = tfp.experimental.distribute
tfde = tfp.experimental.distributions
tfem = tfp.experimental.mcmc

Root = tfed.JointDistributionCoroutine.Root

ノートブックをTPUに接続するには、JAXの次のヘルパーを使用します。接続されていることを確認するために、デバイスの数を印刷します。これは8つである必要があります。

from jax.tools import colab_tpu
colab_tpu.setup_tpu()
print(f'Found {jax.device_count()} devices')
Found 8 devices

簡単に紹介jax.pmap

TPUに接続した後、私たちは、8つのデバイスへのアクセス権を持っています。ただし、JAXコードを熱心に実行すると、JAXはデフォルトで1つだけで計算を実行します。

多くのデバイス間で計算を実行する最も簡単な方法は、関数をマップすることです。各デバイスにマップの1つのインデックスを実行させます。 JAXは提供jax.pmapいくつかのデバイス間で機能をマッピング一つに機能をオン(「平行マップ」)変換を。

次の例では、サイズ8の配列を作成し(使用可能なデバイスの数に一致させるため)、その配列全体に5を追加する関数をマップします。

xs = jnp.arange(8.)
out = jax.pmap(lambda x: x + 5.)(xs)
print(type(out), out)
<class 'jax.interpreters.pxla.ShardedDeviceArray'> [ 5.  6.  7.  8.  9. 10. 11. 12.]

我々が受け取ることに留意されたいShardedDeviceArray出力アレイが物理的デバイスに分割されていることを示す、型バック。

jax.pmap意味的にマップのように動作しますが、その動作を変更するいくつかの重要なオプションがあります。デフォルトでは、 pmap関数へのすべての入力が終わっマッピングされていると仮定し、私たちがして、この動作を変更することができますin_axes引数。

xs = jnp.arange(8.)
y = 5.
# Map over the 0-axis of `xs` and don't map over `y`
out = jax.pmap(lambda x, y: x + y, in_axes=(0, None))(xs, y)
print(out)
[ 5.  6.  7.  8.  9. 10. 11. 12.]

同様に、 out_axesへの引数pmapすべてのデバイス上で値を返すか否かを判断します。設定out_axesNoneに自動的に第一デバイス上で値を返し、我々は値がすべてのデバイスで同じである確信している場合にのみ使用する必要があります。

xs = jnp.ones(8) # Value is the same on each device
out = jax.pmap(lambda x: x + 1, out_axes=None)(xs)
print(out)
2.0

私たちがやりたいことがマップされた純粋関数として簡単に表現できない場合はどうなりますか?たとえば、マッピングしている軸全体で合計を計算したい場合はどうなりますか? JAXは、デバイス間で通信する機能である「集合」を提供し、より興味深く複雑な分散プログラムの作成を可能にします。それらがどのように機能するかを正確に理解するために、SPMDを紹介します。

SPMDとは何ですか?

シングルプログラムマルチデータ(SPMD)は、単一のプログラム(つまり同じコード)がデバイス間で同時に実行される並行プログラミングモデルですが、実行中の各プログラムへの入力は異なる場合があります。

私たちのプログラムは、その入力の簡単な関数(のようなつまり、何かの場合はx + 5 )我々が行ったように、SPMDでプログラムを実行しているだけで、それを超える異なるデータをマッピングしているjax.pmap早いです。ただし、関数を「マップ」するだけでは不十分です。 JAXは、デバイス間で通信する機能である「集合」を提供します。

たとえば、すべてのデバイスの数量の合計を取得したい場合があります。我々はそれを行う前に、我々は軸でオーバー我々はしているマッピングに名前を割り当てる必要がありpmap 。私たちは、その後、使用lax.psum我々が名付け我々は合計するしている軸を識別確保し、デバイス間で合計を実行する(「並列加算」)機能を。

def f(x):
  out = lax.psum(x, axis_name='i')
  return out
xs = jnp.arange(8.) # Length of array matches number of devices
jax.pmap(f, axis_name='i')(xs)
ShardedDeviceArray([28., 28., 28., 28., 28., 28., 28., 28.], dtype=float32)

psum集団凝集の値x各デバイス上およびMAP IEを横切ってその値を同期out28.各デバイスに。単純な「マップ」を実行しなくなりましたが、SPMDプログラムを実行しています。このプログラムでは、集合を使用する方法は限られていますが、各デバイスの計算が他のデバイスの同じ計算と相互作用できるようになりました。このシナリオでは、使用することができますout_axes = Noneため、 psum値を同期します。

def f(x):
  out = lax.psum(x, axis_name='i')
  return out
jax.pmap(f, axis_name='i', out_axes=None)(jnp.arange(8.))
ShardedDeviceArray(28., dtype=float32)

SPMDを使用すると、任意のTPU構成のすべてのデバイスで同時に実行される1つのプログラムを作成できます。 8つのTPUコアで機械学習を行うために使用されるのと同じコードを、数百から数千のコアを持つTPUポッドで使用できます。詳細なチュートリアルについてはjax.pmapとSPMD、あなたはを参照することができJAX 101チュートリアル

大規模なMCMC

このノートブックでは、ベイズ推定にマルコフ連鎖モンテカルロ(MCMC)法を使用することに焦点を当てています。 MCMCに多くのデバイスを利用する方法はいくつかありますが、このノートブックでは、次の2つに焦点を当てます。

  1. 異なるデバイスで独立したマルコフ連鎖を実行します。このケースはかなり単純で、バニラTFPで行うことができます。
  2. デバイス間でデータセットをシャーディングします。このケースはもう少し複雑で、最近追加されたTFP機構が必要です。

独立したチェーン

MCMCを使用して問題についてベイズ推定を行い、複数のデバイス間で複数のチェーンを並列に実行したいとします(たとえば、各デバイスに2つ)。これは、デバイス間で「マッピング」できるプログラム、つまり集合を必要としないプログラムであることがわかりました。各プログラムが(同じマルコフ連鎖を実行するのではなく)異なるマルコフ連鎖を実行することを確認するために、ランダムシードの異なる値を各デバイスに渡します。

2次元ガウス分布からサンプリングするトイプロブレムで試してみましょう。 TFPの既存のMCMC機能をそのまま使用できます。一般に、マップされた関数内にほとんどのロジックを配置して、すべてのデバイスで実行されているものと最初のデバイスだけで実行されているものをより明確に区別しようとします。

def run(seed):
  target_log_prob = tfd.Sample(tfd.Normal(0., 1.), 2).log_prob

  initial_state = jnp.zeros([2, 2]) # 2 chains
  kernel = tfm.HamiltonianMonteCarlo(target_log_prob, 1e-1, 10)
  def trace_fn(state, pkr):
    return target_log_prob(state)

  states, log_prob = tfm.sample_chain(
    num_results=1000,
    num_burnin_steps=1000,
    kernel=kernel,
    current_state=initial_state,
    trace_fn=trace_fn,
    seed=seed
  )
  return states, log_prob

それ自体で、 run機能はステートレスランダムシードを取り込み(どのステートレスランダム性の仕事を見て、あなたが読むことができるJAXの上TFPノートを参照またはJAX 101チュートリアル)。マッピングrun異なる種の上には、いくつかの独立したマルコフ連鎖を実行しているになります。

states, log_probs = jax.pmap(run)(random.split(random.PRNGKey(0), 8))
print(states.shape, log_probs.shape)
# states is (8 devices, 1000 samples, 2 chains, 2 dimensions)
# log_prob is (8 devices, 1000 samples, 2 chains)
(8, 1000, 2, 2) (8, 1000, 2)

各デバイスに対応する追加の軸があることに注意してください。寸法を並べ替えて平らにし、16本のチェーンの軸を取得できます。

states = states.transpose([0, 2, 1, 3]).reshape([-1, 1000, 2])
log_probs = log_probs.transpose([0, 2, 1]).reshape([-1, 1000])
fig, ax = plt.subplots(1, 2, figsize=(10, 5))
ax[0].plot(log_probs.T, alpha=0.4)
ax[1].scatter(*states.reshape([-1, 2]).T, alpha=0.1)
plt.show()

png

多くのデバイスに依存しないチェーンを実行しているときのように、それは簡単なようだpmap使用する機能の上に-ing tfp.mcmc 、確実に我々は、各デバイスへのランダムシードに異なる値を渡します。

シャーディングデータ

MCMCを実行する場合、ターゲット分布は多くの場合、データセットの条件付けによって取得された事後分布であり、正規化されていない対数密度の計算には、観測された各データの尤度の合計が含まれます。

データセットが非常に大きい場合、1つのデバイスで1つのチェーンを実行することさえ非常にコストがかかる可能性があります。ただし、複数のデバイスにアクセスできる場合は、データセットをデバイス間で分割して、利用可能なコンピューティングをより有効に活用できます。

我々はシャードのデータセットとMCMCを行うしたい場合は、我々はそうでない場合は、各デバイスが自分の間違ったターゲットとMCMCことになるだろう、私たちは、各デバイス上で計算する非正規化ログ密度はすなわち、すべてのデータを超える密度の合計を表し、確認する必要があります分布。このため、TFPは今、新しいツール(すなわちありtfp.experimental.distributetfp.experimental.mcmc確率を記録し、彼らとMCMCをやって、「シャード」コンピューティングを有効)。

シャードディストリビューション

TFP抽象コアは現在かけらログprobabiliitiesを計算するため提供されSharded入力として分布をとり、SPMDコンテキストで実行されたときに特定の特性を有する新しい分布を返すメタ分布。 Shardedでの生活tfp.experimental.distribute

直感的に、 Shardedデバイス間で「分割」されている確率変数のセットに配布対応。各デバイスで、それらは異なるサンプルを生成し、個別に異なる対数密度を持つことができます。あるいは、 Shardedプレートのサイズはデバイスの数であるグラフィカルモデル用語で「プレート」に分配相当します。

サンプリングSharded分布を

私たちはからサンプリングした場合はNormalプログラムで配布されてpmap 、各デバイスで同じシードを使用して-ed、我々は、各デバイスで同じサンプルを取得します。次の関数は、デバイス間で同期される単一の確率変数をサンプリングするものと考えることができます。

# `pmap` expects at least one value to be mapped over, so we provide a dummy one
def f(seed, _):
  return tfd.Normal(0., 1.).sample(seed=seed)
jax.pmap(f, in_axes=(None, 0))(random.PRNGKey(0), jnp.arange(8.))
ShardedDeviceArray([-0.20584236, -0.20584236, -0.20584236, -0.20584236,
                    -0.20584236, -0.20584236, -0.20584236, -0.20584236],                   dtype=float32)

私たちはラップした場合tfd.Normal(0., 1.)tfed.Sharded 、我々は論理的になりました(各デバイスに1つずつ)8つの異なるランダムな変数を持っているので、同じシードを渡したにもかかわらず、それぞれに異なるサンプルを作成します。

def f(seed, _):
  return tfed.Sharded(tfd.Normal(0., 1.), shard_axis_name='i').sample(seed=seed)
jax.pmap(f, in_axes=(None, 0), axis_name='i')(random.PRNGKey(0), jnp.arange(8.))
ShardedDeviceArray([ 1.2152631 ,  0.7818249 ,  0.32549605,  0.6828047 ,
                     1.3973192 , -0.57830244,  0.37862757,  2.7706041 ],                   dtype=float32)

単一のデバイスでのこの分布の同等の表現は、8つの独立した正規サンプルです。サンプルの値が(異なるであろうにもかかわらずtfed.Shardedわずかに異なる擬似乱数生成を行い)、それらは両方とも同じ分布を表します。

dist = tfd.Sample(tfd.Normal(0., 1.), jax.device_count())
dist.sample(seed=random.PRNGKey(0))
DeviceArray([ 0.08086783, -0.38624594, -0.3756545 ,  1.668957  ,
             -1.2758069 ,  2.1192007 , -0.85821325,  1.1305912 ],            dtype=float32)

対数密度取るSharded分布を

SPMDコンテキストで正規分布からサンプルの対数密度を計算するとどうなるか見てみましょう。

def f(seed, _):
  dist = tfd.Normal(0., 1.)
  x = dist.sample(seed=seed)
  return x, dist.log_prob(x)
jax.pmap(f, in_axes=(None, 0))(random.PRNGKey(0), jnp.arange(8.))
(ShardedDeviceArray([-0.20584236, -0.20584236, -0.20584236, -0.20584236,
                     -0.20584236, -0.20584236, -0.20584236, -0.20584236],                   dtype=float32),
 ShardedDeviceArray([-0.94012403, -0.94012403, -0.94012403, -0.94012403,
                     -0.94012403, -0.94012403, -0.94012403, -0.94012403],                   dtype=float32))

各サンプルは各デバイスで同じであるため、各デバイスでも同じ密度を計算します。直感的には、ここでは、単一の正規分布変数にのみ分布しています。

Sharded分布、我々は計算するときに、8つの確率変数の上に分布を持っているlog_probサンプルのを、私たちは個々のログ濃度のそれぞれの上に、デバイス間で、合計します。 (この合計log_prob値は、上記で計算されたシングルトンlog_probよりも大きいことに気付くかもしれません。)

def f(seed, _):
  dist = tfed.Sharded(tfd.Normal(0., 1.), shard_axis_name='i')
  x = dist.sample(seed=seed)
  return x, dist.log_prob(x)
sample, log_prob = jax.pmap(f, in_axes=(None, 0), axis_name='i')(
    random.PRNGKey(0), jnp.arange(8.))
print('Sample:', sample)
print('Log Prob:', log_prob)
Sample: [ 1.2152631   0.7818249   0.32549605  0.6828047   1.3973192  -0.57830244
  0.37862757  2.7706041 ]
Log Prob: [-13.7349205 -13.7349205 -13.7349205 -13.7349205 -13.7349205 -13.7349205
 -13.7349205 -13.7349205]

同等の「シャーディングされていない」分布は、同じログ密度を生成します。

dist = tfd.Sample(tfd.Normal(0., 1.), jax.device_count())
dist.log_prob(sample)
DeviceArray(-13.7349205, dtype=float32)

A Sharded分布が異なる値を生成するsample各デバイス上で、しかし同じ値取得log_prob各デバイス上に。ここで何が起こっているのですか? Sharded分布はないpsum確保するために内部log_prob値は、デバイス間で同期しています。なぜこの動作が必要なのですか?私たちは、各デバイスで同じMCMCチェーンを実行している場合、私たちは希望target_log_prob計算でいくつかのランダムな変数はデバイス間でシャードされている場合でも、各デバイス間で同じになるように。

さらに、 Shardedデバイス間で勾配が遷移関数の一部として対数密度関数の勾配を取るHMCなどのアルゴリズムは、適切なサンプルを生成することを確実にするために、正しいこと分布を確実にします。

シャードJointDistribution

私たちは、複数のモデルで作成することができますSharded使用してランダムな変数をJointDistribution S(JDS)。残念ながら、 Shardedのディストリビューションは、安全にバニラを使用することはできませんtfd.JointDistribution Sが、 tfp.experimental.distribute輸出は次のように動作しますJDS「パッチを適用」 Sharded分布を。

def f(seed, _):
  dist = tfed.JointDistributionSequential([
    tfd.Normal(0., 1.),
    tfed.Sharded(tfd.Normal(0., 1.), shard_axis_name='i'),
  ])
  x = dist.sample(seed=seed)
  return x, dist.log_prob(x)
jax.pmap(f, in_axes=(None, 0), axis_name='i')(random.PRNGKey(0), jnp.arange(8.))
([ShardedDeviceArray([1.6121525, 1.6121525, 1.6121525, 1.6121525, 1.6121525,
                      1.6121525, 1.6121525, 1.6121525], dtype=float32),
  ShardedDeviceArray([ 0.8690128 , -0.83167845,  1.2209264 ,  0.88412696,
                       0.76478404, -0.66208494, -0.0129658 ,  0.7391483 ],                   dtype=float32)],
 ShardedDeviceArray([-12.214451, -12.214451, -12.214451, -12.214451,
                     -12.214451, -12.214451, -12.214451, -12.214451],                   dtype=float32))

これらのシャードJDSは、両方持つことができますShardedコンポーネントとして、バニラTFP分布を。シャーディングされていない分布の場合、各デバイスで同じサンプルを取得し、シャーディングされた分布の場合、異なるサンプルを取得します。 log_prob各デバイスには十分と同期されます。

MCMC Sharded分布

どのように考えるかSharded MCMCの文脈における分布?私たちのように表現することができる生成モデルがある場合はJointDistribution 、私たちは、全体で「シャード」にそのモデルのいくつかの軸を選ぶことができます。通常、モデル内の1つの確率変数は観測データに対応し、デバイス間でシャーディングする大きなデータセットがある場合は、データポイントに関連付けられている変数もシャーディングする必要があります。また、シャーディングしている観測値と1対1の「ローカル」確率変数がある可能性があるため、これらの確率変数をさらにシャーディングする必要があります。

私たちは、の使用例の上に行くよShardedこのセクションではTFP MCMCに分布。私たちは、単純ベイズロジスティック回帰の例で始まる、とのためにいくつかのユースケースを実証することを目標に、マトリックス分解の例で結論うdistributeライブラリ。

例:MNISTのベイズロジスティック回帰

大規模なデータセットに対してベイズロジスティック回帰を実行したいと思います。モデルには、回帰の重みよりも前の$ p(\ theta)$があり、すべてのデータ$ {x_i、y_i} _ {i = 1} ^で合計される尤度$ p(y_i | \ theta、x_i)$があります。合計ジョイントログ密度を取得するためのN $。データをシャーディングする場合、モデルで観測された確率変数$ x_i $と$ y_i $をシャーディングします。

MNIST分類には、次のベイズロジスティック回帰モデルを使用します。

$$ \begin{align*} w &\sim \mathcal{N}(0, 1) \\ b &\sim \mathcal{N}(0, 1) \\ y_i | w, b, x_i &\sim \textrm{Categorical}(w^T x_i + b) \end{align*} $$

TensorFlowデータセットを使用してMNISTをロードしましょう。

mnist = tfds.as_numpy(tfds.load('mnist', batch_size=-1))
raw_train_images, train_labels = mnist['train']['image'], mnist['train']['label']
train_images = raw_train_images.reshape([raw_train_images.shape[0], -1]) / 255.

raw_test_images, test_labels = mnist['test']['image'], mnist['test']['label']
test_images = raw_test_images.reshape([raw_test_images.shape[0], -1]) / 255.
Downloading and preparing dataset mnist/3.0.1 (download: 11.06 MiB, generated: 21.00 MiB, total: 32.06 MiB) to /root/tensorflow_datasets/mnist/3.0.1...
WARNING:absl:Dataset mnist is hosted on GCS. It will automatically be downloaded to your
local data directory. If you'd instead prefer to read directly from our public
GCS bucket (recommended if you're running on GCP), you can instead pass
`try_gcs=True` to `tfds.load` or set `data_dir=gs://tfds-data/datasets`.
HBox(children=(FloatProgress(value=0.0, description='Dl Completed...', max=4.0, style=ProgressStyle(descriptio…
Dataset mnist downloaded and prepared to /root/tensorflow_datasets/mnist/3.0.1. Subsequent calls will reuse this data.

60000のトレーニング画像がありますが、利用可能な8つのコアを利用して、8つの方法で分割しましょう。私たちは、この便利な使いますshardユーティリティ関数を。

def shard_value(x):
  x = x.reshape((jax.device_count(), -1, *x.shape[1:]))
  return jax.pmap(lambda x: x)(x) # pmap will physically place values on devices

shard = functools.partial(jax.tree_map, shard_value)
sharded_train_images, sharded_train_labels = shard((train_images, train_labels))
print(sharded_train_images.shape, sharded_train_labels.shape)
(8, 7500, 784) (8, 7500)

先に進む前に、TPUの精度とHMCへの影響について簡単に説明しましょう。 TPUは低い用いた行列乗算を実行bfloat16速度を精度。 bfloat16行列乗算は、多くの場合、多くの深い学習用途には十分ですが、HMCで使用する場合、我々は経験的に低精度が拒否を引き起こし、発散軌道につながることを発見しました。追加の計算をいくらか犠牲にして、より高精度の行列乗算を使用できます。

私たちのMATMUL精度を高めるために、我々は使用することができますjax.default_matmul_precisionでデコレータを"tensorfloat32"の精度(より高い精度のために私たちが使用できる"float32"精度)。

レッツ・今私たちの定義run (各デバイスで同じになります)ランダムシードとMNISTのシャードにかかる機能を、。この関数は前述のモデルを実装し、TFPのバニラMCMC機能を使用して単一のチェーンを実行します。私たちは必ず飾るために作るでしょうrunjax.default_matmul_precision下の特定の例では、私たちは同様に使用することができますが、必ず行列の乗算はより高精度で実行されていることを確認するためにデコレータjnp.dot(images, w, precision=lax.Precision.HIGH)

# We can use `out_axes=None` in the `pmap` because the results will be the same
# on every device. 
@functools.partial(jax.pmap, axis_name='data', in_axes=(None, 0), out_axes=None)
@jax.default_matmul_precision('tensorfloat32')
def run(seed, data):
  images, labels = data # a sharded dataset
  num_examples, dim = images.shape
  num_classes = 10

  def model_fn():
    w = yield Root(tfd.Sample(tfd.Normal(0., 1.), [dim, num_classes]))
    b = yield Root(tfd.Sample(tfd.Normal(0., 1.), [num_classes]))
    logits = jnp.dot(images, w) + b
    yield tfed.Sharded(tfd.Independent(tfd.Categorical(logits=logits), 1),
                       shard_axis_name='data')
  model = tfed.JointDistributionCoroutine(model_fn)

  init_seed, sample_seed = random.split(seed)

  initial_state = model.sample(seed=init_seed)[:-1] # throw away `y`

  def target_log_prob(*state):
    return model.log_prob((*state, labels))

  def accuracy(w, b):
    logits = images.dot(w) + b
    preds = logits.argmax(axis=-1)
    # We take the average accuracy across devices by using `lax.pmean`
    return lax.pmean((preds == labels).mean(), 'data')

  kernel = tfm.HamiltonianMonteCarlo(target_log_prob, 1e-2, 100)
  kernel = tfm.DualAveragingStepSizeAdaptation(kernel, 500)
  def trace_fn(state, pkr):
    return (
        target_log_prob(*state),
        accuracy(*state),
        pkr.new_step_size)
  states, trace = tfm.sample_chain(
    num_results=1000,
    num_burnin_steps=1000,
    current_state=initial_state,
    kernel=kernel,
    trace_fn=trace_fn,
    seed=sample_seed
  )
  return states, trace

jax.pmap JITコンパイルを含むが、コンパイルされた関数は、最初の呼び出しの後にキャッシュされます。私たちは、電話するよrunて、コンパイルをキャッシュするために出力を無視します。

%%time
output = run(random.PRNGKey(0), (sharded_train_images, sharded_train_labels))
jax.tree_map(lambda x: x.block_until_ready(), output)
CPU times: user 24.5 s, sys: 48.2 s, total: 1min 12s
Wall time: 1min 54s

私たちは今、電話するよrun実際の実行にかかる時間の長さを表示するに再び。

%%time
states, trace = run(random.PRNGKey(0), (sharded_train_images, sharded_train_labels))
jax.tree_map(lambda x: x.block_until_ready(), trace)
CPU times: user 13.1 s, sys: 45.2 s, total: 58.3 s
Wall time: 1min 43s

200,000のリープフロッグステップを実行しています。各ステップは、データセット全体の勾配を計算します。計算を8コアに分割すると、約95秒で200,000エポックのトレーニングに相当する計算が可能になります。これは1秒あたり約2,100エポックです。

各サンプルの対数密度と各サンプルの精度をプロットしてみましょう。

fig, ax = plt.subplots(1, 3, figsize=(15, 5))
ax[0].plot(trace[0])
ax[0].set_title('Log Prob')
ax[1].plot(trace[1])
ax[1].set_title('Accuracy')
ax[2].plot(trace[2])
ax[2].set_title('Step Size')
plt.show()

png

サンプルをアンサンブルすると、ベイズモデルの平均を計算してパフォーマンスを向上させることができます。

@functools.partial(jax.pmap, axis_name='data', in_axes=(0, None), out_axes=None)
def bayesian_model_average(data, states):
  images, labels = data
  logits = jax.vmap(lambda w, b: images.dot(w) + b)(*states)
  probs = jax.nn.softmax(logits, axis=-1)
  bma_accuracy = (probs.mean(axis=0).argmax(axis=-1) == labels).mean()
  avg_accuracy = (probs.argmax(axis=-1) == labels).mean()
  return lax.pmean(bma_accuracy, axis_name='data'), lax.pmean(avg_accuracy, axis_name='data')

sharded_test_images, sharded_test_labels = shard((test_images, test_labels))
bma_acc, avg_acc = bayesian_model_average((sharded_test_images, sharded_test_labels), states)
print(f'Average Accuracy: {avg_acc}')
print(f'BMA Accuracy: {bma_acc}')
print(f'Accuracy Improvement: {bma_acc - avg_acc}')
Average Accuracy: 0.9188529253005981
BMA Accuracy: 0.9264000058174133
Accuracy Improvement: 0.0075470805168151855

ベイジアンモデルの平均により、精度がほぼ1%向上します。

例:MovieLensレコメンデーションシステム

それでは、ユーザーとさまざまな映画の評価のコレクションであるMovieLens推奨データセットを使用して推論を試してみましょう。具体的には、MovieLensを$ N \ times M $ウォッチマトリックス$ W $として表すことができます。ここで、$ N $はユーザー数、$ M $は映画数です。 $ N> M $を期待します。 $ W_ {ij} $のエントリは、ユーザー$ i $が映画$ j $を視聴したかどうかを示すブール値です。 MovieLensはユーザー評価を提供しますが、問題を単純化するためにそれらを無視していることに注意してください。

まず、データセットを読み込みます。評価が100万のバージョンを使用します。

movielens = tfds.as_numpy(tfds.load('movielens/1m-ratings', batch_size=-1))
GENRES = ['Action', 'Adventure', 'Animation', 'Children', 'Comedy',
          'Crime', 'Documentary', 'Drama', 'Fantasy', 'Film-Noir',
          'Horror', 'IMAX', 'Musical', 'Mystery', 'Romance', 'Sci-Fi',
          'Thriller', 'Unknown', 'War', 'Western', '(no genres listed)']
Downloading and preparing dataset movielens/1m-ratings/0.1.0 (download: Unknown size, generated: Unknown size, total: Unknown size) to /root/tensorflow_datasets/movielens/1m-ratings/0.1.0...
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Dl Completed...', max=1.0, style=Progre…
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Dl Size...', max=1.0, style=ProgressSty…
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Extraction completed...', max=1.0, styl…

HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))
Shuffling and writing examples to /root/tensorflow_datasets/movielens/1m-ratings/0.1.0.incompleteYKA3TG/movielens-train.tfrecord
HBox(children=(FloatProgress(value=0.0, max=1000209.0), HTML(value='')))
Dataset movielens downloaded and prepared to /root/tensorflow_datasets/movielens/1m-ratings/0.1.0. Subsequent calls will reuse this data.

データセットの前処理を行って、監視マトリックス$ W $を取得します。

raw_movie_ids = movielens['train']['movie_id']
raw_user_ids = movielens['train']['user_id']
genres = movielens['train']['movie_genres']

movie_ids, movie_labels = pd.factorize(movielens['train']['movie_id'])
user_ids, user_labels = pd.factorize(movielens['train']['user_id'])

num_movies = movie_ids.max() + 1
num_users = user_ids.max() + 1

movie_titles = dict(zip(movielens['train']['movie_id'],
                        movielens['train']['movie_title']))
movie_genres = dict(zip(movielens['train']['movie_id'],
                        genres))
movie_id_to_title = [movie_titles[movie_labels[id]].decode('utf-8')
                     for id in range(num_movies)]
movie_id_to_genre = [GENRES[movie_genres[movie_labels[id]][0]] for id in range(num_movies)]

watch_matrix = np.zeros((num_users, num_movies), bool)
watch_matrix[user_ids, movie_ids] = True
print(watch_matrix.shape)
(6040, 3706)

単純な確率的行列因数分解モデルを使用して、$ W $の生成モデルを定義できます。潜在的な$ N \ times D $ユーザー行列$ U $と潜在的な$ M \ times D $映画行列$ V $を想定します。これらを乗算すると、時計行列$ W $のベルヌーイのロジットが生成されます。また、ユーザーと映画のバイアスベクトル$ u $と$ v $も含まれます。

$$ \begin{align*} U &\sim \mathcal{N}(0, 1) \quad u \sim \mathcal{N}(0, 1)\\ V &\sim \mathcal{N}(0, 1) \quad v \sim \mathcal{N}(0, 1)\\ W_{ij} &\sim \textrm{Bernoulli}\left(\sigma\left(\left(UV^T\right)_{ij} + u_i + v_j\right)\right) \end{align*} $$

これはかなり大きなマトリックスです。 6040のユーザーと3706の映画は、2,200万を超えるエントリを含むマトリックスにつながります。このモデルのシャーディングにどのようにアプローチしますか? $ N> M $(つまり、映画よりもユーザーが多い)と仮定すると、ユーザー軸全体でウォッチマトリックスをシャーディングするのが理にかなっているため、各デバイスにはサブセットに対応するウォッチマトリックスのチャンクがあります。ユーザーの。ただし、前の例とは異なり、$ U $マトリックスはユーザーごとに埋め込まれているため、シャード化する必要があります。したがって、各デバイスは$ U $のシャードと$ W $のシャードを担当します。 。一方、$ V $はシャーディングされておらず、デバイス間で同期されます。

sharded_watch_matrix = shard(watch_matrix)

私たちが書く前にrun 、のは、すぐに地元の確率変数$ U $をシャーディングと新たな課題を議論しましょう。 HMCを実行している場合は、バニラtfp.mcmc.HamiltonianMonteCarloカーネルは、チェーンの状態の各要素のための運動量をサンプリングします。以前は、シャーディングされていない確率変数のみがその状態の一部であり、運動量は各デバイスで同じでした。シャーディングされた$ U $ができたら、各デバイスで$ U $の異なる運動量をサンプリングし、$ V $の同じ運動量をサンプリングする必要があります。これを実現するために、我々は使用することができますtfp.experimental.mcmc.PreconditionedHamiltonianMonteCarloしてSharded運動量分布。並列計算をファーストクラスにし続けると、たとえば、HMCカーネルにシャードネスインジケーターを適用することで、これを単純化できます。

def make_run(*,
             axis_name,
             dim=20,
             num_chains=2,
             prior_variance=1.,
             step_size=1e-2,
             num_leapfrog_steps=100,
             num_burnin_steps=1000,
             num_results=500,
             ):
  @functools.partial(jax.pmap, in_axes=(None, 0), axis_name=axis_name)
  @jax.default_matmul_precision('tensorfloat32')
  def run(key, watch_matrix):
    num_users, num_movies = watch_matrix.shape

    Sharded = functools.partial(tfed.Sharded, shard_axis_name=axis_name)

    def prior_fn():
      user_embeddings = yield Root(Sharded(tfd.Sample(tfd.Normal(0., 1.), [num_users, dim]), name='user_embeddings'))
      user_bias = yield Root(Sharded(tfd.Sample(tfd.Normal(0., 1.), [num_users]), name='user_bias'))
      movie_embeddings = yield Root(tfd.Sample(tfd.Normal(0., 1.), [num_movies, dim], name='movie_embeddings'))
      movie_bias = yield Root(tfd.Sample(tfd.Normal(0., 1.), [num_movies], name='movie_bias'))
      return (user_embeddings, user_bias, movie_embeddings, movie_bias)
    prior = tfed.JointDistributionCoroutine(prior_fn)

    def model_fn():
      user_embeddings, user_bias, movie_embeddings, movie_bias = yield from prior_fn()
      logits = (jnp.einsum('...nd,...md->...nm', user_embeddings, movie_embeddings)

                + user_bias[..., :, None] + movie_bias[..., None, :])
      yield Sharded(tfd.Independent(tfd.Bernoulli(logits=logits), 2), name='watch')
    model = tfed.JointDistributionCoroutine(model_fn)

    init_key, sample_key = random.split(key)
    initial_state = prior.sample(seed=init_key, sample_shape=num_chains)

    def target_log_prob(*state):
      return model.log_prob((*state, watch_matrix))

    momentum_distribution = tfed.JointDistributionSequential([
      Sharded(tfd.Independent(tfd.Normal(jnp.zeros([num_chains, num_users, dim]), 1.), 2)),
      Sharded(tfd.Independent(tfd.Normal(jnp.zeros([num_chains, num_users]), 1.), 1)),
      tfd.Independent(tfd.Normal(jnp.zeros([num_chains, num_movies, dim]), 1.), 2),
      tfd.Independent(tfd.Normal(jnp.zeros([num_chains, num_movies]), 1.), 1),
    ])

    # We pass in momentum_distribution here to ensure that the momenta for 
    # user_embeddings and user_bias are also sharded
    kernel = tfem.PreconditionedHamiltonianMonteCarlo(target_log_prob, step_size,
                                                      num_leapfrog_steps,
                                                      momentum_distribution=momentum_distribution)

    num_adaptation_steps = int(0.8 * num_burnin_steps)
    kernel = tfm.DualAveragingStepSizeAdaptation(kernel, num_adaptation_steps)

    def trace_fn(state, pkr):
      return {
        'log_prob': target_log_prob(*state),
        'log_accept_ratio': pkr.inner_results.log_accept_ratio,
      }
    return tfm.sample_chain(
        num_results, initial_state,
        kernel=kernel,
        num_burnin_steps=num_burnin_steps,
        trace_fn=trace_fn,
        seed=sample_key)
  return run

コンパイルキャッシュするために一度我々は再びそれを実行しますrun

%%time
run = make_run(axis_name='data')
output = run(random.PRNGKey(0), sharded_watch_matrix)
jax.tree_map(lambda x: x.block_until_ready(), output)
CPU times: user 56 s, sys: 1min 24s, total: 2min 20s
Wall time: 3min 35s

次に、コンパイルのオーバーヘッドなしで再度実行します。

%%time
states, trace = run(random.PRNGKey(0), sharded_watch_matrix)
jax.tree_map(lambda x: x.block_until_ready(), trace)
CPU times: user 28.8 s, sys: 1min 16s, total: 1min 44s
Wall time: 3min 1s

約3分で約150,000のリープフロッグステップを完了したようです。つまり、1秒あたり約83のリープフロッグステップです。サンプルの受け入れ率と対数密度をプロットしてみましょう。

fig, axs = plt.subplots(1, len(trace), figsize=(5 * len(trace), 5))
for ax, (key, val) in zip(axs, trace.items()):
  ax.plot(val[0]) # Indexing into a sharded array, each element is the same
  ax.set_title(key);

png

マルコフ連鎖からいくつかのサンプルが得られたので、それらを使用していくつかの予測を行いましょう。まず、各コンポーネントを抽出してみましょう。ことを忘れないでくださいuser_embeddingsuser_bias 、デバイス間で分割されているので、我々は連結する必要がShardedArrayそれらすべてを取得します。一方、 movie_embeddingsmovie_biasすべてのデバイスで同じであるので、我々はちょうど最初のシャードから値を選ぶことができます。私たちは、定期的に使用しますnumpy CPUへのTPUバックから値をコピーします。

user_embeddings = np.concatenate(np.array(states.user_embeddings, np.float32), axis=2)
user_bias = np.concatenate(np.array(states.user_bias, np.float32), axis=2)
movie_embeddings = np.array(states.movie_embeddings[0], dtype=np.float32)
movie_bias = np.array(states.movie_bias[0], dtype=np.float32)
samples = (user_embeddings, user_bias, movie_embeddings, movie_bias)
print(f'User embeddings: {user_embeddings.shape}')
print(f'User bias: {user_bias.shape}')
print(f'Movie embeddings: {movie_embeddings.shape}')
print(f'Movie bias: {movie_bias.shape}')
User embeddings: (500, 2, 6040, 20)
User bias: (500, 2, 6040)
Movie embeddings: (500, 2, 3706, 20)
Movie bias: (500, 2, 3706)

これらのサンプルでキャプチャされた不確実性を利用する単純なレコメンダーシステムを構築してみましょう。まず、視聴確率に応じて映画をランク付けする関数を作成しましょう。

@jax.jit
def recommend(sample, user_id):
  user_embeddings, user_bias, movie_embeddings, movie_bias = sample
  movie_logits = (
      jnp.einsum('d,md->m', user_embeddings[user_id], movie_embeddings)

      + user_bias[user_id] + movie_bias)
  return movie_logits.argsort()[::-1]

これで、すべてのサンプルをループし、各サンプルについて、ユーザーがまだ視聴していない上位の映画を選択する関数を作成できます。次に、サンプル全体で推奨されるすべての映画の数を確認できます。

def get_recommendations(user_id): 
  movie_ids = []
  already_watched = set(jnp.arange(num_movies)[watch_matrix[user_id] == 1])
  for i in range(500):
    for j in range(2):
      sample = jax.tree_map(lambda x: x[i, j], samples)
      ranking = recommend(sample, user_id)
      for movie_id in ranking:
        if int(movie_id) not in already_watched:
          movie_ids.append(movie_id)
          break
  return movie_ids

def plot_recommendations(movie_ids, ax=None):
  titles = collections.Counter([movie_id_to_title[i] for i in movie_ids])
  ax = ax or plt.gca()
  names, counts = zip(*sorted(titles.items(), key=lambda x: -x[1]))
  ax.bar(names, counts)
  ax.set_xticklabels(names, rotation=90)

映画を最も多く見たユーザーと最も少なく見たユーザーを比較してみましょう。

user_watch_counts = watch_matrix.sum(axis=1)
user_most = user_watch_counts.argmax()
user_least = user_watch_counts.argmin()
print(user_watch_counts[user_most], user_watch_counts[user_least])
2314 20

我々は、我々のシステムは詳細については確信がある願っていますuser_mostよりuser_least私たちは映画のソート何についての詳細は持っていることを考えると、 user_most視聴する可能性が高いです。

fig, ax = plt.subplots(1, 2, figsize=(20, 10))
most_recommendations = get_recommendations(user_most)
plot_recommendations(most_recommendations, ax=ax[0])
ax[0].set_title('Recommendation for user_most')
least_recommendations = get_recommendations(user_least)
plot_recommendations(least_recommendations, ax=ax[1])
ax[1].set_title('Recommendation for user_least');

png

私たちは、私たちの提言では、よりばらつきがあることがわかりuser_least自分の時計の設定で、当社の追加の不確実性を反映するには。

おすすめの映画のジャンルもご覧いただけます。

most_genres = collections.Counter([movie_id_to_genre[i] for i in most_recommendations])
least_genres = collections.Counter([movie_id_to_genre[i] for i in least_recommendations])
fig, ax = plt.subplots(1, 2, figsize=(20, 10))
ax[0].bar(most_genres.keys(), most_genres.values())
ax[0].set_title('Genres recommended for user_most')
ax[1].bar(least_genres.keys(), least_genres.values())
ax[1].set_title('Genres recommended for user_least');

png

user_mostたくさんの映画を見ていると、一方ミステリーや犯罪のような、よりニッチなジャンルを推奨されているuser_least多くの映画を見ていないし、より多くの主流の映画、スキューコメディとアクションを推奨されていました。