Probabilitas TensorFlow di JAX

Lihat di TensorFlow.org Jalankan di Google Colab Lihat sumber di GitHub Unduh buku catatan

TensorFlow Probabilitas (TFP) adalah library untuk penalaran probabilistik dan analisis statistik yang kini juga bekerja pada JAX ! Bagi mereka yang tidak terbiasa, JAX adalah perpustakaan untuk komputasi numerik yang dipercepat berdasarkan transformasi fungsi yang dapat dikomposisi.

TFP di JAX mendukung banyak fungsi yang paling berguna dari TFP biasa sambil mempertahankan abstraksi dan API yang sekarang nyaman digunakan oleh banyak pengguna TFP.

Mempersiapkan

TFP pada JAX tidak tergantung pada TensorFlow; mari hapus instalan TensorFlow dari Colab ini sepenuhnya.

pip uninstall tensorflow -y -q

Kami dapat menginstal TFP di JAX dengan build TFP malam terbaru.

pip install -Uq tfp-nightly[jax] > /dev/null

Mari impor beberapa pustaka Python yang berguna.

import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
from sklearn import datasets
sns.set(style='white')
/usr/local/lib/python3.6/dist-packages/statsmodels/tools/_testing.py:19: FutureWarning: pandas.util.testing is deprecated. Use the functions in the public API at pandas.testing instead.
  import pandas.util.testing as tm

Mari juga mengimpor beberapa fungsi dasar JAX.

import jax.numpy as jnp
from jax import grad
from jax import jit
from jax import random
from jax import value_and_grad
from jax import vmap

Mengimpor TFP di JAX

Untuk menggunakan TFP pada JAX, hanya mengimpor jax "substrat" dan menggunakannya sebagai Anda biasanya akan tfp :

from tensorflow_probability.substrates import jax as tfp
tfd = tfp.distributions
tfb = tfp.bijectors
tfpk = tfp.math.psd_kernels

Demo: Regresi logistik Bayesian

Untuk mendemonstrasikan apa yang dapat kami lakukan dengan backend JAX, kami akan menerapkan regresi logistik Bayesian yang diterapkan pada set data Iris klasik.

Pertama, mari impor dataset Iris dan ekstrak beberapa metadata.

iris = datasets.load_iris()
features, labels = iris['data'], iris['target']

num_features = features.shape[-1]
num_classes = len(iris.target_names)

Kita bisa menentukan model menggunakan tfd.JointDistributionCoroutine . Kami akan menempatkan prior standar normal pada kedua bobot dan jangka bias yang kemudian menulis target_log_prob fungsi yang pin label sampel untuk data.

Root = tfd.JointDistributionCoroutine.Root
def model():
  w = yield Root(tfd.Sample(tfd.Normal(0., 1.),
                            sample_shape=(num_features, num_classes)))
  b = yield Root(
      tfd.Sample(tfd.Normal(0., 1.), sample_shape=(num_classes,)))
  logits = jnp.dot(features, w) + b
  yield tfd.Independent(tfd.Categorical(logits=logits),
                        reinterpreted_batch_ndims=1)


dist = tfd.JointDistributionCoroutine(model)
def target_log_prob(*params):
  return dist.log_prob(params + (labels,))

Kami sampel dari dist untuk menghasilkan keadaan awal untuk MCMC. Kami kemudian dapat mendefinisikan fungsi yang mengambil kunci acak dan keadaan awal, dan menghasilkan 500 sampel dari No-U-Turn-Sampler (NUTS). Perhatikan bahwa kita dapat menggunakan JAX transformasi seperti jit untuk mengkompilasi sampler KACANG kami menggunakan XLA.

init_key, sample_key = random.split(random.PRNGKey(0))
init_params = tuple(dist.sample(seed=init_key)[:-1])

@jit
def run_chain(key, state):
  kernel = tfp.mcmc.NoUTurnSampler(target_log_prob, 1e-3)
  return tfp.mcmc.sample_chain(500,
      current_state=state,
      kernel=kernel,
      trace_fn=lambda _, results: results.target_log_prob,
      num_burnin_steps=500,
      seed=key)

