الاستدلال الموزع مع JAX

عرض على TensorFlow.org تشغيل في Google Colab عرض المصدر على جيثبتحميل دفتر

لدى TensorFlow Probability (TFP) على JAX الآن أدوات للحوسبة الرقمية الموزعة. للتوسع في أعداد كبيرة من المسرّعات ، تم بناء الأدوات حول كتابة التعليمات البرمجية باستخدام نموذج "برنامج واحد متعدد البيانات" ، أو اختصارًا SPMD.

في هذا الكمبيوتر الدفتري ، سنتعرف على كيفية "التفكير في SPMD" ونقدم تجريدات TFP الجديدة لتوسيع نطاق التكوينات مثل كبسولات TPU أو مجموعات وحدات معالجة الرسومات. إذا كنت تقوم بتشغيل هذا الرمز بنفسك ، فتأكد من تحديد وقت تشغيل TPU.

سنقوم أولاً بتثبيت أحدث إصدارات TFP و JAX و TF.

التثبيتات

سنقوم باستيراد بعض المكتبات العامة ، جنبًا إلى جنب مع بعض أدوات JAX المساعدة.

الإعداد والواردات

INFO:tensorflow:Enabling eager execution
INFO:tensorflow:Enabling v2 tensorshape
INFO:tensorflow:Enabling resource variables
INFO:tensorflow:Enabling tensor equality
INFO:tensorflow:Enabling control flow v2

سنقوم أيضًا بإعداد بعض الأسماء المستعارة لـ TFP. يتم توفير تجريدية جديدة حاليا في tfp.experimental.distribute و tfp.experimental.mcmc .

tfd = tfp.distributions
tfb = tfp.bijectors
tfm = tfp.mcmc
tfed = tfp.experimental.distribute
tfde = tfp.experimental.distributions
tfem = tfp.experimental.mcmc

Root = tfed.JointDistributionCoroutine.Root

لتوصيل الكمبيوتر الدفتري بجهاز TPU ، نستخدم المساعد التالي من JAX. للتأكد من أننا متصلون ، نطبع عدد الأجهزة ، الذي يجب أن يكون ثمانية.

from jax.tools import colab_tpu
colab_tpu.setup_tpu()
print(f'Found {jax.device_count()} devices')
Found 8 devices

مقدمة سريعة ل jax.pmap

بعد الاتصال إلى TPU، لدينا إمكانية الوصول إلى ثمانية أجهزة. ومع ذلك ، عندما نقوم بتشغيل كود JAX بفارغ الصبر ، فإن JAX تكون افتراضية لتشغيل العمليات الحسابية على واحد فقط.

إن أبسط طريقة لتنفيذ عملية حسابية عبر العديد من الأجهزة هي تعيين وظيفة ، بحيث يقوم كل جهاز بتنفيذ فهرس واحد على الخريطة. يوفر JAX في jax.pmap ( "خارطة الموازية") تحول الأمر الذي يجعل وظيفة في واحدة التي تعين على وظيفة عبر العديد من الأجهزة.

في المثال التالي ، قمنا بإنشاء مصفوفة بحجم 8 (لمطابقة عدد الأجهزة المتاحة) وقمنا بتعيين وظيفة تضيف 5 عبرها.

xs = jnp.arange(8.)
out = jax.pmap(lambda x: x + 5.)(xs)
print(type(out), out)
<class 'jax.interpreters.pxla.ShardedDeviceArray'> [ 5.  6.  7.  8.  9. 10. 11. 12.]

علما بأن نتلقى ShardedDeviceArray نوع الظهر، مشيرا إلى أن مجموعة الانتاج يتم تقسيم جسديا عبر الأجهزة.

jax.pmap يعمل غويا مثل الخريطة، ولكن لديها عدد قليل من الخيارات الهامة التي تعدل سلوكها. افتراضيا، pmap يفترض يتم تعيين جميع المدخلات إلى وظيفة أكثر، ولكن يمكننا تعديل هذا السلوك مع in_axes حجة.

xs = jnp.arange(8.)
y = 5.
# Map over the 0-axis of `xs` and don't map over `y`
out = jax.pmap(lambda x, y: x + y, in_axes=(0, None))(xs, y)
print(out)
[ 5.  6.  7.  8.  9. 10. 11. 12.]

بالقياس، و out_axes حجة ل pmap يحدد ما إذا كان أو لم يكن لإرجاع القيم على كل جهاز. وضع out_axes ل None تلقائيا بإرجاع قيمة على الجهاز 1st و ينبغي ألا تستخدم إلا إذا كنا على ثقة من القيم هي نفسها على كل جهاز.

xs = jnp.ones(8) # Value is the same on each device
out = jax.pmap(lambda x: x + 1, out_axes=None)(xs)
print(out)
2.0

ماذا يحدث عندما لا يمكن التعبير عن ما نرغب في القيام به بسهولة كدالة نقية معينة؟ على سبيل المثال ، ماذا لو أردنا إجراء مجموع عبر المحور الذي نرسمه؟ تقدم JAX "مجموعات" ، وهي وظائف تتواصل عبر الأجهزة ، لتمكين كتابة برامج موزعة أكثر تشويقًا وتعقيدًا. لفهم كيفية عملها بالضبط ، سنقدم SPMD.

ما هو SPMD؟

البيانات المتعددة أحادية البرنامج (SPMD) هي نموذج برمجة متزامن يتم فيه تنفيذ برنامج واحد (أي نفس الرمز) في وقت واحد عبر الأجهزة ، ولكن يمكن أن تختلف مدخلات كل برنامج من البرامج قيد التشغيل.

إذا برنامجنا هو وظيفة بسيطة من مدخلاته (أي شيء من هذا القبيل x + 5 )، تشغيل برنامج في SPMD هو مجرد رسم ذلك البيانات المختلفة أكثر، كما فعلنا مع jax.pmap في وقت سابق. ومع ذلك ، يمكننا أن نفعل أكثر من مجرد "تعيين" وظيفة. تقدم JAX "المجموعات" ، وهي وظائف تتواصل عبر الأجهزة.

