JAX . पर TensorFlow की संभावना

TensorFlow.org पर देखें Google Colab में चलाएं GitHub पर स्रोत देखें नोटबुक डाउनलोड करें

TensorFlow संभावना (टीएफपी) संभाव्य तर्क और सांख्यिकीय विश्लेषण है कि अब भी पर काम करता है के लिए एक पुस्तकालय है JAX ! उन लोगों के लिए जो परिचित नहीं हैं, JAX कंपोज़ेबल फंक्शन ट्रांसफ़ॉर्मेशन के आधार पर त्वरित संख्यात्मक कंप्यूटिंग के लिए एक पुस्तकालय है।

जेएक्स पर टीएफपी नियमित टीएफपी की सबसे उपयोगी कार्यक्षमता का समर्थन करता है जबकि अमूर्त और एपीआई को संरक्षित करते हुए कई टीएफपी उपयोगकर्ता अब सहज हैं।

सेट अप

TFP JAX पर TensorFlow पर निर्भर नहीं करता; आइए इस Colab से TensorFlow को पूरी तरह से अनइंस्टॉल कर दें।

pip uninstall tensorflow -y -q

हम TFP के नवीनतम रात्रिकालीन निर्माण के साथ JAX पर TFP स्थापित कर सकते हैं।

pip install -Uq tfp-nightly[jax] > /dev/null

आइए कुछ उपयोगी पायथन पुस्तकालयों को आयात करें।

import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
from sklearn import datasets
sns.set(style='white')
/usr/local/lib/python3.6/dist-packages/statsmodels/tools/_testing.py:19: FutureWarning: pandas.util.testing is deprecated. Use the functions in the public API at pandas.testing instead.
  import pandas.util.testing as tm

आइए कुछ बुनियादी JAX कार्यक्षमता भी आयात करें।

import jax.numpy as jnp
from jax import grad
from jax import jit
from jax import random
from jax import value_and_grad
from jax import vmap

JAX . पर TFP आयात करना

JAX पर TFP का उपयोग करने के लिए बस आयात jax "सब्सट्रेट" और इसका इस्तेमाल के रूप में आप आमतौर पर होता tfp :

from tensorflow_probability.substrates import jax as tfp
tfd = tfp.distributions
tfb = tfp.bijectors
tfpk = tfp.math.psd_kernels

डेमो: बायेसियन लॉजिस्टिक रिग्रेशन

यह प्रदर्शित करने के लिए कि हम JAX बैकएंड के साथ क्या कर सकते हैं, हम क्लासिक आइरिस डेटासेट पर लागू बायेसियन लॉजिस्टिक रिग्रेशन को लागू करेंगे।

सबसे पहले, आइए आइरिस डेटासेट आयात करें और कुछ मेटाडेटा निकालें।

iris = datasets.load_iris()
features, labels = iris['data'], iris['target']

num_features = features.shape[-1]
num_classes = len(iris.target_names)

हम का उपयोग कर मॉडल को परिभाषित कर सकते tfd.JointDistributionCoroutine । हम दोनों वजन और पूर्वाग्रह अवधि पर मानक सामान्य महंतों डाल देता हूँ तो एक लिखने target_log_prob समारोह है कि पिन डेटा करने के लिए नमूने लेबल।

Root = tfd.JointDistributionCoroutine.Root
def model():
  w = yield Root(tfd.Sample(tfd.Normal(0., 1.),
                            sample_shape=(num_features, num_classes)))
  b = yield Root(
      tfd.Sample(tfd.Normal(0., 1.), sample_shape=(num_classes,)))
  logits = jnp.dot(features, w) + b
  yield tfd.Independent(tfd.Categorical(logits=logits),
                        reinterpreted_batch_ndims=1)


dist = tfd.JointDistributionCoroutine(model)
def target_log_prob(*params):
  return dist.log_prob(params + (labels,))

