مساعدة في حماية الحاجز المرجاني العظيم مع TensorFlow على Kaggle تاريخ التحدي

جولة في المها

عرض على TensorFlow.org تشغيل في Google Colab عرض المصدر على جيثب تحميل دفتر

ما هو المها؟

Oryx هي مكتبة تجريبية توسع JAX إلى تطبيقات تتراوح من بناء وتدريب الشبكات العصبية المعقدة لتقريب الاستدلال البايزي في النماذج التوليدية العميقة. مثل JAX يوفر jit و vmap و grad ، يوفر Oryx مجموعة من تحويلات الوظائف vmap التي تتيح كتابة vmap برمجية بسيطة وتحويلها لبناء التعقيد مع الحفاظ على قابلية التشغيل البيني تمامًا مع JAX.

يمكن لـ JAX فقط تحويل الشفرة الوظيفية الخالصة بأمان (أي رمز بدون آثار جانبية). في حين أن الشفرة الخالصة يمكن أن تكون أسهل في الكتابة والتعليل ، إلا أن الشفرة "غير النقية" غالبًا ما تكون أكثر إيجازًا وأكثر تعبيرًا بسهولة.

تعتبر Oryx في جوهرها مكتبة تتيح "زيادة" الشفرة الوظيفية الخالصة لإنجاز مهام مثل تحديد الحالة أو سحب القيم الوسيطة. هدفها هو أن تكون رقيقة من طبقة أعلى JAX قدر الإمكان ، والاستفادة من نهج JAX البسيط في الحوسبة الرقمية. ينقسم المها من الناحية المفاهيمية إلى عدة "طبقات" ، كل مبنى على الطبقة الموجودة تحته.

يمكن العثور على الكود المصدري لـ Oryx على GitHub .

يثبت

pip install -q oryx 1>/dev/null
import matplotlib.pyplot as plt
import seaborn as sns
sns.set(style='whitegrid')

import jax
import jax.numpy as jnp
from jax import random
from jax import vmap
from jax import jit
from jax import grad

import oryx

tfd = oryx.distributions

state = oryx.core.state
ppl = oryx.core.ppl

inverse = oryx.core.inverse
ildj = oryx.core.ildj
plant = oryx.core.plant
reap = oryx.core.reap
sow = oryx.core.sow
unzip = oryx.core.unzip

nn = oryx.experimental.nn
mcmc = oryx.experimental.mcmc
optimizers = oryx.experimental.optimizers

الطبقة 0: تحويلات الوظيفة الأساسية

يحدد Oryx في قاعدته العديد من تحويلات الوظائف الجديدة. يتم تنفيذ هذه التحولات باستخدام آلية التتبع الخاصة بـ JAX ويمكن تشغيلها مع تحويلات JAX الحالية مثل jit ، و grad ، و vmap ، وما إلى ذلك.

انعكاس تلقائي للوظيفة

oryx.core.inverse و oryx.core.ildj عبارة عن تحويلات وظيفية يمكنها عكس دالة برمجيًا وتحسب عكس سجل Jacobian (ILDJ) على التوالي. هذه التحويلات مفيدة في النمذجة الاحتمالية لحساب احتمالات السجل باستخدام صيغة تغيير المتغير. ومع ذلك ، هناك قيود على أنواع الوظائف التي تتوافق معها (انظر الوثائق لمزيد من التفاصيل).

def f(x):
  return jnp.exp(x) + 2.
print(inverse(f)(4.))  # ln(2)
print(ildj(f)(4.)) # -ln(2)
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
0.6931472
-0.6931472

محصول

يتيح oryx.core.harvest وضع علامات على القيم في الوظائف إلى جانب القدرة على جمعها ، أو "جنيها" ، والقدرة على حقن القيم في مكانها ، أو "غرسها". نقوم بتمييز القيم باستخدام وظيفة sow .

def f(x):
  y = sow(x + 1., name='y', tag='intermediate')
  return y ** 2
print('Reap:', reap(f, tag='intermediate')(1.))  # Pulls out 'y'
print('Plant:', plant(f, tag='intermediate')(dict(y=5.), 1.))  # Injects 5. for 'y'
Reap: {'y': DeviceArray(2., dtype=float32)}
Plant: 25.0

فك الضغط

