روز جامعه ML 9 نوامبر است! برای به روز رسانی از TensorFlow، JAX به ما بپیوندید، و بیشتر بیشتر بدانید

تور اوریکس

مشاهده در TensorFlow.org در Google Colab اجرا کنید مشاهده منبع در GitHub دانلود دفترچه یادداشت

اوریکس چیست؟

Oryx یک کتابخانه تجربی است که JAX را به برنامه های مختلفی از ساخت و آموزش شبکه های عصبی پیچیده تا استنباط تقریبی بیزی در مدل های ژنتیکی عمیق گسترش می دهد. همانطور که JAX jit ، vmap و grad فراهم می کند ، Oryx مجموعه ای از تبدیل عملکردهای ترکیبی را فراهم می کند که نوشتن کد ساده و تبدیل آن برای ایجاد پیچیدگی را در حالی که کاملا با JAX قابل همکاری است ، امکان پذیر می کند.

JAX فقط می تواند کد خالص و کاربردی (یعنی کد بدون عوارض جانبی) را با خیال راحت تبدیل کند. در حالی که نوشتن و استدلال کردن کد خالص آسانتر است ، اما کد "ناخالص" اغلب خلاصه تر و راحت تر بیان می شود.

در هسته اصلی خود ، Oryx کتابخانه ای است که "تقویت" کد عملکردی خالص را برای انجام کارهایی مانند تعریف حالت یا بیرون کشیدن مقادیر میانی امکان پذیر می کند. هدف آن نازک بودن یک لایه در بالای JAX است ، و از رویکرد حداقلی JAX در محاسبات عددی استفاده می کند. اوریکس از نظر مفهومی به چندین "لایه" تقسیم شده است که هر کدام از آنها در لایه زیرین آن بنا شده اند.

کد منبع Oryx را می توان در GitHub یافت.

برپایی

pip install -q oryx 1>/dev/null
import matplotlib.pyplot as plt
import seaborn as sns
sns.set(style='whitegrid')

import jax
import jax.numpy as jnp
from jax import random
from jax import vmap
from jax import jit
from jax import grad

import oryx

tfd = oryx.distributions

state = oryx.core.state
ppl = oryx.core.ppl

inverse = oryx.core.inverse
ildj = oryx.core.ildj
plant = oryx.core.plant
reap = oryx.core.reap
sow = oryx.core.sow
unzip = oryx.core.unzip

nn = oryx.experimental.nn
mcmc = oryx.experimental.mcmc
optimizers = oryx.experimental.optimizers

لایه 0: تبدیلات عملکرد پایه

اوریکس در پایه خود چندین تغییر کارکرد جدید را تعریف می کند. این تبدیل ها با استفاده از ماشین آلات ردیابی JAX اجرا می شوند و با تبدیلات JAX موجود مانند jit ، grad ، vmap و غیره قابل همکاری هستند.

وارون کردن عملکرد خودکار

oryx.core.inverse و oryx.core.ildj عملکردی هستند که می توانند به طور برنامه ای یک عملکرد را معکوس کرده و به ترتیب Jacobian log-det آن را (ILDJ) محاسبه کنند. این تحولات در مدل سازی احتمالی برای محاسبه احتمالات ورود به سیستم با استفاده از فرمول تغییر متغیر مفید هستند. محدودیت هایی در انواع عملکردهایی که با آنها سازگار است وجود دارد (برای اطلاعات بیشتر به اسناد مراجعه کنید).

def f(x):
  return jnp.exp(x) + 2.
print(inverse(f)(4.))  # ln(2)
print(ildj(f)(4.)) # -ln(2)
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
0.6931472
-0.6931472

محصول

oryx.core.harvest . ما مقادیر را با استفاده از تابع sow تگ می کنیم.

def f(x):
  y = sow(x + 1., name='y', tag='intermediate')
  return y ** 2
print('Reap:', reap(f, tag='intermediate')(1.))  # Pulls out 'y'
print('Plant:', plant(f, tag='intermediate')(dict(y=5.), 1.))  # Injects 5. for 'y'
Reap: {'y': DeviceArray(2., dtype=float32)}
Plant: 25.0

از حالت فشرده خارج کنید