على سبيل المثال ، ربما نرغب في أخذ مجموع الكمية عبر جميع أجهزتنا. قبل ان نفعل ذلك، نحن بحاجة لتعيين اسم لمحور نحن الخرائط كنت فوق في pmap . نحن ثم استخدم lax.psum وظيفة ( "مبلغ مواز") لتنفيذ المبلغ عبر الأجهزة، وضمان نحدد اسمه محور نحن تلخيص أكثر.

def f(x):
  out = lax.psum(x, axis_name='i')
  return out
xs = jnp.arange(8.) # Length of array matches number of devices
jax.pmap(f, axis_name='i')(xs)
ShardedDeviceArray([28., 28., 28., 28., 28., 28., 28., 28.], dtype=float32)

و psum المجاميع الجماعية قيمة x على كل جهاز ومزامنة قيمته عبر الخريطة، أي out هو 28. على كل جهاز. لم نعد نقوم بتنفيذ "خريطة" بسيطة ، لكننا ننفذ برنامج SPMD حيث يمكن الآن لحسابات كل جهاز أن تتفاعل مع نفس الحساب على الأجهزة الأخرى ، وإن كان ذلك بطريقة محدودة باستخدام المجموعات. في هذا السيناريو، يمكننا استخدام out_axes = None ، لأن psum سوف مزامنة القيمة.

def f(x):
  out = lax.psum(x, axis_name='i')
  return out
jax.pmap(f, axis_name='i', out_axes=None)(jnp.arange(8.))
ShardedDeviceArray(28., dtype=float32)

يتيح لنا SPMD كتابة برنامج واحد يتم تشغيله على كل جهاز في أي تكوين TPU في وقت واحد. يمكن استخدام نفس الكود المستخدم في التعلم الآلي على 8 نوى من مادة TPU على جراب TPU الذي قد يحتوي على مئات إلى آلاف النوى! لتعليمي أكثر تفصيلا حول jax.pmap وSPMD، يمكنك الرجوع إلى و JAX 101 التعليمي .

MCMC على نطاق واسع

في هذا الكمبيوتر الدفتري ، نركز على استخدام أساليب Markov Chain Monte Carlo (MCMC) للاستدلال البايزي. قد تكون هناك طرق لاستخدام العديد من الأجهزة في MCMC ، ولكن في هذا الكمبيوتر الدفتري ، سنركز على اثنين:

  1. تشغيل سلاسل ماركوف المستقلة على أجهزة مختلفة. هذه الحالة بسيطة إلى حد ما ويمكن القيام بها باستخدام Vanilla TFP.
  2. مشاركة مجموعة بيانات عبر الأجهزة. هذه الحالة أكثر تعقيدًا قليلاً وتتطلب آلات TFP المضافة حديثًا.

سلاسل مستقلة

لنفترض أننا نرغب في إجراء استنتاج بايزي بشأن مشكلة باستخدام MCMC ونود تشغيل عدة سلاسل بالتوازي عبر عدة أجهزة (لنقل 2 على كل جهاز). تبين أن هذا برنامج يمكننا فقط "تعيينه" عبر الأجهزة ، أي برنامج لا يحتاج إلى مجموعات. للتأكد من أن كل برنامج ينفذ سلسلة Markov مختلفة (بدلاً من تشغيل نفس السلسلة) ، نقوم بتمرير قيمة مختلفة للبذور العشوائية لكل جهاز.

لنجربها في مشكلة لعبة لأخذ عينات من توزيع غاوسي ثنائي الأبعاد. يمكننا استخدام وظيفة MCMC الحالية الخاصة بـ TFP من خارج الصندوق. بشكل عام ، نحاول وضع معظم المنطق داخل الوظيفة المعينة لدينا للتمييز بشكل أكثر وضوحًا بين ما يتم تشغيله على جميع الأجهزة مقابل الأول فقط.

def run(seed):
  target_log_prob = tfd.Sample(tfd.Normal(0., 1.), 2).log_prob

  initial_state = jnp.zeros([2, 2]) # 2 chains
  kernel = tfm.HamiltonianMonteCarlo(target_log_prob, 1e-1, 10)
  def trace_fn(state, pkr):
    return target_log_prob(state)

  states, log_prob = tfm.sample_chain(
    num_results=1000,
    num_burnin_steps=1000,
    kernel=kernel,
    current_state=initial_state,
    trace_fn=trace_fn,
    seed=seed
  )
  return states, log_prob

في حد ذاته، و run يأخذ وظيفة في البذور عشوائي عديمي الجنسية (البدون لنرى كيف العمل العشوائية، يمكنك قراءة TFP على JAX المحمول أو رؤية البرنامج التعليمي JAX 101 ). خرائط run وأكثر من البذور المختلفة يؤدي إلى تشغيل عدة سلاسل ماركوف مستقلة.

states, log_probs = jax.pmap(run)(random.split(random.PRNGKey(0), 8))
print(states.shape, log_probs.shape)
# states is (8 devices, 1000 samples, 2 chains, 2 dimensions)
# log_prob is (8 devices, 1000 samples, 2 chains)
(8, 1000, 2, 2) (8, 1000, 2)

لاحظ كيف لدينا الآن محور إضافي يتوافق مع كل جهاز. يمكننا إعادة ترتيب الأبعاد وتسويتها للحصول على محور لـ 16 سلسلة.

states = states.transpose([0, 2, 1, 3]).reshape([-1, 1000, 2])
log_probs = log_probs.transpose([0, 2, 1]).reshape([-1, 1000])
fig, ax = plt.subplots(1, 2, figsize=(10, 5))
ax[0].plot(log_probs.T, alpha=0.4)
ax[1].scatter(*states.reshape([-1, 2]).T, alpha=0.1)
plt.show()

بي إن جي

عند تشغيل سلاسل المستقلة على العديد من الأجهزة، انها سهلة كما pmap -ing على وظيفة أن الاستخدامات tfp.mcmc ، وضمان نجتاز قيم مختلفة للالبذور عشوائي لكل جهاز.

تقاسم البيانات