हम से नमूना dist एमसीएमसी के लिए एक प्रारंभिक राज्य निर्माण करने के लिए। फिर हम एक फ़ंक्शन को परिभाषित कर सकते हैं जो एक यादृच्छिक कुंजी और प्रारंभिक स्थिति लेता है, और नो-यू-टर्न-सैंपलर (एनयूटीएस) से 500 नमूने तैयार करता है। ध्यान दें कि हम जैसे JAX परिवर्तनों का उपयोग कर सकते jit XLA का उपयोग कर हमारे पागल नमूना संकलित करने के लिए।

init_key, sample_key = random.split(random.PRNGKey(0))
init_params = tuple(dist.sample(seed=init_key)[:-1])

@jit
def run_chain(key, state):
  kernel = tfp.mcmc.NoUTurnSampler(target_log_prob, 1e-3)
  return tfp.mcmc.sample_chain(500,
      current_state=state,
      kernel=kernel,
      trace_fn=lambda _, results: results.target_log_prob,
      num_burnin_steps=500,
      seed=key)

states, log_probs = run_chain(sample_key, init_params)
plt.figure()
plt.plot(log_probs)
plt.ylabel('Target Log Prob')
plt.xlabel('Iterations of NUTS')
plt.show()

पीएनजी

आइए हमारे नमूनों का उपयोग वजन के प्रत्येक सेट की अनुमानित संभावनाओं के औसत से बायेसियन मॉडल औसत (बीएमए) करने के लिए करें।

आइए पहले एक फ़ंक्शन लिखें जो दिए गए मापदंडों के सेट के लिए प्रत्येक वर्ग पर संभावनाओं का उत्पादन करेगा। हम उपयोग कर सकते हैं dist.sample_distributions मॉडल में अंतिम वितरण प्राप्त करने के लिए।

def classifier_probs(params):
  dists, _ = dist.sample_distributions(seed=random.PRNGKey(0),
                                       value=params + (None,))
  return dists[-1].distribution.probs_parameter()

हम कर सकते हैं vmap(classifier_probs) नमूने के समूह के ऊपर हमारे नमूने से प्रत्येक के लिए भविष्यवाणी की वर्ग संभावनाओं को पाने के लिए। फिर हम प्रत्येक नमूने में औसत सटीकता की गणना करते हैं, और बायेसियन मॉडल औसत से सटीकता की गणना करते हैं।

all_probs = jit(vmap(classifier_probs))(states)
print('Average accuracy:', jnp.mean(all_probs.argmax(axis=-1) == labels))
print('BMA accuracy:', jnp.mean(all_probs.mean(axis=0).argmax(axis=-1) == labels))
Average accuracy: 0.96952
BMA accuracy: 0.97999996

ऐसा लगता है कि BMA हमारी त्रुटि दर को लगभग एक तिहाई कम कर देता है!

बुनियादी बातों

TFP JAX पर TF के लिए एक समान एपीआई जहां TF वस्तुओं स्वीकार करने के बजाय की तरह है tf.Tensor है यह JAX एनालॉग स्वीकार करता है। उदाहरण के लिए, जहाँ भी एक tf.Tensor पहले से इनपुट के रूप में इस्तेमाल किया गया था, एपीआई अब एक JAX उम्मीद DeviceArray । इसके बजाय एक लौटने का tf.Tensor , TFP तरीकों वापस आ जाएगी DeviceArray रों। TFP JAX पर भी JAX वस्तुओं की नेस्टेड संरचनाओं, की एक सूची या शब्दकोश की तरह साथ काम करता है DeviceArray रों।

वितरण

TFP के अधिकांश वितरण JAX में उनके TF समकक्षों के समान समानार्थक शब्दों के साथ समर्थित हैं। उन्होंने यह भी रूप में पंजीकृत हैं JAX Pytrees , तो वे इनपुट और JAX-बदल कार्यों के आउटपुट हो सकता है।

बुनियादी वितरण

log_prob वितरण के लिए विधि एक ही काम करता है।

dist = tfd.Normal(0., 1.)
print(dist.log_prob(0.))
-0.9189385