oryx.core.unzip بتقسيم الوظيفة إلى قسمين على طول مجموعة من القيم الموسومة كوسائط ، ثم تعيد الدالتين init_f و apply_f . يأخذ init_f الوسيطة الرئيسية ويعيد الوسطاء. تُرجع apply_f دالة تأخذ apply_f وتعيد ناتج الوظيفة الأصلية.

def f(key, x):
  w = sow(random.normal(key), tag='variable', name='w')
  return w * x
init_f, apply_f = unzip(f, tag='variable')(random.PRNGKey(0), 1.)

تعمل الدالة init_f تشغيل f ولكنها تُرجع فقط المتغيرات الخاصة بها.

init_f(random.PRNGKey(0))
{'w': DeviceArray(-0.20584226, dtype=float32)}

يأخذ apply_f مجموعة من المتغيرات apply_f الأول وينفذ f مع مجموعة معينة من المتغيرات.

apply_f(dict(w=2.), 2.)  # Runs f with `w = 2`.
DeviceArray(4., dtype=float32)

الطبقة 1: تحولات المستوى الأعلى

يبني Oryx على تحويلات دالة المعكوس والحصاد وفك الضغط ذات المستوى المنخفض لتقديم العديد من التحولات عالية المستوى لكتابة الحسابات ذات الحالة وللبرمجة الاحتمالية.

الوظائف ذات الحالة ( core.state )

غالبًا ما نهتم بالتعبير عن الحسابات ذات الحالة حيث نقوم بتهيئة مجموعة من المعلمات ونعبر عن عملية حسابية من حيث المعلمات. في oryx.core.state ، يوفر Oryx تحويل init الذي يحول دالة إلى أخرى تقوم بتهيئة Module ، وهي حاوية للحالة.

تشبه Module Pytorch و TensorFlow Module إلا أنها غير قابلة للتغيير.

def make_dense(dim_out):
  def forward(x, init_key=None):
    w_key, b_key = random.split(init_key)
    dim_in = x.shape[0]
    w = state.variable(random.normal(w_key, (dim_in, dim_out)), name='w')
    b = state.variable(random.normal(w_key, (dim_out,)), name='b')
    return jnp.dot(x, w) + b
  return forward

layer = state.init(make_dense(5))(random.PRNGKey(0), jnp.zeros(2))
print('layer:', layer)
print('layer.w:', layer.w)
print('layer.b:', layer.b)
layer: FunctionModule(dict_keys(['w', 'b']))
layer.w: [[-2.6105583   0.03385283  1.0863334  -1.4802988   0.48895672]
 [ 1.062516    0.5417484   0.0170228   0.2722685   0.30522448]]
layer.b: [0.59902626 0.2172144  2.4202902  0.03266738 1.2164948 ]

يتم تسجيل Module كأشجار JAX ويمكن استخدامها كمدخلات لوظائف JAX المحولة. يوفر Oryx وظيفة call ملائمة تقوم بتنفيذ Module .

vmap(state.call, in_axes=(None, 0))(layer, jnp.ones((5, 2)))
DeviceArray([[-0.94901603,  0.7928156 ,  3.5236464 , -1.1753628 ,
               2.010676  ],
             [-0.94901603,  0.7928156 ,  3.5236464 , -1.1753628 ,
               2.010676  ],
             [-0.94901603,  0.7928156 ,  3.5236464 , -1.1753628 ,
               2.010676  ],
             [-0.94901603,  0.7928156 ,  3.5236464 , -1.1753628 ,
               2.010676  ],
             [-0.94901603,  0.7928156 ,  3.5236464 , -1.1753628 ,
               2.010676  ]], dtype=float32)

في state تمكن API أيضا كتابة التحديثات جليل (مثل المتوسطات تشغيل) باستخدام assign وظيفة. Module الناتجة لها وظيفة update مع توقيع إدخال __call__ في Module ولكنه ينشئ نسخة جديدة من Module مع حالة محدثة.

def counter(x, init_key=None):
  count = state.variable(0., key=init_key, name='count')
  count = state.assign(count + 1., name='count')
  return x + count
layer = state.init(counter)(random.PRNGKey(0), 0.)
print(layer.count)
updated_layer = layer.update(0.)
print(updated_layer.count) # Count has advanced!
print(updated_layer.call(1.))
0.0
1.0
3.0