oryx.core.unzip یک تابع را به دو oryx.core.unzip تقسیم می کند در امتداد مجموعه ای از مقادیر که به عنوان واسطه برچسب گذاری شده اند ، سپس توابع init_f و apply_f . init_f یک آرگومان کلیدی می گیرد و واسطه ها را برمی گرداند. apply_f برمی گرداند که واسطه ها را می گیرد و خروجی عملکرد اصلی را برمی گرداند.

def f(key, x):
  w = sow(random.normal(key), tag='variable', name='w')
  return w * x
init_f, apply_f = unzip(f, tag='variable')(random.PRNGKey(0), 1.)

تابع init_f f اجرا می کند اما فقط متغیرهای خود را برمی گرداند.

init_f(random.PRNGKey(0))
{'w': DeviceArray(-0.20584226, dtype=float32)}

apply_f مجموعه ای از متغیرها را به عنوان اولین ورودی خود می گیرد و f با مجموعه متغیرهای داده شده اجرا می کند.

apply_f(dict(w=2.), 2.)  # Runs f with `w = 2`.
DeviceArray(4., dtype=float32)

لایه 1: تحولات سطح بالاتر

Oryx با ایجاد سطح پایین عملکرد معکوس ، برداشت و بازکردن از حالت فشرده ، چندین تغییر در سطح بالاتر را برای نوشتن محاسبات مناسب و برنامه نویسی احتمالی ارائه می دهد.

توابع دولتی ( core.state )

ما اغلب علاقه مند به بیان محاسبات متغیر در جایی هستیم که مجموعه ای از پارامترها را مقدار دهی اولیه می کنیم و محاسباتی را از نظر پارامترها بیان می کنیم. در oryx.core.state ، اوریکس فراهم می کند init تحول است که یک تابع را به یکی که مقداردهی اولیه می تبدیل Module ، یک ظرف برای دولت است.

Module ها شبیه Pytorch و TensorFlow Module با این تفاوت که غیرقابل تغییر هستند.

def make_dense(dim_out):
  def forward(x, init_key=None):
    w_key, b_key = random.split(init_key)
    dim_in = x.shape[0]
    w = state.variable(random.normal(w_key, (dim_in, dim_out)), name='w')
    b = state.variable(random.normal(w_key, (dim_out,)), name='b')
    return jnp.dot(x, w) + b
  return forward

layer = state.init(make_dense(5))(random.PRNGKey(0), jnp.zeros(2))
print('layer:', layer)
print('layer.w:', layer.w)
print('layer.b:', layer.b)
layer: FunctionModule(dict_keys(['w', 'b']))
layer.w: [[-2.6105583   0.03385283  1.0863334  -1.4802988   0.48895672]
 [ 1.062516    0.5417484   0.0170228   0.2722685   0.30522448]]
layer.b: [0.59902626 0.2172144  2.4202902  0.03266738 1.2164948 ]

Module ها به عنوان pytrees JAX ثبت شده اند و می توانند به عنوان ورودی به توابع تبدیل شده JAX استفاده شوند. Oryx یک عملکرد call مناسب را فراهم می کند که Module را اجرا می کند.

vmap(state.call, in_axes=(None, 0))(layer, jnp.ones((5, 2)))
DeviceArray([[-0.94901603,  0.7928156 ,  3.5236464 , -1.1753628 ,
               2.010676  ],
             [-0.94901603,  0.7928156 ,  3.5236464 , -1.1753628 ,
               2.010676  ],
             [-0.94901603,  0.7928156 ,  3.5236464 , -1.1753628 ,
               2.010676  ],
             [-0.94901603,  0.7928156 ,  3.5236464 , -1.1753628 ,
               2.010676  ],
             [-0.94901603,  0.7928156 ,  3.5236464 , -1.1753628 ,
               2.010676  ]], dtype=float32)

API state همچنین نوشتن به روزرسانی های مناسب (مانند میانگین های در حال اجرا) را با استفاده از تابع assign کند. Module حاصل دارای یک عملکرد update با امضای ورودی است که همان __call__ Module است اما نسخه جدیدی از Module با وضعیت به روز شده ایجاد می کند.

def counter(x, init_key=None):
  count = state.variable(0., key=init_key, name='count')
  count = state.assign(count + 1., name='count')
  return x + count
layer = state.init(counter)(random.PRNGKey(0), 0.)
print(layer.count)
updated_layer = layer.update(0.)
print(updated_layer.count) # Count has advanced!
print(updated_layer.call(1.))
0.0
1.0
3.0