عندما نقوم بعمل MCMC ، غالبًا ما يكون التوزيع المستهدف هو التوزيع اللاحق الذي يتم الحصول عليه عن طريق التكييف على مجموعة بيانات ، ويتضمن حساب كثافة السجل غير الطبيعية جمع الاحتمالات لكل بيانات تمت ملاحظتها.

مع مجموعات البيانات الكبيرة جدًا ، قد يكون تشغيل سلسلة واحدة على جهاز واحد مكلفًا للغاية. ومع ذلك ، عندما نتمكن من الوصول إلى أجهزة متعددة ، يمكننا تقسيم مجموعة البيانات عبر الأجهزة للاستفادة بشكل أفضل من الحوسبة المتوفرة لدينا.

إذا كنا نرغب في القيام MCMC مع مجموعة بيانات sharded، نحن بحاجة إلى ضمان unnormalized سجل الكثافة نحسب على كل جهاز يمثل المجموع، أي كثافة على جميع البيانات، وإلا كل جهاز سوف تفعل MCMC مع الهدف غير صحيح الخاصة توزيع. ولهذه الغاية، TFP لديها الآن أدوات جديدة (أي tfp.experimental.distribute و tfp.experimental.mcmc ) التي تمكن الحوسبة "sharded" الاحتمالات سجل والقيام MCMC معهم.

توزيعات مُقسمة

يوفر التجريد TFP الأساسية الآن لحساب probabiliities سجل sharded هو Sharded الفوقية التوزيع، والتي تأخذ توزيع كمدخل وإرجاع توزيع جديد له خصائص معينة عند تنفيذها في سياق SPMD. Sharded حياة في tfp.experimental.distribute .

حدسي، و Sharded يتوافق توزيعها على مجموعة من المتغيرات العشوائية التي تم "تقسيم" عبر الأجهزة. على كل جهاز ، سينتجون عينات مختلفة ، ويمكن أن يكون لكل منهم كثافات لوغاريتمية مختلفة. بدلا من ذلك، Sharded يتوافق توزيع إلى "لوحة" في لغة نموذج الرسومية، حيث حجم لوحة هو عدد من الأجهزة.

أخذ عينات من Sharded التوزيع

إذا كان لنا عينة من Normal التوزيع في برنامج كائن pmap افتتاحية باستخدام نفس البذور على كل جهاز، سوف نحصل على نفس العينة على كل جهاز. يمكننا التفكير في الوظيفة التالية كأخذ عينات لمتغير عشوائي واحد تتم مزامنته عبر الأجهزة.

# `pmap` expects at least one value to be mapped over, so we provide a dummy one
def f(seed, _):
  return tfd.Normal(0., 1.).sample(seed=seed)
jax.pmap(f, in_axes=(None, 0))(random.PRNGKey(0), jnp.arange(8.))
ShardedDeviceArray([-0.20584236, -0.20584236, -0.20584236, -0.20584236,
                    -0.20584236, -0.20584236, -0.20584236, -0.20584236],                   dtype=float32)

إذا كنا التفاف tfd.Normal(0., 1.) مع tfed.Sharded ، فإننا منطقيا لديها الآن ثمانية المتغيرات العشوائية المختلفة (واحد على كل جهاز)، وبالتالي سوف تنتج عينة مختلفة لكل واحد، على الرغم من يمر في نفس البذور .

def f(seed, _):
  return tfed.Sharded(tfd.Normal(0., 1.), shard_axis_name='i').sample(seed=seed)
jax.pmap(f, in_axes=(None, 0), axis_name='i')(random.PRNGKey(0), jnp.arange(8.))
ShardedDeviceArray([ 1.2152631 ,  0.7818249 ,  0.32549605,  0.6828047 ,
                     1.3973192 , -0.57830244,  0.37862757,  2.7706041 ],                   dtype=float32)

تمثيل مكافئ لهذا التوزيع على جهاز واحد هو مجرد 8 عينات عادية مستقلة. على الرغم من أن قيمة العينة ستكون مختلفة ( tfed.Sharded يفعل شبه عشوائي توليد عدد مختلف قليلا)، وكلاهما يمثل نفس التوزيع.

dist = tfd.Sample(tfd.Normal(0., 1.), jax.device_count())
dist.sample(seed=random.PRNGKey(0))
DeviceArray([ 0.08086783, -0.38624594, -0.3756545 ,  1.668957  ,
             -1.2758069 ,  2.1192007 , -0.85821325,  1.1305912 ],            dtype=float32)

أخذ سجل الكثافة من Sharded التوزيع

دعونا نرى ما يحدث عندما نحسب كثافة اللوغاريتمات لعينة من توزيع منتظم في سياق SPMD.

def f(seed, _):
  dist = tfd.Normal(0., 1.)
  x = dist.sample(seed=seed)
  return x, dist.log_prob(x)
jax.pmap(f, in_axes=(None, 0))(random.PRNGKey(0), jnp.arange(8.))
(ShardedDeviceArray([-0.20584236, -0.20584236, -0.20584236, -0.20584236,
                     -0.20584236, -0.20584236, -0.20584236, -0.20584236],                   dtype=float32),
 ShardedDeviceArray([-0.94012403, -0.94012403, -0.94012403, -0.94012403,
                     -0.94012403, -0.94012403, -0.94012403, -0.94012403],                   dtype=float32))

كل عينة هي نفسها على كل جهاز ، لذلك نحسب نفس الكثافة على كل جهاز أيضًا. بشكل بديهي ، لدينا هنا فقط توزيع على متغير واحد يتم توزيعه بشكل طبيعي.

مع Sharded التوزيع، لدينا توزيع أكثر من 8 المتغيرات العشوائية، لذلك عندما كنا حساب log_prob لعينة، ونحن خلاصة القول، عبر الأجهزة، على كل من الكثافة سجل الفردية. (قد تلاحظ أن إجمالي قيمة log_prob هذه أكبر من قيمة log_prob المفرد المحسوبة أعلاه.)

def f(seed, _):
  dist = tfed.Sharded(tfd.Normal(0., 1.), shard_axis_name='i')
  x = dist.sample(seed=seed)
  return x, dist.log_prob(x)