states, log_probs = run_chain(sample_key, init_params)
plt.figure()
plt.plot(log_probs)
plt.ylabel('Target Log Prob')
plt.xlabel('Iterations of NUTS')
plt.show()

png

Mari kita gunakan sampel kita untuk melakukan Bayesian model averaging (BMA) dengan merata-ratakan probabilitas yang diprediksi dari setiap rangkaian bobot.

Pertama mari kita tulis sebuah fungsi yang untuk sekumpulan parameter tertentu akan menghasilkan probabilitas untuk setiap kelas. Kita dapat menggunakan dist.sample_distributions untuk mendapatkan distribusi akhir dalam model.

def classifier_probs(params):
  dists, _ = dist.sample_distributions(seed=random.PRNGKey(0),
                                       value=params + (None,))
  return dists[-1].distribution.probs_parameter()

Kita bisa vmap(classifier_probs) atas set sampel untuk mendapatkan probabilitas kelas diprediksi untuk masing-masing sampel kami. Kami kemudian menghitung akurasi rata-rata di setiap sampel, dan akurasi dari rata-rata model Bayesian.

all_probs = jit(vmap(classifier_probs))(states)
print('Average accuracy:', jnp.mean(all_probs.argmax(axis=-1) == labels))
print('BMA accuracy:', jnp.mean(all_probs.mean(axis=0).argmax(axis=-1) == labels))
Average accuracy: 0.96952
BMA accuracy: 0.97999996

Sepertinya BMA mengurangi tingkat kesalahan kami hampir sepertiga!

Dasar-dasar

TFP pada JAX memiliki API identik dengan TF mana bukannya menerima benda TF seperti tf.Tensor s itu menerima analog JAX. Sebagai contoh, di mana pun tf.Tensor sebelumnya digunakan sebagai masukan, API sekarang mengharapkan JAX DeviceArray . Alih-alih mengembalikan tf.Tensor , metode TFP akan kembali DeviceArray s. TFP pada JAX juga bekerja dengan struktur bersarang benda JAX, seperti daftar atau kamus DeviceArray s.

Distribusi

Sebagian besar distribusi TFP didukung di JAX dengan semantik yang sangat mirip dengan rekanan TF mereka. Mereka juga terdaftar sebagai JAX Pytrees , sehingga mereka dapat menjadi input dan output dari fungsi JAX-berubah.

Distribusi dasar

The log_prob metode untuk distribusi bekerja sama.

dist = tfd.Normal(0., 1.)
print(dist.log_prob(0.))
-0.9189385

Sampling dari distribusi membutuhkan eksplisit lewat di PRNGKey (atau daftar bilangan bulat) sebagai seed argumen kata kunci. Gagal mengirimkan benih secara eksplisit akan menimbulkan kesalahan.

tfd.Normal(0., 1.).sample(seed=random.PRNGKey(0))
DeviceArray(-0.20584226, dtype=float32)

Semantik bentuk untuk distribusi tetap sama di JAX, di mana distribusi masing-masing akan memiliki event_shape dan batch_shape dan menggambar banyak sampel akan menambah tambahan sample_shape dimensi.

Misalnya, tfd.MultivariateNormalDiag dengan parameter vektor akan memiliki bentuk acara vektor dan bentuk batch yang kosong.

dist = tfd.MultivariateNormalDiag(
    loc=jnp.zeros(5),
    scale_diag=jnp.ones(5)
)
print('Event shape:', dist.event_shape)
print('Batch shape:', dist.batch_shape)
Event shape: (5,)
Batch shape: ()

Di sisi lain, tfd.Normal diparameterisasi dengan vektor akan memiliki skalar acara bentuk dan vektor bets bentuk.

dist = tfd.Normal(
    loc=jnp.ones(5),
    scale=jnp.ones(5),
)
print('Event shape:', dist.event_shape)
print('Batch shape:', dist.batch_shape)
Event shape: ()
Batch shape: (5,)