एक वितरण से नमूना स्पष्ट रूप से एक में गुजर आवश्यकता PRNGKey के रूप में (पूर्णांकों की सूची या) seed कीवर्ड तर्क। एक बीज में स्पष्ट रूप से पारित करने में विफल होने पर एक त्रुटि होगी।

tfd.Normal(0., 1.).sample(seed=random.PRNGKey(0))
DeviceArray(-0.20584226, dtype=float32)

वितरण के लिए आकार अर्थ विज्ञान JAX, जहां वितरण प्रत्येक एक होगा में ही रहते हैं event_shape और एक batch_shape और कई नमूने ड्राइंग अतिरिक्त जोड़ देगा sample_shape आयाम।

उदाहरण के लिए, एक tfd.MultivariateNormalDiag वेक्टर मानकों के साथ एक वेक्टर घटना आकार और खाली बैच आकार होगा।

dist = tfd.MultivariateNormalDiag(
    loc=jnp.zeros(5),
    scale_diag=jnp.ones(5)
)
print('Event shape:', dist.event_shape)
print('Batch shape:', dist.batch_shape)
Event shape: (5,)
Batch shape: ()

दूसरी ओर, एक tfd.Normal वैक्टर साथ parameterized एक अदिश घटना आकार और वेक्टर बैच आकार होगा।

dist = tfd.Normal(
    loc=jnp.ones(5),
    scale=jnp.ones(5),
)
print('Event shape:', dist.event_shape)
print('Batch shape:', dist.batch_shape)
Event shape: ()
Batch shape: (5,)

लेने का अर्थ विज्ञान log_prob नमूनों की भी JAX में एक ही काम करता है।

dist =  tfd.Normal(jnp.zeros(5), jnp.ones(5))
s = dist.sample(sample_shape=(10, 2), seed=random.PRNGKey(0))
print(dist.log_prob(s).shape)

dist =  tfd.Independent(tfd.Normal(jnp.zeros(5), jnp.ones(5)), 1)
s = dist.sample(sample_shape=(10, 2), seed=random.PRNGKey(0))
print(dist.log_prob(s).shape)
(10, 2, 5)
(10, 2)

क्योंकि JAX DeviceArray रों NumPy और matplotlib तरह पुस्तकालयों के साथ संगत कर रहे हैं, हम एक साजिश रचने समारोह में सीधे नमूने फ़ीड कर सकते हैं।

sns.distplot(tfd.Normal(0., 1.).sample(1000, seed=random.PRNGKey(0)))
plt.show()

पीएनजी

Distribution के तरीकों JAX परिवर्तनों के साथ संगत कर रहे हैं।

sns.distplot(jit(vmap(lambda key: tfd.Normal(0., 1.).sample(seed=key)))(
    random.split(random.PRNGKey(0), 2000)))
plt.show()

पीएनजी

x = jnp.linspace(-5., 5., 100)
plt.plot(x, jit(vmap(grad(tfd.Normal(0., 1.).prob)))(x))
plt.show()

पीएनजी

TFP वितरण JAX pytree नोड्स के रूप में पंजीकृत हैं, इसलिए हम आदानों या आउटपुट के रूप में वितरण के साथ काम करता है लिख सकते हैं और का उपयोग कर उन्हें बदलने jit , लेकिन वे अभी तक तर्क के रूप में समर्थित नहीं हैं vmap एड कार्य करता है।

@jit
def random_distribution(key):
  loc_key, scale_key = random.split(key)
  loc, log_scale = random.normal(loc_key), random.normal(scale_key)
  return tfd.Normal(loc, jnp.exp(log_scale))
random_dist = random_distribution(random.PRNGKey(0))
print(random_dist.mean(), random_dist.variance())
0.14389051 0.081832744

रूपांतरित वितरण

बदल वितरण यानी वितरण जिसका नमूने एक के माध्यम से पारित कर रहे हैं Bijector भी बॉक्स से बाहर काम करते हैं (bijectors भी काम करते हैं! नीचे देखें)।

