Hari Komunitas ML adalah 9 November! Bergabung dengan kami untuk update dari TensorFlow, JAX, dan lebih Pelajari lebih lanjut

Tur Oryx

Lihat di TensorFlow.org Jalankan di Google Colab Lihat sumber di GitHub Unduh buku catatan

Apa itu Oryx?

Oryx adalah pustaka eksperimental yang memperluas JAX ke aplikasi mulai dari membangun dan melatih jaringan neural yang kompleks hingga mendekati inferensi Bayesian dalam model generatif yang mendalam. Seperti JAX yang menyediakan jit , vmap , dan grad , Oryx menyediakan satu set transformasi fungsi yang dapat disusun yang memungkinkan penulisan kode sederhana dan mengubahnya untuk membangun kompleksitas sambil tetap dapat dioperasikan sepenuhnya dengan JAX.

JAX hanya dapat dengan aman mengubah kode fungsional murni (yaitu kode tanpa efek samping). Meskipun kode murni lebih mudah untuk ditulis dan dipikirkan, kode yang "tidak murni" seringkali bisa lebih ringkas dan lebih mudah ekspresif.

Pada intinya, Oryx adalah pustaka yang memungkinkan "menambah" kode fungsional murni untuk menyelesaikan tugas-tugas seperti menentukan status atau mengeluarkan nilai antara. Tujuannya adalah menjadi lapisan setipis mungkin di atas JAX, memanfaatkan pendekatan minimalis JAX untuk komputasi numerik. Oryx secara konseptual dibagi menjadi beberapa "lapisan", masing-masing bangunan di atas satu lapisan di bawahnya.

Kode sumber untuk Oryx dapat ditemukan di GitHub .

Mendirikan

pip install -q oryx 1>/dev/null
import matplotlib.pyplot as plt
import seaborn as sns
sns.set(style='whitegrid')

import jax
import jax.numpy as jnp
from jax import random
from jax import vmap
from jax import jit
from jax import grad

import oryx

tfd = oryx.distributions

state = oryx.core.state
ppl = oryx.core.ppl

inverse = oryx.core.inverse
ildj = oryx.core.ildj
plant = oryx.core.plant
reap = oryx.core.reap
sow = oryx.core.sow
unzip = oryx.core.unzip

nn = oryx.experimental.nn
mcmc = oryx.experimental.mcmc
optimizers = oryx.experimental.optimizers

Layer 0: Transformasi fungsi dasar

Pada dasarnya, Oryx mendefinisikan beberapa transformasi fungsi baru. Transformasi ini diimplementasikan menggunakan mesin penelusuran JAX dan dapat dioperasikan dengan transformasi JAX yang ada seperti jit , grad , vmap , dll.

Pembalikan fungsi otomatis

oryx.core.inverse dan oryx.core.ildj adalah transformasi fungsi yang secara programatik dapat membalikkan fungsi dan menghitung invers log-det Jacobian (ILDJ) masing-masing. Transformasi ini berguna dalam pemodelan probabilistik untuk menghitung probabilitas log menggunakan rumus perubahan variabel. Namun, ada batasan pada jenis fungsi yang kompatibel dengannya (lihat dokumentasi untuk lebih jelasnya).

def f(x):
  return jnp.exp(x) + 2.
print(inverse(f)(4.))  # ln(2)
print(ildj(f)(4.)) # -ln(2)
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
0.6931472
-0.6931472

Panen

oryx.core.harvest memungkinkan penandaan nilai dalam fungsi bersama dengan kemampuan untuk mengumpulkannya, atau "menuai" mereka, dan kemampuan untuk memasukkan nilai pada tempatnya, atau "menanam" mereka. Kami menandai nilai menggunakan fungsi sow .

def f(x):
  y = sow(x + 1., name='y', tag='intermediate')
  return y ** 2
print('Reap:', reap(f, tag='intermediate')(1.))  # Pulls out 'y'
print('Plant:', plant(f, tag='intermediate')(dict(y=5.), 1.))  # Injects 5. for 'y'
Reap: {'y': DeviceArray(2., dtype=float32)}
Plant: 25.0