Semantik mengambil log_prob sampel bekerja sama dalam JAX juga.

dist =  tfd.Normal(jnp.zeros(5), jnp.ones(5))
s = dist.sample(sample_shape=(10, 2), seed=random.PRNGKey(0))
print(dist.log_prob(s).shape)

dist =  tfd.Independent(tfd.Normal(jnp.zeros(5), jnp.ones(5)), 1)
s = dist.sample(sample_shape=(10, 2), seed=random.PRNGKey(0))
print(dist.log_prob(s).shape)
(10, 2, 5)
(10, 2)

Karena JAX DeviceArray s yang kompatibel dengan perpustakaan seperti NumPy dan Matplotlib, kita bisa memberi makan sampel langsung ke fungsi merencanakan.

sns.distplot(tfd.Normal(0., 1.).sample(1000, seed=random.PRNGKey(0)))
plt.show()

png

Distribution metode yang kompatibel dengan transformasi JAX.

sns.distplot(jit(vmap(lambda key: tfd.Normal(0., 1.).sample(seed=key)))(
    random.split(random.PRNGKey(0), 2000)))
plt.show()

png

x = jnp.linspace(-5., 5., 100)
plt.plot(x, jit(vmap(grad(tfd.Normal(0., 1.).prob)))(x))
plt.show()

png

Karena distribusi TFP terdaftar sebagai JAX node pytree, kita bisa menulis fungsi dengan distribusi sebagai input atau output dan mengubah mereka menggunakan jit , tetapi mereka belum didukung sebagai argumen untuk vmap fungsi -ed.

@jit
def random_distribution(key):
  loc_key, scale_key = random.split(key)
  loc, log_scale = random.normal(loc_key), random.normal(scale_key)
  return tfd.Normal(loc, jnp.exp(log_scale))
random_dist = random_distribution(random.PRNGKey(0))
print(random_dist.mean(), random_dist.variance())
0.14389051 0.081832744

Distribusi yang diubah

Distribusi berubah yaitu distribusi yang sampel dilewatkan melalui Bijector juga bekerja di luar kotak (bijectors bekerja terlalu! Lihat di bawah).

dist = tfd.TransformedDistribution(
    tfd.Normal(0., 1.),
    tfb.Sigmoid()
)
sns.distplot(dist.sample(1000, seed=random.PRNGKey(0)))
plt.show()

png

Distribusi bersama

TFP menawarkan JointDistribution s untuk memungkinkan menggabungkan distribusi komponen ke dalam distribusi tunggal atas beberapa variabel acak. Saat ini, TFP penawaran tiga varian inti ( JointDistributionSequential , JointDistributionNamed , dan JointDistributionCoroutine ) yang semuanya didukung dalam JAX. The AutoBatched varian juga semua didukung.

dist = tfd.JointDistributionSequential([
  tfd.Normal(0., 1.),
  lambda x: tfd.Normal(x, 1e-1)
])
plt.scatter(*dist.sample(1000, seed=random.PRNGKey(0)), alpha=0.5)
plt.show()

png

joint = tfd.JointDistributionNamed(dict(
    e=             tfd.Exponential(rate=1.),
    n=             tfd.Normal(loc=0., scale=2.),
    m=lambda n, e: tfd.Normal(loc=n, scale=e),
    x=lambda    m: tfd.Sample(tfd.Bernoulli(logits=m), 12),
))
joint.sample(seed=random.PRNGKey(0))
{'e': DeviceArray(3.376818, dtype=float32),
 'm': DeviceArray(2.5449684, dtype=float32),
 'n': DeviceArray(-0.6027825, dtype=float32),
 'x': DeviceArray([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], dtype=int32)}
Root = tfd.JointDistributionCoroutine.Root
def model():
  e = yield Root(tfd.Exponential(rate=1.))
  n = yield Root(tfd.Normal(loc=0, scale=2.))
  m = yield tfd.Normal(loc=n, scale=e)
  x = yield tfd.Sample(tfd.Bernoulli(logits=m), 12)