dist = tfd.TransformedDistribution(
    tfd.Normal(0., 1.),
    tfb.Sigmoid()
)
sns.distplot(dist.sample(1000, seed=random.PRNGKey(0)))
plt.show()

पीएनजी

संयुक्त वितरण

TFP प्रदान करता है JointDistribution रों कई यादृच्छिक परिवर्तनीय आने वाले एक वितरण में घटक वितरण के संयोजन सक्षम करने के लिए। वर्तमान में, TFP प्रस्तावों तीन मुख्य वेरिएंट ( JointDistributionSequential , JointDistributionNamed , और JointDistributionCoroutine ) जो सभी के JAX में समर्थित हैं। AutoBatched वेरिएंट भी सभी समर्थित हैं।

dist = tfd.JointDistributionSequential([
  tfd.Normal(0., 1.),
  lambda x: tfd.Normal(x, 1e-1)
])
plt.scatter(*dist.sample(1000, seed=random.PRNGKey(0)), alpha=0.5)
plt.show()

पीएनजी

joint = tfd.JointDistributionNamed(dict(
    e=             tfd.Exponential(rate=1.),
    n=             tfd.Normal(loc=0., scale=2.),
    m=lambda n, e: tfd.Normal(loc=n, scale=e),
    x=lambda    m: tfd.Sample(tfd.Bernoulli(logits=m), 12),
))
joint.sample(seed=random.PRNGKey(0))
{'e': DeviceArray(3.376818, dtype=float32),
 'm': DeviceArray(2.5449684, dtype=float32),
 'n': DeviceArray(-0.6027825, dtype=float32),
 'x': DeviceArray([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], dtype=int32)}
Root = tfd.JointDistributionCoroutine.Root
def model():
  e = yield Root(tfd.Exponential(rate=1.))
  n = yield Root(tfd.Normal(loc=0, scale=2.))
  m = yield tfd.Normal(loc=n, scale=e)
  x = yield tfd.Sample(tfd.Bernoulli(logits=m), 12)

joint = tfd.JointDistributionCoroutine(model)

joint.sample(seed=random.PRNGKey(0))
StructTuple(var0=DeviceArray(0.17315261, dtype=float32), var1=DeviceArray(-3.290489, dtype=float32), var2=DeviceArray(-3.1949058, dtype=float32), var3=DeviceArray([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], dtype=int32))

अन्य वितरण

गाऊसी प्रक्रियाएं भी JAX मोड में काम करती हैं!

k1, k2, k3 = random.split(random.PRNGKey(0), 3)
observation_noise_variance = 0.01
f = lambda x: jnp.sin(10*x[..., 0]) * jnp.exp(-x[..., 0]**2)
observation_index_points = random.uniform(
    k1, [50], minval=-1.,maxval= 1.)[..., jnp.newaxis]
observations = f(observation_index_points) + tfd.Normal(
    loc=0., scale=jnp.sqrt(observation_noise_variance)).sample(seed=k2)

index_points = jnp.linspace(-1., 1., 100)[..., jnp.newaxis]

kernel = tfpk.ExponentiatedQuadratic(length_scale=0.1)

gprm = tfd.GaussianProcessRegressionModel(
    kernel=kernel,
    index_points=index_points,
    observation_index_points=observation_index_points,
    observations=observations,
    observation_noise_variance=observation_noise_variance)

samples = gprm.sample(10, seed=k3)
for i in range(10):
  plt.plot(index_points, samples[i], alpha=0.5)
plt.plot(observation_index_points, observations, marker='o', linestyle='')
plt.show()

पीएनजी

हिडन मार्कोव मॉडल भी समर्थित हैं।

initial_distribution = tfd.Categorical(probs=[0.8, 0.2])
transition_distribution = tfd.Categorical(probs=[[0.7, 0.3],
                                                 [0.2, 0.8]])

observation_distribution = tfd.Normal(loc=[0., 15.], scale=[5., 10.])