Buka zip

oryx.core.unzip membagi fungsi menjadi dua di sepanjang satu set nilai yang ditandai sebagai perantara, lalu mengembalikan fungsi init_f dan apply_f . init_f mengambil argumen kunci dan mengembalikan perantara. apply_f mengembalikan fungsi yang mengambil perantara dan mengembalikan output fungsi asli.

def f(key, x):
  w = sow(random.normal(key), tag='variable', name='w')
  return w * x
init_f, apply_f = unzip(f, tag='variable')(random.PRNGKey(0), 1.)

Fungsi init_f menjalankan f tetapi hanya mengembalikan variabelnya.

init_f(random.PRNGKey(0))
{'w': DeviceArray(-0.20584226, dtype=float32)}

apply_f mengambil satu set variabel sebagai input pertama dan mengeksekusi f dengan set variabel yang diberikan.

apply_f(dict(w=2.), 2.)  # Runs f with `w = 2`.
DeviceArray(4., dtype=float32)

Lapisan 1: Transformasi tingkat yang lebih tinggi

Oryx membangun transformasi fungsi inversi, panen, dan unzip tingkat rendah untuk menawarkan beberapa transformasi tingkat yang lebih tinggi untuk menulis komputasi stateful dan untuk pemrograman probabilistik.

Fungsi stateful ( core.state )

Kami sering tertarik untuk mengekspresikan komputasi stateful di mana kami menginisialisasi sekumpulan parameter dan mengekspresikan komputasi dalam istilah parameter. Di oryx.core.state , Oryx menyediakan transformasi init yang mengubah fungsi menjadi fungsi yang menginisialisasi Module , wadah untuk status.

Module s menyerupai Pytorch dan TensorFlow Module s kecuali bahwa mereka yang berubah.

def make_dense(dim_out):
  def forward(x, init_key=None):
    w_key, b_key = random.split(init_key)
    dim_in = x.shape[0]
    w = state.variable(random.normal(w_key, (dim_in, dim_out)), name='w')
    b = state.variable(random.normal(w_key, (dim_out,)), name='b')
    return jnp.dot(x, w) + b
  return forward

layer = state.init(make_dense(5))(random.PRNGKey(0), jnp.zeros(2))
print('layer:', layer)
print('layer.w:', layer.w)
print('layer.b:', layer.b)
layer: FunctionModule(dict_keys(['w', 'b']))
layer.w: [[-2.6105583   0.03385283  1.0863334  -1.4802988   0.48895672]
 [ 1.062516    0.5417484   0.0170228   0.2722685   0.30522448]]
layer.b: [0.59902626 0.2172144  2.4202902  0.03266738 1.2164948 ]

Module didaftarkan sebagai pytree JAX dan dapat digunakan sebagai input untuk fungsi transformasi JAX. Oryx menyediakan fungsi call praktis yang menjalankan Module .

vmap(state.call, in_axes=(None, 0))(layer, jnp.ones((5, 2)))
DeviceArray([[-0.94901603,  0.7928156 ,  3.5236464 , -1.1753628 ,
               2.010676  ],
             [-0.94901603,  0.7928156 ,  3.5236464 , -1.1753628 ,
               2.010676  ],
             [-0.94901603,  0.7928156 ,  3.5236464 , -1.1753628 ,
               2.010676  ],
             [-0.94901603,  0.7928156 ,  3.5236464 , -1.1753628 ,
               2.010676  ],
             [-0.94901603,  0.7928156 ,  3.5236464 , -1.1753628 ,
               2.010676  ]], dtype=float32)

state API juga memungkinkan penulisan update stateful (seperti menjalankan rata-rata) menggunakan fungsi assign . Module dihasilkan memiliki fungsi update dengan tanda tangan input yang sama dengan __call__ Module tetapi membuat salinan baru Module dengan status yang diperbarui.

def counter(x, init_key=None):
  count = state.variable(0., key=init_key, name='count')
  count = state.assign(count + 1., name='count')
  return x + count