برنامه نویسی احتمالی

در oryx.core.ppl ، Oryx مجموعه ای از ابزارهای ساخته شده در بالای harvest و inverse که هدف آنها نوشتن و تبدیل برنامه های احتمالی را بصری و آسان می کند.

در Oryx ، یک برنامه احتمالی یک تابع JAX است که یک منبع تصادفی را به عنوان اولین آرگومان خود در نظر می گیرد و یک نمونه را از یک توزیع ، به عنوان مثال ، f :: Key -> Sample برمی گرداند. برای نوشتن این برنامه ها ، Oryx توزیع های احتمال random_variable می پیچد و یک تابع ساده random_variable که توزیع را به یک برنامه احتمالی تبدیل می کند.

def sample(key):
  return ppl.random_variable(tfd.Normal(0., 1.))(key)
sample(random.PRNGKey(0))
DeviceArray(-0.20584235, dtype=float32)

با برنامه های احتمالی چه کاری می توانیم انجام دهیم؟ ساده ترین کار این است که یک برنامه احتمالی (یعنی یک تابع نمونه برداری) بگیرید و آن را به برنامه ای تبدیل کنید که تراکم ورود به سیستم یک نمونه را فراهم کند.

ppl.log_prob(sample)(1.)
DeviceArray(-1.4189385, dtype=float32)

تابع احتمال ورود به سیستم جدید با سایر تبدیلات JAX مانند vmap و grad سازگار است.

grad(lambda s: vmap(ppl.log_prob(sample))(s).sum())(jnp.arange(10.))
DeviceArray([-0., -1., -2., -3., -4., -5., -6., -7., -8., -9.], dtype=float32)

با استفاده از تحول ildj ، می توانیم log_prob برنامه هایی را محاسبه کنیم که به طور log_prob نمونه ها را تغییر می دهند.

def sample(key):
  x = ppl.random_variable(tfd.Normal(0., 1.))(key)
  return jnp.exp(x / 2.) + 2.
_, ax = plt.subplots(2)
ax[0].hist(jit(vmap(sample))(random.split(random.PRNGKey(0), 1000)),
    bins='auto')
x = jnp.linspace(0, 8, 100)
ax[1].plot(x, jnp.exp(jit(vmap(ppl.log_prob(sample)))(x)))
plt.show()

png

ما می توانیم مقادیر میانی را در یک برنامه احتمالی با نام ها برچسب گذاری کرده و نمونه برداری مشترک و توابع log-prob مشترک را بدست آوریم.

def sample(key):
  z_key, x_key = random.split(key)
  z = ppl.random_variable(tfd.Normal(0., 1.), name='z')(z_key)
  x = ppl.random_variable(tfd.Normal(z, 1.), name='x')(x_key)
  return x
ppl.joint_sample(sample)(random.PRNGKey(0))
{'x': DeviceArray(-1.1076484, dtype=float32),
 'z': DeviceArray(0.14389044, dtype=float32)}

Oryx همچنین دارای یک تابع joint_log_prob است که log_prob با joint_sample .

ppl.joint_log_prob(sample)(dict(x=0., z=0.))
DeviceArray(-1.837877, dtype=float32)

برای کسب اطلاعات بیشتر ، به اسناد مراجعه کنید.

لایه 2: کتابخانه های کوچک

اوریکس با ایجاد بیشتر در بالای لایه هایی که برنامه نویسی حالت و احتمال را کنترل می کنند ، کتابخانه های کوچک آزمایشی متناسب با کاربردهای خاص مانند یادگیری عمیق و استنباط بیزی را فراهم می کند.

شبکه های عصبی

در oryx.experimental.nn ، Oryx مجموعه ای از Layer شبکه عصبی مشترک را ارائه می دهد که کاملاً متناسب با API state . این لایه ها برای نمونه های منفرد ساخته شده اند (نه دسته ای) اما رفتارهای دسته ای را نادیده می گیرند تا الگوهایی مانند میانگین های در حال اجرا را در حالت عادی دسته ای کنترل کنند. آنها همچنین می توانند استدلال های کلمه کلیدی مانند training=True/False را به ماژول ها منتقل کنند.

Layer ها از یک Template مانند nn.Dense(200) با استفاده از state.init .