joint = tfd.JointDistributionCoroutine(model)

joint.sample(seed=random.PRNGKey(0))
StructTuple(var0=DeviceArray(0.17315261, dtype=float32), var1=DeviceArray(-3.290489, dtype=float32), var2=DeviceArray(-3.1949058, dtype=float32), var3=DeviceArray([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], dtype=int32))

distribusi lainnya

Proses Gaussian juga bekerja dalam mode JAX!

k1, k2, k3 = random.split(random.PRNGKey(0), 3)
observation_noise_variance = 0.01
f = lambda x: jnp.sin(10*x[..., 0]) * jnp.exp(-x[..., 0]**2)
observation_index_points = random.uniform(
    k1, [50], minval=-1.,maxval= 1.)[..., jnp.newaxis]
observations = f(observation_index_points) + tfd.Normal(
    loc=0., scale=jnp.sqrt(observation_noise_variance)).sample(seed=k2)

index_points = jnp.linspace(-1., 1., 100)[..., jnp.newaxis]

kernel = tfpk.ExponentiatedQuadratic(length_scale=0.1)

gprm = tfd.GaussianProcessRegressionModel(
    kernel=kernel,
    index_points=index_points,
    observation_index_points=observation_index_points,
    observations=observations,
    observation_noise_variance=observation_noise_variance)

samples = gprm.sample(10, seed=k3)
for i in range(10):
  plt.plot(index_points, samples[i], alpha=0.5)
plt.plot(observation_index_points, observations, marker='o', linestyle='')
plt.show()

png

Model Markov Tersembunyi juga didukung.

initial_distribution = tfd.Categorical(probs=[0.8, 0.2])
transition_distribution = tfd.Categorical(probs=[[0.7, 0.3],
                                                 [0.2, 0.8]])

observation_distribution = tfd.Normal(loc=[0., 15.], scale=[5., 10.])

model = tfd.HiddenMarkovModel(
    initial_distribution=initial_distribution,
    transition_distribution=transition_distribution,
    observation_distribution=observation_distribution,
    num_steps=7)

print(model.mean())
print(model.log_prob(jnp.zeros(7)))
print(model.sample(seed=random.PRNGKey(0)))
[3.       6.       7.5      8.249999 8.625001 8.812501 8.90625 ]
/usr/local/lib/python3.6/dist-packages/tensorflow_probability/substrates/jax/distributions/hidden_markov_model.py:483: UserWarning: HiddenMarkovModel.log_prob in TFP versions < 0.12.0 had a bug in which the transition model was applied prior to the initial step. This bug has been fixed. You may observe a slight change in behavior.
  'HiddenMarkovModel.log_prob in TFP versions < 0.12.0 had a bug '
-19.855635
[ 1.3641367  0.505798   1.3626463  3.6541772  2.272286  15.10309
 22.794212 ]

Beberapa distribusi seperti PixelCNN tidak didukung belum karena ketergantungan yang ketat pada TensorFlow atau XLA tidak kompatibel.

Bijektor

Sebagian besar bijector TFP didukung di JAX hari ini!

tfb.Exp().inverse(1.)
DeviceArray(0., dtype=float32)
bij = tfb.Shift(1.)(tfb.Scale(3.))
print(bij.forward(jnp.ones(5)))
print(bij.inverse(jnp.ones(5)))
[4. 4. 4. 4. 4.]
[0. 0. 0. 0. 0.]
b = tfb.FillScaleTriL(diag_bijector=tfb.Exp(), diag_shift=None)
print(b.forward(x=[0., 0., 0.]))
print(b.inverse(y=[[1., 0], [.5, 2]]))
[[1. 0.]
 [0. 1.]]
[0.6931472 0.5       0.       ]
b = tfb.Chain([tfb.Exp(), tfb.Softplus()])
# or:
# b = tfb.Exp()(tfb.Softplus())
print(b.forward(-jnp.ones(5)))
[1.3678794 1.3678794 1.3678794 1.3678794 1.3678794]