البرمجة الاحتمالية

في oryx.core.ppl ، يوفر Oryx مجموعة من الأدوات المبنية على قمة harvest inverse والتي تهدف إلى جعل كتابة وتحويل البرامج الاحتمالية أمرًا بديهيًا وسهلاً.

في Oryx ، البرنامج الاحتمالي هو دالة JAX التي تأخذ مصدر العشوائية كوسيطة أولى وتعيد عينة من التوزيع ، أي f :: Key -> Sample . من أجل كتابة هذه البرامج ، يلف Oryx توزيعات TensorFlow الاحتمالية ويوفر وظيفة بسيطة random_variable يحول التوزيع إلى برنامج احتمالي.

def sample(key):
  return ppl.random_variable(tfd.Normal(0., 1.))(key)
sample(random.PRNGKey(0))
DeviceArray(-0.20584235, dtype=float32)

ماذا يمكننا أن نفعل بالبرامج الاحتمالية؟ سيكون أبسط شيء هو أخذ برنامج احتمالي (أي وظيفة أخذ العينات) وتحويله إلى برنامج يوفر الكثافة اللوغاريتمية للعينة.

ppl.log_prob(sample)(1.)
DeviceArray(-1.4189385, dtype=float32)

تتوافق وظيفة احتمالية السجل الجديدة مع تحويلات JAX الأخرى مثل vmap و grad .

grad(lambda s: vmap(ppl.log_prob(sample))(s).sum())(jnp.arange(10.))
DeviceArray([-0., -1., -2., -3., -4., -5., -6., -7., -8., -9.], dtype=float32)

باستخدام تحويل ildj ، يمكننا حساب log_prob من البرامج التي تحول العينات بشكل عكسي.

def sample(key):
  x = ppl.random_variable(tfd.Normal(0., 1.))(key)
  return jnp.exp(x / 2.) + 2.
_, ax = plt.subplots(2)
ax[0].hist(jit(vmap(sample))(random.split(random.PRNGKey(0), 1000)),
    bins='auto')
x = jnp.linspace(0, 8, 100)
ax[1].plot(x, jnp.exp(jit(vmap(ppl.log_prob(sample)))(x)))
plt.show()

بي إن جي

يمكننا تمييز القيم الوسيطة في برنامج احتمالي بالأسماء والحصول على وظائف مشتركة لأخذ العينات وسجل احتمالي مشترك.

def sample(key):
  z_key, x_key = random.split(key)
  z = ppl.random_variable(tfd.Normal(0., 1.), name='z')(z_key)
  x = ppl.random_variable(tfd.Normal(z, 1.), name='x')(x_key)
  return x
ppl.joint_sample(sample)(random.PRNGKey(0))
{'x': DeviceArray(-1.1076484, dtype=float32),
 'z': DeviceArray(0.14389044, dtype=float32)}

يحتوي Oryx أيضًا على وظيفة joint_log_prob التي تؤلف log_prob باستخدام joint_sample .

ppl.joint_log_prob(sample)(dict(x=0., z=0.))
DeviceArray(-1.837877, dtype=float32)

لمعرفة المزيد ، راجع الوثائق .

الطبقة 2: المكتبات المصغرة

بناءً على الطبقات التي تتعامل مع برمجة الحالة والاحتمالية ، يوفر Oryx مكتبات مصغرة تجريبية مصممة لتطبيقات محددة مثل التعلم العميق والاستدلال البايزي.

الشبكات العصبية

في oryx.experimental.nn ، يوفر Oryx مجموعة من Layer الشبكات العصبية الشائعة التي تتلاءم بدقة مع واجهة برمجة تطبيقات state . تم تصميم هذه الطبقات لأمثلة فردية (وليس مجموعات) ولكنها تتجاوز سلوكيات الدُفعات للتعامل مع أنماط مثل تشغيل المتوسطات في تسوية الدُفعات. كما أنها تتيح تمرير وسيطات الكلمات الرئيسية مثل training=True/False إلى وحدات نمطية.

تتم تهيئة Layer من Template مثل nn.Dense(200) باستخدام state.init .