model = tfd.HiddenMarkovModel(
    initial_distribution=initial_distribution,
    transition_distribution=transition_distribution,
    observation_distribution=observation_distribution,
    num_steps=7)

print(model.mean())
print(model.log_prob(jnp.zeros(7)))
print(model.sample(seed=random.PRNGKey(0)))
[3.       6.       7.5      8.249999 8.625001 8.812501 8.90625 ]
/usr/local/lib/python3.6/dist-packages/tensorflow_probability/substrates/jax/distributions/hidden_markov_model.py:483: UserWarning: HiddenMarkovModel.log_prob in TFP versions < 0.12.0 had a bug in which the transition model was applied prior to the initial step. This bug has been fixed. You may observe a slight change in behavior.
  'HiddenMarkovModel.log_prob in TFP versions < 0.12.0 had a bug '
-19.855635
[ 1.3641367  0.505798   1.3626463  3.6541772  2.272286  15.10309
 22.794212 ]

की तरह कुछ वितरण PixelCNN TensorFlow या XLA असंगतियां पर सख्त निर्भरता की वजह से अभी तक समर्थित नहीं हैं।

बिजेक्टर

TFP के अधिकांश बायजेक्टर आज JAX में समर्थित हैं!

tfb.Exp().inverse(1.)
DeviceArray(0., dtype=float32)
bij = tfb.Shift(1.)(tfb.Scale(3.))
print(bij.forward(jnp.ones(5)))
print(bij.inverse(jnp.ones(5)))
[4. 4. 4. 4. 4.]
[0. 0. 0. 0. 0.]
b = tfb.FillScaleTriL(diag_bijector=tfb.Exp(), diag_shift=None)
print(b.forward(x=[0., 0., 0.]))
print(b.inverse(y=[[1., 0], [.5, 2]]))
[[1. 0.]
 [0. 1.]]
[0.6931472 0.5       0.       ]
b = tfb.Chain([tfb.Exp(), tfb.Softplus()])
# or:
# b = tfb.Exp()(tfb.Softplus())
print(b.forward(-jnp.ones(5)))
[1.3678794 1.3678794 1.3678794 1.3678794 1.3678794]

Bijectors तरह JAX परिवर्तनों के साथ संगत कर रहे हैं jit , grad और vmap

jit(vmap(tfb.Exp().inverse))(jnp.arange(4.))
DeviceArray([     -inf, 0.       , 0.6931472, 1.0986123], dtype=float32)
x = jnp.linspace(0., 1., 100)
plt.plot(x, jit(grad(lambda x: vmap(tfb.Sigmoid().inverse)(x).sum()))(x))
plt.show()

पीएनजी

कुछ bijectors, जैसे RealNVP और FFJORD अभी तक समर्थित नहीं हैं।

एमसीएमसी

हम पोर्ट किया है tfp.mcmc रूप में अच्छी तरह JAX के लिए, तो हम Hamiltonian मोंटे कार्लो (एचएमसी) और JAX में नो-यू-टर्न-नमूना (पागल) की तरह एल्गोरिदम चला सकते हैं।

target_log_prob = tfd.MultivariateNormalDiag(jnp.zeros(2), jnp.ones(2)).log_prob

TFP TF पर विपरीत, हम एक पास करना आवश्यक है PRNGKey में sample_chain का उपयोग कर seed कीवर्ड तर्क।

def run_chain(key, state):
  kernel = tfp.mcmc.NoUTurnSampler(target_log_prob, 1e-1)
  return tfp.mcmc.sample_chain(1000,
      current_state=state,
      kernel=kernel,
      trace_fn=lambda _, results: results.target_log_prob,
      seed=key)
states, log_probs = jit(run_chain)(random.PRNGKey(0), jnp.zeros(2))
plt.figure()
plt.scatter(*states.T, alpha=0.5)
plt.figure()
plt.plot(log_probs)
plt.show()

पीएनजी

पीएनजी