Bijectors yang kompatibel dengan transformasi JAX seperti jit , grad dan vmap .

jit(vmap(tfb.Exp().inverse))(jnp.arange(4.))
DeviceArray([     -inf, 0.       , 0.6931472, 1.0986123], dtype=float32)
x = jnp.linspace(0., 1., 100)
plt.plot(x, jit(grad(lambda x: vmap(tfb.Sigmoid().inverse)(x).sum()))(x))
plt.show()

png

Beberapa bijectors, seperti RealNVP dan FFJORD belum didukung.

MCMC

Kami telah porting tfp.mcmc ke JAX juga, sehingga kami dapat menjalankan algoritma seperti Hamiltonian Monte Carlo (HMC) dan No-U-Turn-Sampler (KACANG) di JAX.

target_log_prob = tfd.MultivariateNormalDiag(jnp.zeros(2), jnp.ones(2)).log_prob

Tidak seperti TFP pada TF, kita dituntut untuk lulus PRNGKey ke sample_chain menggunakan seed argumen kata kunci.

def run_chain(key, state):
  kernel = tfp.mcmc.NoUTurnSampler(target_log_prob, 1e-1)
  return tfp.mcmc.sample_chain(1000,
      current_state=state,
      kernel=kernel,
      trace_fn=lambda _, results: results.target_log_prob,
      seed=key)
states, log_probs = jit(run_chain)(random.PRNGKey(0), jnp.zeros(2))
plt.figure()
plt.scatter(*states.T, alpha=0.5)
plt.figure()
plt.plot(log_probs)
plt.show()

png

png

Untuk menjalankan beberapa rantai, kita baik dapat melewati batch negara ke sample_chain atau penggunaan vmap (meskipun kita belum dieksplorasi perbedaan kinerja antara dua pendekatan).

states, log_probs = jit(run_chain)(random.PRNGKey(0), jnp.zeros([10, 2]))
plt.figure()
for i in range(10):
  plt.scatter(*states[:, i].T, alpha=0.5)
plt.figure()
for i in range(10):
  plt.plot(log_probs[:, i], alpha=0.5)
plt.show()

png

png

Pengoptimal

TFP di JAX mendukung beberapa pengoptimal penting seperti BFGS dan L-BFGS. Mari kita siapkan fungsi kerugian kuadratik berskala sederhana.

minimum = jnp.array([1.0, 1.0])  # The center of the quadratic bowl.
scales = jnp.array([2.0, 3.0])  # The scales along the two axes.

# The objective function and the gradient.
def quadratic_loss(x):
  return jnp.sum(scales * jnp.square(x - minimum))

start = jnp.array([0.6, 0.8])  # Starting point for the search.

BFGS dapat menemukan kerugian minimum ini.

optim_results = tfp.optimizer.bfgs_minimize(
    value_and_grad(quadratic_loss), initial_position=start, tolerance=1e-8)

# Check that the search converged
assert(optim_results.converged)
# Check that the argmin is close to the actual value.
np.testing.assert_allclose(optim_results.position, minimum)
# Print out the total number of function evaluations it took. Should be 5.
print("Function evaluations: %d" % optim_results.num_objective_evaluations)
Function evaluations: 5

Begitu juga dengan L-BFGS.

optim_results = tfp.optimizer.lbfgs_minimize(
    value_and_grad(quadratic_loss), initial_position=start, tolerance=1e-8)

# Check that the search converged
assert(optim_results.converged)
# Check that the argmin is close to the actual value.
np.testing.assert_allclose(optim_results.position, minimum)
# Print out the total number of function evaluations it took. Should be 5.
print("Function evaluations: %d" % optim_results.num_objective_evaluations)
Function evaluations: 5

Untuk vmap L-BFGS, mari kita set terserah fungsi yang mengoptimalkan kerugian bagi titik awal tunggal.

def optimize_single(start):
  return tfp.optimizer.lbfgs_minimize(
      value_and_grad(quadratic_loss), initial_position=start, tolerance=1e-8)