layer = state.init(nn.Dense(200))(random.PRNGKey(0), jnp.zeros(50))
print(layer, layer.params.kernel.shape, layer.params.bias.shape)
Dense(200) (50, 200) (200,)

تحتوي Layer على طريقة call تقوم بتشغيل مسارها الأمامي.

layer.call(jnp.ones(50)).shape
(200,)

يوفر Oryx أيضًا مُجمعًا Serial .

mlp_template = nn.Serial([
  nn.Dense(200), nn.Relu(),
  nn.Dense(200), nn.Relu(),
  nn.Dense(10), nn.Softmax()
])
# OR
mlp_template = (
    nn.Dense(200) >> nn.Relu()
    >> nn.Dense(200) >> nn.Relu()
    >> nn.Dense(10) >> nn.Softmax())
mlp = state.init(mlp_template)(random.PRNGKey(0), jnp.ones(784))
mlp(jnp.ones(784))
DeviceArray([0.16362445, 0.21150257, 0.14715882, 0.10425295, 0.05952952,
             0.07531884, 0.08368199, 0.0376978 , 0.0159679 , 0.10126514],            dtype=float32)

يمكننا تشذير الوظائف والجمعيات لإنشاء شبكة عصبية مرنة "لغة ميتا".

def resnet(template):
  def forward(x, init_key=None):
    layer = state.init(template, name='layer')(init_key, x)
    return x + layer(x)
  return forward

big_resnet_template = nn.Serial([
  nn.Dense(50)
  >> resnet(nn.Dense(50) >> nn.Relu())
  >> resnet(nn.Dense(50) >> nn.Relu())
  >> nn.Dense(10)
])
network = state.init(big_resnet_template)(random.PRNGKey(0), jnp.ones(784))
network(jnp.ones(784))
DeviceArray([-0.03828401,  0.9046303 ,  1.6083915 , -0.17005858,
              3.889552  ,  1.7427744 , -1.0567027 ,  3.0192878 ,
              0.28983995,  1.7103616 ], dtype=float32)

محسنون

في oryx.experimental.optimizers ، يوفر Oryx مجموعة من oryx.experimental.optimizers من الدرجة الأولى ، التي تم إنشاؤها باستخدام واجهة برمجة تطبيقات state . يعتمد optix مكتبة optix الخاصة بـ JAX ، حيث يحافظ optix على حالة حول مجموعة من تحديثات التدرج. يدير إصدار Oryx الحالة باستخدام واجهة برمجة تطبيقات state .

network_key, opt_key = random.split(random.PRNGKey(0))
def autoencoder_loss(network, x):
  return jnp.square(network.call(x) - x).mean()
network = state.init(nn.Dense(200) >> nn.Relu() >> nn.Dense(2))(network_key, jnp.zeros(2))
opt = state.init(optimizers.adam(1e-4))(opt_key, network, network)
g = grad(autoencoder_loss)(network, jnp.zeros(2))

g, opt = opt.call_and_update(network, g)
network = optimizers.optix.apply_updates(network, g)

سلسلة ماركوف مونتي كارلو

في oryx.experimental.mcmc ، توفر Oryx مجموعة من oryx.experimental.mcmc Markov Chain Monte Carlo (MCMC). MCMC هو نهج لتقريب الاستدلال البايزي حيث نأخذ عينات من سلسلة ماركوف التي يكون توزيعها الثابت هو التوزيع اللاحق للفائدة.

مكتبة Oryx's MCMC مبنية على كل من state و ppl API.

def model(key):
  return jnp.exp(ppl.random_variable(tfd.MultivariateNormalDiag(
      jnp.zeros(2), jnp.ones(2)))(key))

مسيرة عشوائية متروبوليس

samples = jit(mcmc.sample_chain(mcmc.metropolis(
    ppl.log_prob(model),
    mcmc.random_walk()), 1000))(random.PRNGKey(0), jnp.ones(2))
plt.scatter(samples[:, 0], samples[:, 1], alpha=0.5)
plt.show()

بي إن جي

هاميلتونيان مونتي كارلو

samples = jit(mcmc.sample_chain(mcmc.hmc(
    ppl.log_prob(model)), 1000))(random.PRNGKey(0), jnp.ones(2))
plt.scatter(samples[:, 0], samples[:, 1], alpha=0.5)
plt.show()

بي إن جي