layer = state.init(counter)(random.PRNGKey(0), 0.)
print(layer.count)
updated_layer = layer.update(0.)
print(updated_layer.count) # Count has advanced!
print(updated_layer.call(1.))
0.0
1.0
3.0

Pemrograman probabilistik

Di oryx.core.ppl , Oryx menyediakan seperangkat alat yang dibangun di atas harvest dan inverse yang bertujuan untuk membuat penulisan dan transformasi program probabilistik menjadi intuitif dan mudah.

Di Oryx, program probabilistik adalah fungsi JAX yang mengambil sumber keacakan sebagai argumen pertamanya dan mengembalikan sampel dari distribusi, yaitu f :: Key -> Sample . Untuk menulis program ini, Oryx membungkus distribusi Probabilitas TensorFlow dan menyediakan fungsi sederhana random_variable yang mengubah distribusi menjadi program probabilistik.

def sample(key):
  return ppl.random_variable(tfd.Normal(0., 1.))(key)
sample(random.PRNGKey(0))
DeviceArray(-0.20584235, dtype=float32)

Apa yang dapat kita lakukan dengan program probabilistik? Hal yang paling sederhana adalah mengambil program probabilistik (yaitu fungsi pengambilan sampel) dan mengubahnya menjadi program yang menyediakan log-density sampel.

ppl.log_prob(sample)(1.)
DeviceArray(-1.4189385, dtype=float32)

Fungsi probabilitas log baru kompatibel dengan transformasi JAX lainnya seperti vmap dan grad .

grad(lambda s: vmap(ppl.log_prob(sample))(s).sum())(jnp.arange(10.))
DeviceArray([-0., -1., -2., -3., -4., -5., -6., -7., -8., -9.], dtype=float32)

Menggunakan transformasi ildj , kita dapat menghitung log_prob dari program yang mengubah sampel secara invertib.

def sample(key):
  x = ppl.random_variable(tfd.Normal(0., 1.))(key)
  return jnp.exp(x / 2.) + 2.
_, ax = plt.subplots(2)
ax[0].hist(jit(vmap(sample))(random.split(random.PRNGKey(0), 1000)),
    bins='auto')
x = jnp.linspace(0, 8, 100)
ax[1].plot(x, jnp.exp(jit(vmap(ppl.log_prob(sample)))(x)))
plt.show()

png

Kita dapat menandai nilai antara dalam program probabilistik dengan nama dan mendapatkan pengambilan sampel bersama dan fungsi log-prob gabungan.

def sample(key):
  z_key, x_key = random.split(key)
  z = ppl.random_variable(tfd.Normal(0., 1.), name='z')(z_key)
  x = ppl.random_variable(tfd.Normal(z, 1.), name='x')(x_key)
  return x
ppl.joint_sample(sample)(random.PRNGKey(0))
{'x': DeviceArray(-1.1076484, dtype=float32),
 'z': DeviceArray(0.14389044, dtype=float32)}

Oryx juga memiliki fungsi joint_log_prob yang menyusun log_prob dengan joint_sample .

ppl.joint_log_prob(sample)(dict(x=0., z=0.))
DeviceArray(-1.837877, dtype=float32)

Untuk mempelajari lebih lanjut, lihat dokumentasi .

Lapisan 2: Perpustakaan mini

Membangun lebih jauh di atas lapisan yang menangani pemrograman keadaan dan probabilistik, Oryx menyediakan perpustakaan mini eksperimental yang disesuaikan untuk aplikasi tertentu seperti pembelajaran mendalam dan inferensi Bayesian.

Jaringan saraf

Di oryx.experimental.nn , Oryx menyediakan sekumpulan Layer jaringan neural umum yang cocok dengan API state . Lapisan ini dibuat untuk contoh tunggal (bukan kumpulan) tetapi mengganti perilaku kumpulan untuk menangani pola seperti menjalankan rata-rata dalam normalisasi batch. Mereka juga memungkinkan melewatkan argumen kata kunci seperti training=True/False ke dalam modul.

Layer diinisialisasi dari Template seperti nn.Dense(200) menggunakan state.init .