sample, log_prob = jax.pmap(f, in_axes=(None, 0), axis_name='i')(
    random.PRNGKey(0), jnp.arange(8.))
print('Sample:', sample)
print('Log Prob:', log_prob)
Sample: [ 1.2152631   0.7818249   0.32549605  0.6828047   1.3973192  -0.57830244
  0.37862757  2.7706041 ]
Log Prob: [-13.7349205 -13.7349205 -13.7349205 -13.7349205 -13.7349205 -13.7349205
 -13.7349205 -13.7349205]

ينتج التوزيع المكافئ "غير المقوى" نفس كثافة اللوغاريتمات.

dist = tfd.Sample(tfd.Normal(0., 1.), jax.device_count())
dist.log_prob(sample)
DeviceArray(-13.7349205, dtype=float32)

A Sharded توزيع ينتج قيم مختلفة من sample على كل جهاز، ولكن الحصول على نفس القيمة ل log_prob على كل جهاز. ماذا يحصل هنا؟ A Sharded توزيع يقوم psum داخليا لضمان log_prob القيم متزامنة عبر الأجهزة. لماذا نريد هذا السلوك؟ إذا كان لنا أن تقوم بتشغيل نفس السلسلة MCMC على كل جهاز، نود أن target_log_prob أن تكون هي نفسها في كل جهاز، حتى لو تم sharded بعض المتغيرات العشوائية في حساب عبر الأجهزة.

بالإضافة إلى ذلك، Sharded يضمن التوزيع التي التدرجات عبر الأجهزة هي صحيحة، لضمان أن خوارزميات مثل مؤسسة حمد الطبية، والتي تأخذ التدرجات وظيفة سجل الكثافة كجزء من وظيفة التحول، وإنتاج العينات المناسبة.

Sharded JointDistribution الصورة

يمكننا خلق نماذج متعددة مع Sharded المتغيرات العشوائية باستخدام JointDistribution الصورة (دينار). للأسف، Sharded توزيعات لا يمكن استخدامها بأمان مع الفانيليا tfd.JointDistribution الصورة، ولكن tfp.experimental.distribute الصادرات "يرقع" دينار التي سوف تتصرف مثل Sharded التوزيعات.

def f(seed, _):
  dist = tfed.JointDistributionSequential([
    tfd.Normal(0., 1.),
    tfed.Sharded(tfd.Normal(0., 1.), shard_axis_name='i'),
  ])
  x = dist.sample(seed=seed)
  return x, dist.log_prob(x)
jax.pmap(f, in_axes=(None, 0), axis_name='i')(random.PRNGKey(0), jnp.arange(8.))
([ShardedDeviceArray([1.6121525, 1.6121525, 1.6121525, 1.6121525, 1.6121525,
                      1.6121525, 1.6121525, 1.6121525], dtype=float32),
  ShardedDeviceArray([ 0.8690128 , -0.83167845,  1.2209264 ,  0.88412696,
                       0.76478404, -0.66208494, -0.0129658 ,  0.7391483 ],                   dtype=float32)],
 ShardedDeviceArray([-12.214451, -12.214451, -12.214451, -12.214451,
                     -12.214451, -12.214451, -12.214451, -12.214451],                   dtype=float32))

ويمكن لهذه دينار sharded على حد سواء Sharded والفانيليا TFP توزيعات كمكونات. بالنسبة للتوزيعات غير المجهزة ، نحصل على نفس العينة على كل جهاز ، وبالنسبة للتوزيعات المُقسمة ، نحصل على عينات مختلفة. و log_prob تتم مزامنة على كل جهاز أيضا.

MCMC مع Sharded التوزيعات

كيف نفكر Sharded التوزيعات في سياق MCMC؟ إذا كان لدينا نموذج توليدي التي يمكن أن يعبر عنه JointDistribution ، يمكننا اختيار بعض محور هذا النموذج ب "قشرة" عبر. عادةً ، سيتوافق أحد المتغيرات العشوائية في النموذج مع البيانات المرصودة ، وإذا كانت لدينا مجموعة بيانات كبيرة نرغب في تجزئتها عبر الأجهزة ، فنحن نريد أيضًا تجزئة المتغيرات المرتبطة بنقاط البيانات. قد يكون لدينا أيضًا متغيرات عشوائية "محلية" تكون فردية مع الملاحظات التي نقوم بتقسيمها ، لذلك سيتعين علينا إضافة هذه المتغيرات العشوائية إلى أجزاء أخرى.

سنذهب على أمثلة على استخدام Sharded توزيعات مع TFP MCMC في هذا القسم. سنبدأ مع أبسط النظرية الافتراضية سبيل المثال الانحدار اللوجستي، وتختتم مع مثال مصفوفة التعميل، وذلك بهدف إثبات بعض حالات الاستخدام ل distribute مكتبة.

مثال: الانحدار اللوجستي Bayesian لـ MNIST

نود إجراء الانحدار اللوجستي Bayesian على مجموعة بيانات كبيرة ؛ نموذج لديه مسبق \(p(\theta)\) على أوزان الانحدار، واحتمال \(p(y_i | \theta, x_i)\) التي تتلخص على جميع البيانات \(\{x_i, y_i\}_{i = 1}^N\) للحصول على مجموع كثافة السجل المشترك. إذا كنا شارد بياناتنا، كنا شارد المتغيرات العشوائية المرصودة \(x_i\) و \(y_i\) في نموذجنا.

نستخدم نموذج الانحدار اللوجستي Bayesian التالي لتصنيف MNIST:

\[ \begin{align*} w &\sim \mathcal{N}(0, 1) \\ b &\sim \mathcal{N}(0, 1) \\ y_i | w, b, x_i &\sim \textrm{Categorical}(w^T x_i + b) \end{align*} \]

لنقم بتحميل MNIST باستخدام مجموعات بيانات TensorFlow.

mnist = tfds.as_numpy(tfds.load('mnist', batch_size=-1))
raw_train_images, train_labels = mnist['train']['image'], mnist['train']['label']
train_images = raw_train_images.reshape([raw_train_images.shape[0], -1]) / 255.