all_results = jit(vmap(optimize_single))(
    random.normal(random.PRNGKey(0), (10, 2)))
assert all(all_results.converged)
for i in range(10):
  np.testing.assert_allclose(optim_results.position[i], minimum)
print("Function evaluations: %s" % all_results.num_objective_evaluations)
Function evaluations: [6 6 9 6 6 8 6 8 5 9]

Peringatan

Ada beberapa perbedaan mendasar antara TF dan JAX, beberapa perilaku TFP akan berbeda antara kedua media dan tidak semua fungsi didukung. Sebagai contoh,

  • TFP pada JAX tidak mendukung hal seperti tf.Variable karena tidak ada seperti itu ada di JAX. Ini juga berarti utilitas seperti tfp.util.TransformedVariable tidak didukung baik.
  • tfp.layers tidak didukung di backend belum, karena ketergantungannya pada Keras dan tf.Variable s.
  • tfp.math.minimize tidak bekerja di TFP pada JAX karena ketergantungannya pada tf.Variable .
  • Dengan TFP di JAX, bentuk tensor selalu merupakan nilai integer konkret dan tidak pernah diketahui/dinamis seperti pada TFP di TF.
  • Pseudorandomness ditangani secara berbeda di TF dan JAX (lihat lampiran).
  • Perpustakaan di tfp.experimental tidak dijamin ada di substrat JAX.
  • Aturan promosi Dtype berbeda antara TF dan JAX. TFP di JAX mencoba untuk menghormati semantik dtype TF secara internal, untuk konsistensi.
  • Bijector belum terdaftar sebagai pytrees JAX.

Untuk melihat daftar lengkap dari apa yang didukung dalam TFP pada JAX, silakan lihat dokumentasi API .

Kesimpulan

Kami telah mem-porting banyak fitur TFP ke JAX dan sangat antusias untuk melihat apa yang akan dibuat semua orang. Beberapa fungsi belum didukung; jika kita telah melewatkan sesuatu yang penting untuk Anda (atau jika Anda menemukan bug!) silakan hubungi kami - Anda dapat mengirim email tfprobability@tensorflow.org atau mengajukan masalah pada repo Github kami .

Lampiran: pseudorandomness di JAX

Jumlah pseudorandom generasi (PRNG) Model JAX adalah stateless. Tidak seperti model stateful, tidak ada state global yang dapat berubah yang berevolusi setelah setiap undian acak. Dalam model JAX, kita mulai dengan kunci PRNG, yang bertindak seperti sepasang 32-bit bilangan bulat. Kita dapat membuat tombol-tombol ini dengan menggunakan jax.random.PRNGKey .

key = random.PRNGKey(0)  # Creates a key with value [0, 0]
print(key)
[0 0]

Fungsi acak dalam JAX mengkonsumsi kunci untuk deterministik menghasilkan variate acak, yang berarti mereka tidak boleh digunakan lagi. Sebagai contoh, kita dapat menggunakan key untuk sampel nilai terdistribusi normal, tapi kita tidak harus menggunakan key lagi di tempat lain. Selanjutnya, melewati nilai yang sama dalam random.normal akan menghasilkan nilai yang sama.

print(random.normal(key))
-0.20584226

Jadi bagaimana kita bisa menggambar banyak sampel dari satu kunci? Jawabannya adalah membelah kunci. Ide dasarnya adalah bahwa kita dapat membagi PRNGKey menjadi beberapa, dan masing-masing tombol baru dapat diperlakukan sebagai sumber independen keacakan.

key1, key2 = random.split(key, num=2)
print(key1, key2)
[4146024105  967050713] [2718843009 1272950319]

Pemisahan kunci bersifat deterministik tetapi kacau, sehingga setiap kunci baru sekarang dapat digunakan untuk menggambar sampel acak yang berbeda.

print(random.normal(key1), random.normal(key2))
0.14389051 -1.2515389

Untuk rincian lebih lanjut tentang model kunci membelah deterministik JAX, lihat panduan ini .