layer = state.init(nn.Dense(200))(random.PRNGKey(0), jnp.zeros(50))
print(layer, layer.params.kernel.shape, layer.params.bias.shape)
Dense(200) (50, 200) (200,)

Layer memiliki metode call yang menjalankan forward pass-nya.

layer.call(jnp.ones(50)).shape
(200,)

Oryx juga menyediakan kombinator Serial .

mlp_template = nn.Serial([
  nn.Dense(200), nn.Relu(),
  nn.Dense(200), nn.Relu(),
  nn.Dense(10), nn.Softmax()
])
# OR
mlp_template = (
    nn.Dense(200) >> nn.Relu()
    >> nn.Dense(200) >> nn.Relu()
    >> nn.Dense(10) >> nn.Softmax())
mlp = state.init(mlp_template)(random.PRNGKey(0), jnp.ones(784))
mlp(jnp.ones(784))
DeviceArray([0.16362445, 0.21150257, 0.14715882, 0.10425295, 0.05952952,
             0.07531884, 0.08368199, 0.0376978 , 0.0159679 , 0.10126514],            dtype=float32)

Kita dapat menyisipkan fungsi dan kombinator untuk membuat "bahasa meta" jaringan saraf fleksibel.

def resnet(template):
  def forward(x, init_key=None):
    layer = state.init(template, name='layer')(init_key, x)
    return x + layer(x)
  return forward

big_resnet_template = nn.Serial([
  nn.Dense(50)
  >> resnet(nn.Dense(50) >> nn.Relu())
  >> resnet(nn.Dense(50) >> nn.Relu())
  >> nn.Dense(10)
])
network = state.init(big_resnet_template)(random.PRNGKey(0), jnp.ones(784))
network(jnp.ones(784))
DeviceArray([-0.03828401,  0.9046303 ,  1.6083915 , -0.17005858,
              3.889552  ,  1.7427744 , -1.0567027 ,  3.0192878 ,
              0.28983995,  1.7103616 ], dtype=float32)

Pengoptimal

Di oryx.experimental.optimizers , Oryx menyediakan satu set pengoptimal urutan pertama, dibuat menggunakan API state . Desainnya didasarkan pada pustaka optix JAX, tempat pengoptimal mempertahankan status tentang sekumpulan pembaruan gradien. Versi Oryx mengelola status menggunakan API state .

network_key, opt_key = random.split(random.PRNGKey(0))
def autoencoder_loss(network, x):
  return jnp.square(network.call(x) - x).mean()
network = state.init(nn.Dense(200) >> nn.Relu() >> nn.Dense(2))(network_key, jnp.zeros(2))
opt = state.init(optimizers.adam(1e-4))(opt_key, network, network)
g = grad(autoencoder_loss)(network, jnp.zeros(2))

g, opt = opt.call_and_update(network, g)
network = optimizers.optix.apply_updates(network, g)

Jaringan Markov Monte Carlo

Di oryx.experimental.mcmc , Oryx menyediakan satu set kernel Markov Chain Monte Carlo (MCMC). MCMC adalah pendekatan untuk memperkirakan inferensi Bayesian di mana kami mengambil sampel dari rantai Markov yang distribusi stasionernya adalah distribusi posterior yang diminati.

Pustaka MCMC Oryx dibangun di atas API state dan ppl .

def model(key):
  return jnp.exp(ppl.random_variable(tfd.MultivariateNormalDiag(
      jnp.zeros(2), jnp.ones(2)))(key))

Jalan acak Metropolis

samples = jit(mcmc.sample_chain(mcmc.metropolis(
    ppl.log_prob(model),
    mcmc.random_walk()), 1000))(random.PRNGKey(0), jnp.ones(2))
plt.scatter(samples[:, 0], samples[:, 1], alpha=0.5)
plt.show()

png

Hamiltonian Monte Carlo

samples = jit(mcmc.sample_chain(mcmc.hmc(
    ppl.log_prob(model)), 1000))(random.PRNGKey(0), jnp.ones(2))
plt.scatter(samples[:, 0], samples[:, 1], alpha=0.5)
plt.show()

png