layer = state.init(nn.Dense(200))(random.PRNGKey(0), jnp.zeros(50))
print(layer, layer.params.kernel.shape, layer.params.bias.shape)
Dense(200) (50, 200) (200,)

یک Layer یک روش call که پاس رو به جلو را اجرا می کند.

layer.call(jnp.ones(50)).shape
(200,)

Oryx همچنین یک ترکیب کننده Serial ارائه می دهد.

mlp_template = nn.Serial([
  nn.Dense(200), nn.Relu(),
  nn.Dense(200), nn.Relu(),
  nn.Dense(10), nn.Softmax()
])
# OR
mlp_template = (
    nn.Dense(200) >> nn.Relu()
    >> nn.Dense(200) >> nn.Relu()
    >> nn.Dense(10) >> nn.Softmax())
mlp = state.init(mlp_template)(random.PRNGKey(0), jnp.ones(784))
mlp(jnp.ones(784))
DeviceArray([0.16362445, 0.21150257, 0.14715882, 0.10425295, 0.05952952,
             0.07531884, 0.08368199, 0.0376978 , 0.0159679 , 0.10126514],            dtype=float32)

ما می توانیم توابع و ترکیب کننده ها را برای ایجاد یک شبکه عصبی انعطاف پذیر "متا زبان" بهم بزنیم.

def resnet(template):
  def forward(x, init_key=None):
    layer = state.init(template, name='layer')(init_key, x)
    return x + layer(x)
  return forward

big_resnet_template = nn.Serial([
  nn.Dense(50)
  >> resnet(nn.Dense(50) >> nn.Relu())
  >> resnet(nn.Dense(50) >> nn.Relu())
  >> nn.Dense(10)
])
network = state.init(big_resnet_template)(random.PRNGKey(0), jnp.ones(784))
network(jnp.ones(784))
DeviceArray([-0.03828401,  0.9046303 ,  1.6083915 , -0.17005858,
              3.889552  ,  1.7427744 , -1.0567027 ,  3.0192878 ,
              0.28983995,  1.7103616 ], dtype=float32)

بهینه سازها

در oryx.experimental.optimizers ، Oryx مجموعه ای از بهینه سازهای مرتبه اول را ارائه می دهد که با استفاده از API state ساخته شده اند. طراحی آنها بر اساس کتابخانه optix JAX است ، جایی که بهینه سازها وضعیت مجموعه ای از به روزرسانی های شیب را حفظ می کنند. نسخه Oryx با استفاده از API state ، state مدیریت می کند.

network_key, opt_key = random.split(random.PRNGKey(0))
def autoencoder_loss(network, x):
  return jnp.square(network.call(x) - x).mean()
network = state.init(nn.Dense(200) >> nn.Relu() >> nn.Dense(2))(network_key, jnp.zeros(2))
opt = state.init(optimizers.adam(1e-4))(opt_key, network, network)
g = grad(autoencoder_loss)(network, jnp.zeros(2))

g, opt = opt.call_and_update(network, g)
network = optimizers.optix.apply_updates(network, g)

زنجیره مارکوف مونت کارلو

در oryx.experimental.mcmc ، Oryx مجموعه ای از هسته های Markov Chain Monte Carlo (MCMC) را ارائه می دهد. MCMC روشی برای استنباط تقریبی بیزی است که در آن نمونه هایی از یک زنجیره مارکوف را می گیریم که توزیع ثابت آن توزیع خلفی مورد علاقه است.

کتابخانه MCMC اوریکس ، هم از طریق API و هم از state ppl .

def model(key):
  return jnp.exp(ppl.random_variable(tfd.MultivariateNormalDiag(
      jnp.zeros(2), jnp.ones(2)))(key))

پیاده روی تصادفی کلانشهر

samples = jit(mcmc.sample_chain(mcmc.metropolis(
    ppl.log_prob(model),
    mcmc.random_walk()), 1000))(random.PRNGKey(0), jnp.ones(2))
plt.scatter(samples[:, 0], samples[:, 1], alpha=0.5)
plt.show()

png

همیلتونی مونت کارلو

samples = jit(mcmc.sample_chain(mcmc.hmc(
    ppl.log_prob(model)), 1000))(random.PRNGKey(0), jnp.ones(2))
plt.scatter(samples[:, 0], samples[:, 1], alpha=0.5)
plt.show()

png