raw_test_images, test_labels = mnist['test']['image'], mnist['test']['label']
test_images = raw_test_images.reshape([raw_test_images.shape[0], -1]) / 255.
Downloading and preparing dataset mnist/3.0.1 (download: 11.06 MiB, generated: 21.00 MiB, total: 32.06 MiB) to /root/tensorflow_datasets/mnist/3.0.1...
WARNING:absl:Dataset mnist is hosted on GCS. It will automatically be downloaded to your
local data directory. If you'd instead prefer to read directly from our public
GCS bucket (recommended if you're running on GCP), you can instead pass
`try_gcs=True` to `tfds.load` or set `data_dir=gs://tfds-data/datasets`.
HBox(children=(FloatProgress(value=0.0, description='Dl Completed...', max=4.0, style=ProgressStyle(descriptio…
Dataset mnist downloaded and prepared to /root/tensorflow_datasets/mnist/3.0.1. Subsequent calls will reuse this data.

لدينا 60000 صورة تدريبية ولكن دعونا نستفيد من 8 مراكز متوفرة لدينا ونقسمها 8 طرق. سنستخدم هذا يدوية shard دالة المنفعة.

def shard_value(x):
  x = x.reshape((jax.device_count(), -1, *x.shape[1:]))
  return jax.pmap(lambda x: x)(x) # pmap will physically place values on devices

shard = functools.partial(jax.tree_map, shard_value)
sharded_train_images, sharded_train_labels = shard((train_images, train_labels))
print(sharded_train_images.shape, sharded_train_labels.shape)
(8, 7500, 784) (8, 7500)

قبل أن نواصل ، دعنا نناقش بسرعة الدقة على TPU وتأثيرها على HMC. TPUs تنفيذ ضرب المصفوفات باستخدام منخفضة bfloat16 الدقة للسرعة. bfloat16 ضرب المصفوفات وغالبا ما تكون كافية لكثير من تطبيقات التعلم عميقة، ولكن عندما تستخدم مع مؤسسة حمد الطبية، وجدنا تجريبيا دقة أقل يمكن أن يؤدي إلى تباين المسارات، مما تسبب في الرفض. يمكننا استخدام مضاعفات المصفوفات بدقة أعلى على حساب بعض الحسابات الإضافية.

لزيادة الدقة لدينا matmul، يمكننا استخدام jax.default_matmul_precision الديكور مع "tensorfloat32" الدقة (لأعلى دقة يمكن أن نستخدمها "float32" الدقة).

دعونا الآن تحديد لدينا run وظيفة، والتي سوف تأخذ في البذور العشوائي (والتي سوف تكون هي نفسها في كل جهاز) وكسرة فخارية من MNIST. ستقوم الوظيفة بتنفيذ النموذج المذكور أعلاه وسنستخدم بعد ذلك وظيفة Vanilla MCMC في TFP لتشغيل سلسلة واحدة. ونحن سوف نتأكد من أن تزيين run مع jax.default_matmul_precision الديكور للتأكد من تشغيل الضرب مصفوفة بدقة أعلى، وإن كان في مثال معين أدناه، فإننا يمكن أن تستخدم فقط كذلك jnp.dot(images, w, precision=lax.Precision.HIGH) .

# We can use `out_axes=None` in the `pmap` because the results will be the same
# on every device. 
@functools.partial(jax.pmap, axis_name='data', in_axes=(None, 0), out_axes=None)
@jax.default_matmul_precision('tensorfloat32')
def run(seed, data):
  images, labels = data # a sharded dataset
  num_examples, dim = images.shape
  num_classes = 10

  def model_fn():
    w = yield Root(tfd.Sample(tfd.Normal(0., 1.), [dim, num_classes]))
    b = yield Root(tfd.Sample(tfd.Normal(0., 1.), [num_classes]))
    logits = jnp.dot(images, w) + b
    yield tfed.Sharded(tfd.Independent(tfd.Categorical(logits=logits), 1),
                       shard_axis_name='data')
  model = tfed.JointDistributionCoroutine(model_fn)

  init_seed, sample_seed = random.split(seed)

  initial_state = model.sample(seed=init_seed)[:-1] # throw away `y`

  def target_log_prob(*state):
    return model.log_prob((*state, labels))

  def accuracy(w, b):
    logits = images.dot(w) + b
    preds = logits.argmax(axis=-1)
    # We take the average accuracy across devices by using `lax.pmean`
    return lax.pmean((preds == labels).mean(), 'data')

  kernel = tfm.HamiltonianMonteCarlo(target_log_prob, 1e-2, 100)
  kernel = tfm.DualAveragingStepSizeAdaptation(kernel, 500)
  def trace_fn(state, pkr):
    return (
        target_log_prob(*state),
        accuracy(*state),
        pkr.new_step_size)
  states, trace = tfm.sample_chain(
    num_results=1000,
    num_burnin_steps=1000,
    current_state=initial_state,
    kernel=kernel,
    trace_fn=trace_fn,
    seed=sample_seed
  )
  return states, trace

jax.pmap يتضمن ترجمة JIT ولكن تم تخزينها مؤقتا وظيفة جمعت بعد المكالمة الأولى. وسوف ندعو run وتجاهل الإخراج إلى ذاكرة التخزين المؤقت تجميع.

%%time
output = run(random.PRNGKey(0), (sharded_train_images, sharded_train_labels))
jax.tree_map(lambda x: x.block_until_ready(), output)
CPU times: user 24.5 s, sys: 48.2 s, total: 1min 12s
Wall time: 1min 54s

سنقوم الآن الدعوة run مرة أخرى لنرى كم من الوقت يستغرق التنفيذ الفعلي.

%%time
states, trace = run(random.PRNGKey(0), (sharded_train_images, sharded_train_labels))
jax.tree_map(lambda x: x.block_until_ready(), trace)
CPU times: user 13.1 s, sys: 45.2 s, total: 58.3 s
Wall time: 1min 43s

نحن ننفذ 200000 خطوة قفزة ، كل منها يحسب تدرجًا على مجموعة البيانات بأكملها. يتيح لنا تقسيم الحساب على 8 مراكز حساب ما يعادل 200000 حقبة تدريب في حوالي 95 ثانية ، أي حوالي 2100 حقبة في الثانية!

دعنا نرسم كثافة اللوغاريتمات لكل عينة ودقة كل عينة:

fig, ax = plt.subplots(1, 3, figsize=(15, 5))
ax[0].plot(trace[0])
ax[0].set_title('Log Prob')
ax[1].plot(trace[1])
ax[1].set_title('Accuracy')
ax[2].plot(trace[2])
ax[2].set_title('Step Size')
plt.show()

بي إن جي

إذا قمنا بتجميع العينات ، فيمكننا حساب متوسط ​​نموذج بايزي لتحسين أدائنا.

@functools.partial(jax.pmap, axis_name='data', in_axes=(0, None), out_axes=None)
def bayesian_model_average(data, states):
  images, labels = data
  logits = jax.vmap(lambda w, b: images.dot(w) + b)(*states)
  probs = jax.nn.softmax(logits, axis=-1)
  bma_accuracy = (probs.mean(axis=0).argmax(axis=-1) == labels).mean()
  avg_accuracy = (probs.argmax(axis=-1) == labels).mean()
  return lax.pmean(bma_accuracy, axis_name='data'), lax.pmean(avg_accuracy, axis_name='data')

sharded_test_images, sharded_test_labels = shard((test_images, test_labels))
bma_acc, avg_acc = bayesian_model_average((sharded_test_images, sharded_test_labels), states)
print(f'Average Accuracy: {avg_acc}')
print(f'BMA Accuracy: {bma_acc}')
print(f'Accuracy Improvement: {bma_acc - avg_acc}')
Average Accuracy: 0.9188529253005981
BMA Accuracy: 0.9264000058174133
Accuracy Improvement: 0.0075470805168151855

يزيد متوسط ​​نموذج بايزي من دقتنا بنسبة 1٪ تقريبًا!

مثال: نظام توصية MovieLens

دعنا الآن نحاول الاستدلال باستخدام مجموعة بيانات توصيات MovieLens ، وهي مجموعة من المستخدمين وتقييماتهم للأفلام المختلفة. على وجه التحديد، ونحن يمكن أن تمثل MovieLens باعتبارها \(N \times M\) ساعة مصفوفة \(W\) حيث \(N\) هو عدد المستخدمين و \(M\) هو عدد من الأفلام. نتوقع \(N > M\). مداخل \(W_{ij}\) هي منطقية تشير أم لا المستعمل \(i\) شاهدت فيلم \(j\). لاحظ أن برنامج MovieLens يوفر تقييمات للمستخدمين ، لكننا نتجاهلها لتبسيط المشكلة.

أولاً ، سنقوم بتحميل مجموعة البيانات. سنستخدم الإصدار بمليون تقييم.

movielens = tfds.as_numpy(tfds.load('movielens/1m-ratings', batch_size=-1))
GENRES = ['Action', 'Adventure', 'Animation', 'Children', 'Comedy',
          'Crime', 'Documentary', 'Drama', 'Fantasy', 'Film-Noir',
          'Horror', 'IMAX', 'Musical', 'Mystery', 'Romance', 'Sci-Fi',
          'Thriller', 'Unknown', 'War', 'Western', '(no genres listed)']
Downloading and preparing dataset movielens/1m-ratings/0.1.0 (download: Unknown size, generated: Unknown size, total: Unknown size) to /root/tensorflow_datasets/movielens/1m-ratings/0.1.0...
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Dl Completed...', max=1.0, style=Progre…
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Dl Size...', max=1.0, style=ProgressSty…
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Extraction completed...', max=1.0, styl…

HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))
Shuffling and writing examples to /root/tensorflow_datasets/movielens/1m-ratings/0.1.0.incompleteYKA3TG/movielens-train.tfrecord
HBox(children=(FloatProgress(value=0.0, max=1000209.0), HTML(value='')))
Dataset movielens downloaded and prepared to /root/tensorflow_datasets/movielens/1m-ratings/0.1.0. Subsequent calls will reuse this data.

سنفعل بعض تجهيزها من مجموعة البيانات للحصول على ساعة مصفوفة \(W\).

raw_movie_ids = movielens['train']['movie_id']
raw_user_ids = movielens['train']['user_id']
genres = movielens['train']['movie_genres']

movie_ids, movie_labels = pd.factorize(movielens['train']['movie_id'])
user_ids, user_labels = pd.factorize(movielens['train']['user_id'])

num_movies = movie_ids.max() + 1
num_users = user_ids.max() + 1

movie_titles = dict(zip(movielens['train']['movie_id'],
                        movielens['train']['movie_title']))
movie_genres = dict(zip(movielens['train']['movie_id'],
                        genres))
movie_id_to_title = [movie_titles[movie_labels[id]].decode('utf-8')
                     for id in range(num_movies)]
movie_id_to_genre = [GENRES[movie_genres[movie_labels[id]][0]] for id in range(num_movies)]

watch_matrix = np.zeros((num_users, num_movies), bool)
watch_matrix[user_ids, movie_ids] = True
print(watch_matrix.shape)
(6040, 3706)

يمكننا تحديد نموذج توليدي ل \(W\)، وذلك باستخدام مصفوفة الاحتمالية نموذج توكيل تجاري بسيط. ونحن نفترض كامنة \(N \times D\) مصفوفة المستخدم \(U\) وكامنة \(M \times D\) الفيلم مصفوفة \(V\)، والتي عندما تضاعف إنتاج logits من برنولي للساعة مصفوفة \(W\). ونحن سوف تشمل أيضا ناقلات التحيز للمستخدمين والأفلام، \(u\) و \(v\).

\[ \begin{align*} U &\sim \mathcal{N}(0, 1) \quad u \sim \mathcal{N}(0, 1)\\ V &\sim \mathcal{N}(0, 1) \quad v \sim \mathcal{N}(0, 1)\\ W_{ij} &\sim \textrm{Bernoulli}\left(\sigma\left(\left(UV^T\right)_{ij} + u_i + v_j\right)\right) \end{align*} \]

هذه مصفوفة كبيرة جدًا ؛ 6040 مستخدم و 3706 فيلم يؤدي إلى مصفوفة بها أكثر من 22 مليون مدخل. كيف نقترب من تجزئة هذا النموذج؟ حسنا، إذا افترضنا أن \(N > M\) (أي هناك المزيد من المستخدمين من الأفلام)، فإنه سيكون من المنطقي أن شارد مصفوفة ساعة عبر محور المستخدم، بحيث كل جهاز سيكون له قطعة من الساعات المصفوفة المقابلة لمجموعة فرعية من المستخدمين . على عكس المثال السابق، ومع ذلك، سيكون لدينا أيضا إلى شارد حتى \(U\) المصفوفة، نظرا لما له من التضمين لكل مستخدم، بحيث كل جهاز ستكون مسؤولة عن كسرة فخارية من \(U\) وكسرة فخارية من \(W\). من ناحية أخرى، \(V\) سيكون unsharded وتكون متزامنة عبر الأجهزة.

sharded_watch_matrix = shard(watch_matrix)

قبل نكتب لدينا run ، دعونا بسرعة مناقشة تحديات إضافية مع عملية التجزئة المحلي متغير عشوائي \(U\). عند تشغيل مؤسسة حمد الطبية، والفانيليا tfp.mcmc.HamiltonianMonteCarlo نواة سوف تذوق العزم لكل عنصر من عناصر الدولة في السلسلة. في السابق ، كانت المتغيرات العشوائية غير المقيدة فقط جزءًا من تلك الحالة ، وكانت العزم هي نفسها على كل جهاز. عندما يكون لدينا الآن sharded \(U\)، نحن بحاجة إلى أخذ عينات العزم مختلفة على كل جهاز ل \(U\)، في حين أخذ عينات من نفس العزم ل \(V\). ولتحقيق ذلك، يمكننا استخدام tfp.experimental.mcmc.PreconditionedHamiltonianMonteCarlo مع Sharded توزيع العزم. بينما نستمر في إجراء الحساب المتوازي من الدرجة الأولى ، يمكننا تبسيط ذلك ، على سبيل المثال عن طريق أخذ مؤشر الحدة إلى نواة HMC.

def make_run(*,
             axis_name,
             dim=20,
             num_chains=2,
             prior_variance=1.,
             step_size=1e-2,
             num_leapfrog_steps=100,
             num_burnin_steps=1000,
             num_results=500,
             ):
  @functools.partial(jax.pmap, in_axes=(None, 0), axis_name=axis_name)
  @jax.default_matmul_precision('tensorfloat32')
  def run(key, watch_matrix):
    num_users, num_movies = watch_matrix.shape

    Sharded = functools.partial(tfed.Sharded, shard_axis_name=axis_name)

    def prior_fn():
      user_embeddings = yield Root(Sharded(tfd.Sample(tfd.Normal(0., 1.), [num_users, dim]), name='user_embeddings'))
      user_bias = yield Root(Sharded(tfd.Sample(tfd.Normal(0., 1.), [num_users]), name='user_bias'))
      movie_embeddings = yield Root(tfd.Sample(tfd.Normal(0., 1.), [num_movies, dim], name='movie_embeddings'))
      movie_bias = yield Root(tfd.Sample(tfd.Normal(0., 1.), [num_movies], name='movie_bias'))
      return (user_embeddings, user_bias, movie_embeddings, movie_bias)
    prior = tfed.JointDistributionCoroutine(prior_fn)

    def model_fn():
      user_embeddings, user_bias, movie_embeddings, movie_bias = yield from prior_fn()
      logits = (jnp.einsum('...nd,...md->...nm', user_embeddings, movie_embeddings)
                + user_bias[..., :, None] + movie_bias[..., None, :])
      yield Sharded(tfd.Independent(tfd.Bernoulli(logits=logits), 2), name='watch')
    model = tfed.JointDistributionCoroutine(model_fn)

    init_key, sample_key = random.split(key)
    initial_state = prior.sample(seed=init_key, sample_shape=num_chains)

    def target_log_prob(*state):
      return model.log_prob((*state, watch_matrix))

    momentum_distribution = tfed.JointDistributionSequential([
      Sharded(tfd.Independent(tfd.Normal(jnp.zeros([num_chains, num_users, dim]), 1.), 2)),
      Sharded(tfd.Independent(tfd.Normal(jnp.zeros([num_chains, num_users]), 1.), 1)),
      tfd.Independent(tfd.Normal(jnp.zeros([num_chains, num_movies, dim]), 1.), 2),
      tfd.Independent(tfd.Normal(jnp.zeros([num_chains, num_movies]), 1.), 1),
    ])

    # We pass in momentum_distribution here to ensure that the momenta for 
    # user_embeddings and user_bias are also sharded
    kernel = tfem.PreconditionedHamiltonianMonteCarlo(target_log_prob, step_size,
                                                      num_leapfrog_steps,
                                                      momentum_distribution=momentum_distribution)

    num_adaptation_steps = int(0.8 * num_burnin_steps)
    kernel = tfm.DualAveragingStepSizeAdaptation(kernel, num_adaptation_steps)

    def trace_fn(state, pkr):
      return {
        'log_prob': target_log_prob(*state),
        'log_accept_ratio': pkr.inner_results.log_accept_ratio,
      }
    return tfm.sample_chain(
        num_results, initial_state,
        kernel=kernel,
        num_burnin_steps=num_burnin_steps,
        trace_fn=trace_fn,
        seed=sample_key)
  return run

ونحن مرة أخرى سوف تشغيله مرة واحدة إلى ذاكرة التخزين المؤقت المترجمة run .

%%time
run = make_run(axis_name='data')
output = run(random.PRNGKey(0), sharded_watch_matrix)
jax.tree_map(lambda x: x.block_until_ready(), output)
CPU times: user 56 s, sys: 1min 24s, total: 2min 20s
Wall time: 3min 35s

الآن سنقوم بتشغيله مرة أخرى دون تحميل النفقات العامة.

%%time
states, trace = run(random.PRNGKey(0), sharded_watch_matrix)
jax.tree_map(lambda x: x.block_until_ready(), trace)
CPU times: user 28.8 s, sys: 1min 16s, total: 1min 44s
Wall time: 3min 1s

يبدو أننا أكملنا حوالي 150000 خطوة قفزة في حوالي 3 دقائق ، أي حوالي 83 خطوة قفزة في الثانية! دعنا نرسم نسبة القبول وكثافة اللوغاريتمات لعيناتنا.

fig, axs = plt.subplots(1, len(trace), figsize=(5 * len(trace), 5))
for ax, (key, val) in zip(axs, trace.items()):
  ax.plot(val[0]) # Indexing into a sharded array, each element is the same
  ax.set_title(key);

بي إن جي

الآن بعد أن أصبح لدينا بعض العينات من سلسلة ماركوف ، فلنستخدمها لعمل بعض التنبؤات. أولاً ، دعنا نستخرج كل مكون. تذكر أن user_embeddings و user_bias هي الانقسام عبر الجهاز، لذلك نحن بحاجة إلى سلسلة لدينا ShardedArray للحصول على كل منهم. من ناحية أخرى، movie_embeddings و movie_bias هي نفسها على كل جهاز، حتى نتمكن من مجرد اختيار القيمة من قشرة الأولى. سنستخدم العادية numpy لنسخ القيم من الخلف TPUs إلى وحدة المعالجة المركزية.

user_embeddings = np.concatenate(np.array(states.user_embeddings, np.float32), axis=2)
user_bias = np.concatenate(np.array(states.user_bias, np.float32), axis=2)
movie_embeddings = np.array(states.movie_embeddings[0], dtype=np.float32)
movie_bias = np.array(states.movie_bias[0], dtype=np.float32)
samples = (user_embeddings, user_bias, movie_embeddings, movie_bias)
print(f'User embeddings: {user_embeddings.shape}')
print(f'User bias: {user_bias.shape}')
print(f'Movie embeddings: {movie_embeddings.shape}')
print(f'Movie bias: {movie_bias.shape}')
User embeddings: (500, 2, 6040, 20)
User bias: (500, 2, 6040)
Movie embeddings: (500, 2, 3706, 20)
Movie bias: (500, 2, 3706)

دعنا نحاول بناء نظام توصية بسيط يستخدم عدم اليقين الذي تم تسجيله في هذه العينات. لنكتب أولاً دالة تصنف الأفلام وفقًا لاحتمال المشاهدة.

@jax.jit
def recommend(sample, user_id):
  user_embeddings, user_bias, movie_embeddings, movie_bias = sample
  movie_logits = (
      jnp.einsum('d,md->m', user_embeddings[user_id], movie_embeddings)
      + user_bias[user_id] + movie_bias)
  return movie_logits.argsort()[::-1]

يمكننا الآن كتابة وظيفة تتكرر في جميع العينات ولكل منها ، تختار الفيلم الأعلى تصنيفًا الذي لم يشاهده المستخدم بالفعل. يمكننا بعد ذلك رؤية أعداد جميع الأفلام الموصى بها عبر العينات.

def get_recommendations(user_id): 
  movie_ids = []
  already_watched = set(jnp.arange(num_movies)[watch_matrix[user_id] == 1])
  for i in range(500):
    for j in range(2):
      sample = jax.tree_map(lambda x: x[i, j], samples)
      ranking = recommend(sample, user_id)
      for movie_id in ranking:
        if int(movie_id) not in already_watched:
          movie_ids.append(movie_id)
          break
  return movie_ids

def plot_recommendations(movie_ids, ax=None):
  titles = collections.Counter([movie_id_to_title[i] for i in movie_ids])
  ax = ax or plt.gca()
  names, counts = zip(*sorted(titles.items(), key=lambda x: -x[1]))
  ax.bar(names, counts)
  ax.set_xticklabels(names, rotation=90)

لنأخذ المستخدم الذي شاهد معظم الأفلام مقابل المستخدم الذي شاهد أقل عدد من الأفلام.

user_watch_counts = watch_matrix.sum(axis=1)
user_most = user_watch_counts.argmax()
user_least = user_watch_counts.argmin()
print(user_watch_counts[user_most], user_watch_counts[user_least])
2314 20

نأمل نظامنا أكبر من اليقين حول user_most من user_least ، بالنظر إلى أن لدينا المزيد من المعلومات حول ما يفرز من الأفلام user_most هو أكثر عرضة للمشاهدة.

fig, ax = plt.subplots(1, 2, figsize=(20, 10))
most_recommendations = get_recommendations(user_most)
plot_recommendations(most_recommendations, ax=ax[0])
ax[0].set_title('Recommendation for user_most')
least_recommendations = get_recommendations(user_least)
plot_recommendations(least_recommendations, ax=ax[1])
ax[1].set_title('Recommendation for user_least');

بي إن جي

ونحن نرى أن هناك المزيد من التباين في توصياتنا لل user_least يعكس لدينا المزيد من عدم اليقين في تفضيلات مراقبتهم.

يمكننا أيضًا الاطلاع على أنواع الأفلام الموصى بها.

most_genres = collections.Counter([movie_id_to_genre[i] for i in most_recommendations])
least_genres = collections.Counter([movie_id_to_genre[i] for i in least_recommendations])
fig, ax = plt.subplots(1, 2, figsize=(20, 10))
ax[0].bar(most_genres.keys(), most_genres.values())
ax[0].set_title('Genres recommended for user_most')
ax[1].bar(least_genres.keys(), least_genres.values())
ax[1].set_title('Genres recommended for user_least');

بي إن جي

user_most شهدت الكثير من الأفلام والموصى بها أكثر الأنواع المتخصصة مثل الغموض والجريمة في حين user_least لم يشاهد العديد من الأفلام، وأوصى المزيد من الأفلام السائدة، التي الكوميديا الانحراف والعمل.