कई चेन चलाने के लिए, हम या तो राज्यों का एक बैच में पारित कर सकते हैं sample_chain या उपयोग vmap (हालांकि हम अभी तक दो दृष्टिकोणों के बीच प्रदर्शन अंतर का पता लगाया नहीं किया है)।

states, log_probs = jit(run_chain)(random.PRNGKey(0), jnp.zeros([10, 2]))
plt.figure()
for i in range(10):
  plt.scatter(*states[:, i].T, alpha=0.5)
plt.figure()
for i in range(10):
  plt.plot(log_probs[:, i], alpha=0.5)
plt.show()

पीएनजी

पीएनजी

अनुकूलक

जेएक्स पर टीएफपी बीएफजीएस और एल-बीएफजीएस जैसे कुछ महत्वपूर्ण अनुकूलकों का समर्थन करता है। आइए एक साधारण स्केल्ड द्विघात हानि फ़ंक्शन सेट करें।

minimum = jnp.array([1.0, 1.0])  # The center of the quadratic bowl.
scales = jnp.array([2.0, 3.0])  # The scales along the two axes.

# The objective function and the gradient.
def quadratic_loss(x):
  return jnp.sum(scales * jnp.square(x - minimum))

start = jnp.array([0.6, 0.8])  # Starting point for the search.

बीएफजीएस इस नुकसान का न्यूनतम पता लगा सकता है।

optim_results = tfp.optimizer.bfgs_minimize(
    value_and_grad(quadratic_loss), initial_position=start, tolerance=1e-8)

# Check that the search converged
assert(optim_results.converged)
# Check that the argmin is close to the actual value.
np.testing.assert_allclose(optim_results.position, minimum)
# Print out the total number of function evaluations it took. Should be 5.
print("Function evaluations: %d" % optim_results.num_objective_evaluations)
Function evaluations: 5

तो एल-बीएफजीएस कर सकते हैं।

optim_results = tfp.optimizer.lbfgs_minimize(
    value_and_grad(quadratic_loss), initial_position=start, tolerance=1e-8)

# Check that the search converged
assert(optim_results.converged)
# Check that the argmin is close to the actual value.
np.testing.assert_allclose(optim_results.position, minimum)
# Print out the total number of function evaluations it took. Should be 5.
print("Function evaluations: %d" % optim_results.num_objective_evaluations)
Function evaluations: 5

करने के लिए vmap एल BFGS, सेट एक समारोह है कि एक ही प्रारंभिक बिंदु के लिए नुकसान का अनुकूलन अप करते हैं।

def optimize_single(start):
  return tfp.optimizer.lbfgs_minimize(
      value_and_grad(quadratic_loss), initial_position=start, tolerance=1e-8)

all_results = jit(vmap(optimize_single))(
    random.normal(random.PRNGKey(0), (10, 2)))
assert all(all_results.converged)
for i in range(10):
  np.testing.assert_allclose(optim_results.position[i], minimum)
print("Function evaluations: %s" % all_results.num_objective_evaluations)
Function evaluations: [6 6 9 6 6 8 6 8 5 9]

चेतावनियां

TF और JAX के बीच कुछ मूलभूत अंतर हैं, कुछ TFP व्यवहार दो सबस्ट्रेट्स के बीच भिन्न होंगे और सभी कार्यक्षमता समर्थित नहीं हैं। उदाहरण के लिए,

  • TFP JAX पर ऐसा कुछ का समर्थन नहीं करता tf.Variable कुछ भी नहीं के बाद से है जैसे कि यह JAX में मौजूद है। यह भी मतलब है की तरह उपयोगिताओं tfp.util.TransformedVariable या तो समर्थित नहीं हैं।
  • tfp.layers पर Keras और अपनी निर्भरता की वजह से अभी तक बैकएंड में समर्थित नहीं है, tf.Variable रों।
  • tfp.math.minimize पर अपनी निर्भरता की वजह से JAX पर TFP में काम नहीं करता tf.Variable
  • JAX पर TFP के साथ, टेंसर आकार हमेशा ठोस पूर्णांक मान होते हैं और TF पर TFP की तरह कभी भी अज्ञात/गतिशील नहीं होते हैं।
  • छद्म यादृच्छिकता को TF और JAX (परिशिष्ट देखें) में अलग तरह से नियंत्रित किया जाता है।
  • में पुस्तकालय tfp.experimental JAX सब्सट्रेट में मौजूद गारंटी नहीं है।
  • TF और JAX के बीच Dtype पदोन्नति नियम भिन्न हैं। जेएक्स पर टीएफपी स्थिरता के लिए आंतरिक रूप से टीएफ के डीटाइप सेमेन्टिक्स का सम्मान करने का प्रयास करता है।
  • बिजेक्टर को अभी तक JAX pytrees के रूप में पंजीकृत नहीं किया गया है।

क्या JAX पर TFP में समर्थित है की पूरी सूची देखने के लिए, कृपया को देखें API दस्तावेज़

निष्कर्ष

हमने TFP की बहुत सी विशेषताओं को JAX में पोर्ट किया है और यह देखने के लिए उत्साहित हैं कि हर कोई क्या बनाएगा। कुछ कार्यक्षमता अभी तक समर्थित नहीं है; हम कुछ आप के लिए महत्वपूर्ण नहीं छूटा है अगर (या यदि आप एक बग मिल!) हमें से संपर्क करें - आप ईमेल कर सकते हैं tfprobability@tensorflow.org या पर एक मुद्दा फ़ाइल हमारे Github रेपो

परिशिष्ट: JAX . में छद्म यादृच्छिकता

JAX के कूट-यादृच्छिक संख्या पीढ़ी (PRNG) मॉडल राज्यविहीन है। एक स्टेटफुल मॉडल के विपरीत, कोई भी परिवर्तनशील वैश्विक स्थिति नहीं है जो प्रत्येक यादृच्छिक ड्रा के बाद विकसित होती है। JAX के मॉडल में, हम एक PRNG कुंजी है, जो 32-बिट पूर्णांकों की एक जोड़ी की तरह काम करता साथ शुरू करते हैं। हम का उपयोग करके इन कुंजियों का निर्माण कर सकते jax.random.PRNGKey

key = random.PRNGKey(0)  # Creates a key with value [0, 0]
print(key)
[0 0]

JAX में रैंडम कार्यों एक महत्वपूर्ण उपभोग करने के लिए निर्धारणात्मक एक यादृच्छिक variate उत्पादन, जिसका अर्थ है कि वे फिर से नहीं किया जाना चाहिए। उदाहरण के लिए, हम उपयोग कर सकते हैं key एक सामान्य रूप से वितरित मूल्य नमूने के लिए है, लेकिन हम उपयोग नहीं करना चाहिए key फिर कहीं और। इसके अलावा, में एक ही मूल्य गुजर random.normal एक ही मूल्य का उत्पादन करेगा।

print(random.normal(key))
-0.20584226

तो हम कभी भी एक ही कुंजी से कई नमूने कैसे खींच सकते हैं? उत्तर कुंजी बंटवारे है। मूल विचार है कि हम एक विभाजित कर सकते हैं है PRNGKey कई में, और नए चाबियों का प्रत्येक अनियमितता के एक स्वतंत्र स्रोत के रूप में इलाज किया जा सकता।

key1, key2 = random.split(key, num=2)
print(key1, key2)
[4146024105  967050713] [2718843009 1272950319]

कुंजी विभाजन नियतात्मक है लेकिन अराजक है, इसलिए प्रत्येक नई कुंजी का उपयोग अब एक अलग यादृच्छिक नमूना बनाने के लिए किया जा सकता है।

print(random.normal(key1), random.normal(key2))
0.14389051 -1.2515389

JAX के नियतात्मक कुंजी बंटवारे मॉडल के बारे में अधिक जानकारी के लिए, इस गाइड