ความน่าจะเป็นของ TensorFlow บน JAX

ดูบน TensorFlow.org ทำงานใน Google Colab ดูแหล่งที่มาบน GitHub ดาวน์โหลดโน๊ตบุ๊ค

TensorFlow ความน่าจะเป็น (TFP) เป็นห้องสมุดสำหรับเหตุผลน่าจะเป็นและการวิเคราะห์ทางสถิติที่ตอนนี้ยังทำงานบน JAX ! สำหรับผู้ที่ไม่คุ้นเคย JAX เป็นไลบรารีสำหรับการประมวลผลเชิงตัวเลขแบบเร่งความเร็วโดยอิงจากการแปลงฟังก์ชันที่เขียนได้

TFP บน JAX รองรับการทำงานที่มีประโยชน์มากที่สุดของ TFP ปกติ ในขณะที่ยังคงรักษา abstractions และ API ที่ผู้ใช้ TFP จำนวนมากรู้สึกสบายใจ

ติดตั้ง

TFP ใน JAX ไม่ขึ้นอยู่กับ TensorFlow; ถอนการติดตั้ง TensorFlow จาก Colab นี้โดยสิ้นเชิง

pip uninstall tensorflow -y -q

เราสามารถติดตั้ง TFP บน JAX ด้วย TFP รุ่นล่าสุดทุกคืน

pip install -Uq tfp-nightly[jax] > /dev/null

มานำเข้าไลบรารี Python ที่มีประโยชน์กันเถอะ

import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
from sklearn import datasets
sns.set(style='white')
/usr/local/lib/python3.6/dist-packages/statsmodels/tools/_testing.py:19: FutureWarning: pandas.util.testing is deprecated. Use the functions in the public API at pandas.testing instead.
  import pandas.util.testing as tm

มานำเข้าฟังก์ชัน JAX พื้นฐานกันด้วย

import jax.numpy as jnp
from jax import grad
from jax import jit
from jax import random
from jax import value_and_grad
from jax import vmap

การนำเข้า TFP บน JAX

ที่จะใช้ใน TFP JAX เพียงแค่นำเข้า jax "สารตั้งต้น" และใช้เป็นคุณมักจะ tfp :

from tensorflow_probability.substrates import jax as tfp
tfd = tfp.distributions
tfb = tfp.bijectors
tfpk = tfp.math.psd_kernels

การสาธิต: การถดถอยโลจิสติกแบบเบย์

เพื่อแสดงให้เห็นว่าเราสามารถทำอะไรกับแบ็กเอนด์ JAX ได้บ้าง เราจะนำการถดถอยโลจิสติกแบบเบย์มาใช้กับชุดข้อมูล Iris แบบคลาสสิก

ขั้นแรก ให้นำเข้าชุดข้อมูล Iris และแยกข้อมูลเมตาบางส่วน

iris = datasets.load_iris()
features, labels = iris['data'], iris['target']

num_features = features.shape[-1]
num_classes = len(iris.target_names)

เราสามารถกำหนดรูปแบบการใช้ tfd.JointDistributionCoroutine เราจะใส่ไพรเออร์แบบปกติมาตรฐานทั้งน้ำหนักและระยะอคติแล้วเขียน target_log_prob ฟังก์ชั่นที่หมุดป้ายตัวอย่างข้อมูล

Root = tfd.JointDistributionCoroutine.Root
def model():
  w = yield Root(tfd.Sample(tfd.Normal(0., 1.),
                            sample_shape=(num_features, num_classes)))
  b = yield Root(
      tfd.Sample(tfd.Normal(0., 1.), sample_shape=(num_classes,)))
  logits = jnp.dot(features, w) + b
  yield tfd.Independent(tfd.Categorical(logits=logits),
                        reinterpreted_batch_ndims=1)


dist = tfd.JointDistributionCoroutine(model)
def target_log_prob(*params):
  return dist.log_prob(params + (labels,))

เราตัวอย่างจาก dist การผลิตสถานะเริ่มต้นสำหรับ MCMC จากนั้นเราสามารถกำหนดฟังก์ชันที่ใช้คีย์สุ่มและสถานะเริ่มต้น และสร้างตัวอย่าง 500 รายการจาก No-U-Turn-Sampler (NUTS) โปรดทราบว่าเราสามารถใช้การแปลง JAX เช่น jit รวบรวมตัวอย่างถั่วของเราโดยใช้ XLA

init_key, sample_key = random.split(random.PRNGKey(0))
init_params = tuple(dist.sample(seed=init_key)[:-1])

@jit
def run_chain(key, state):
  kernel = tfp.mcmc.NoUTurnSampler(target_log_prob, 1e-3)
  return tfp.mcmc.sample_chain(500,
      current_state=state,
      kernel=kernel,
      trace_fn=lambda _, results: results.target_log_prob,
      num_burnin_steps=500,
      seed=key)

states, log_probs = run_chain(sample_key, init_params)
plt.figure()
plt.plot(log_probs)
plt.ylabel('Target Log Prob')
plt.xlabel('Iterations of NUTS')
plt.show()

png

ลองใช้ตัวอย่างของเราในการหาค่าเฉลี่ยแบบจำลองเบย์ (BMA) โดยหาค่าเฉลี่ยความน่าจะเป็นที่คาดการณ์ไว้ของน้ำหนักแต่ละชุด

ขั้นแรก ให้เขียนฟังก์ชันที่สำหรับชุดของพารามิเตอร์ที่กำหนดจะทำให้เกิดความน่าจะเป็นในแต่ละคลาส เราสามารถใช้ dist.sample_distributions ที่จะได้รับการกระจายสุดท้ายในรูปแบบ

def classifier_probs(params):
  dists, _ = dist.sample_distributions(seed=random.PRNGKey(0),
                                       value=params + (None,))
  return dists[-1].distribution.probs_parameter()

เราสามารถ vmap(classifier_probs) มากกว่าชุดของตัวอย่างที่จะได้รับความน่าจะเป็นระดับที่คาดการณ์ไว้สำหรับแต่ละตัวอย่างของเรา จากนั้นเราจะคำนวณความแม่นยำเฉลี่ยในแต่ละตัวอย่าง และความแม่นยำจากค่าเฉลี่ยของแบบจำลองเบย์

all_probs = jit(vmap(classifier_probs))(states)
print('Average accuracy:', jnp.mean(all_probs.argmax(axis=-1) == labels))
print('BMA accuracy:', jnp.mean(all_probs.mean(axis=0).argmax(axis=-1) == labels))
Average accuracy: 0.96952
BMA accuracy: 0.97999996

ดูเหมือนว่า BMA จะลดอัตราความผิดพลาดของเราลงเกือบหนึ่งในสาม!

พื้นฐาน

TFP ใน JAX มี API เหมือนกับ TF ที่แทนการรับวัตถุ TF เช่น tf.Tensor s ยอมรับอนาล็อก JAX ยกตัวอย่างเช่นใดก็ตาม tf.Tensor ถูกนำมาใช้ก่อนหน้านี้เป็น input API ในขณะนี้คาดว่าจะมี JAX DeviceArray แทนที่จะกลับ tf.Tensor วิธี TFP จะกลับ DeviceArray s TFP ใน JAX ยังทำงานร่วมกับโครงสร้างที่ซ้อนกันของวัตถุ JAX เช่นรายการหรือพจนานุกรมของ DeviceArray s

การกระจาย

การแจกแจงของ TFP ส่วนใหญ่ได้รับการสนับสนุนใน JAX โดยมีความหมายที่คล้ายคลึงกันมากกับคู่ของ TF พวกเขายังได้รับการจดทะเบียนเป็น JAX Pytrees เพื่อให้พวกเขาสามารถเป็นปัจจัยการผลิตและผลของฟังก์ชั่น JAX-เปลี่ยน

การแจกแจงพื้นฐาน

log_prob วิธีการสำหรับการกระจายการทำงานเดียวกัน

dist = tfd.Normal(0., 1.)
print(dist.log_prob(0.))
-0.9189385

การสุ่มตัวอย่างจากการกระจายต้องใช้อย่างชัดเจนผ่านใน PRNGKey (หรือรายการของจำนวนเต็ม) ในขณะที่ seed โต้แย้งคำหลัก การไม่ส่งผ่านเมล็ดพันธุ์อย่างชัดเจนจะทำให้เกิดข้อผิดพลาด

tfd.Normal(0., 1.).sample(seed=random.PRNGKey(0))
DeviceArray(-0.20584226, dtype=float32)

ความหมายรูปร่างสำหรับการกระจายยังคงเหมือนเดิมใน JAX ที่กระจายแต่ละคนจะมี event_shape และ batch_shape และการวาดภาพตัวอย่างจำนวนมากจะเพิ่มเพิ่มเติม sample_shape มิติ

ยกตัวอย่างเช่น tfd.MultivariateNormalDiag กับพารามิเตอร์เวกเตอร์จะมีรูปร่างเหตุการณ์เวกเตอร์และรูปทรงชุดที่ว่างเปล่า

dist = tfd.MultivariateNormalDiag(
    loc=jnp.zeros(5),
    scale_diag=jnp.ones(5)
)
print('Event shape:', dist.event_shape)
print('Batch shape:', dist.batch_shape)
Event shape: (5,)
Batch shape: ()

บนมืออื่น ๆ ที่เป็น tfd.Normal แปรกับพาหะจะมีรูปร่างเหตุการณ์สเกลาร์และเวกเตอร์ชุดรูปร่าง

dist = tfd.Normal(
    loc=jnp.ones(5),
    scale=jnp.ones(5),
)
print('Event shape:', dist.event_shape)
print('Batch shape:', dist.batch_shape)
Event shape: ()
Batch shape: (5,)

ความหมายของการ log_prob ตัวอย่างการทำงานเดียวกันใน JAX เกินไป

dist =  tfd.Normal(jnp.zeros(5), jnp.ones(5))
s = dist.sample(sample_shape=(10, 2), seed=random.PRNGKey(0))
print(dist.log_prob(s).shape)

dist =  tfd.Independent(tfd.Normal(jnp.zeros(5), jnp.ones(5)), 1)
s = dist.sample(sample_shape=(10, 2), seed=random.PRNGKey(0))
print(dist.log_prob(s).shape)
(10, 2, 5)
(10, 2)

เพราะ JAX DeviceArray s เข้ากันได้กับห้องสมุดเช่น NumPy และ Matplotlib เราสามารถให้อาหารตัวอย่างโดยตรงในฟังก์ชั่นการวางแผน

sns.distplot(tfd.Normal(0., 1.).sample(1000, seed=random.PRNGKey(0)))
plt.show()

png

Distribution วิธีการเข้ากันได้กับการเปลี่ยนแปลง JAX

sns.distplot(jit(vmap(lambda key: tfd.Normal(0., 1.).sample(seed=key)))(
    random.split(random.PRNGKey(0), 2000)))
plt.show()

png

x = jnp.linspace(-5., 5., 100)
plt.plot(x, jit(vmap(grad(tfd.Normal(0., 1.).prob)))(x))
plt.show()

png

เพราะการกระจาย TFP มีการจดทะเบียนเป็นโหนด pytree JAX เราสามารถเขียนฟังก์ชั่นที่มีการกระจายเป็นปัจจัยการผลิตหรือเอาท์พุทและเปลี่ยนพวกเขาโดยใช้ jit แต่พวกเขายังไม่ได้รับการสนับสนุนเป็นข้อโต้แย้งที่จะ vmap ฟังก์ชั่น -ed

@jit
def random_distribution(key):
  loc_key, scale_key = random.split(key)
  loc, log_scale = random.normal(loc_key), random.normal(scale_key)
  return tfd.Normal(loc, jnp.exp(log_scale))
random_dist = random_distribution(random.PRNGKey(0))
print(random_dist.mean(), random_dist.variance())
0.14389051 0.081832744

การกระจายแบบแปลงร่าง

กระจายเปลี่ยนคือการกระจายตัวอย่างซึ่งจะผ่าน Bijector ยังทำงานออกจากกล่อง (bijectors ทำงานมากเกินไป! ดูด้านล่าง)

dist = tfd.TransformedDistribution(
    tfd.Normal(0., 1.),
    tfb.Sigmoid()
)
sns.distplot(dist.sample(1000, seed=random.PRNGKey(0)))
plt.show()

png

การกระจายร่วม

TFP มี JointDistribution ที่จะช่วยให้การรวมแจกแจงองค์ประกอบเป็นกระจายเดียวมากกว่าหลายตัวแปรสุ่ม ปัจจุบัน TFP ข้อเสนอสามสายพันธุ์หลัก ( JointDistributionSequential , JointDistributionNamed และ JointDistributionCoroutine ) ซึ่งทั้งหมดได้รับการสนับสนุนใน JAX AutoBatched สายพันธุ์นอกจากนี้ยังได้รับการสนับสนุนทั้งหมด

dist = tfd.JointDistributionSequential([
  tfd.Normal(0., 1.),
  lambda x: tfd.Normal(x, 1e-1)
])
plt.scatter(*dist.sample(1000, seed=random.PRNGKey(0)), alpha=0.5)
plt.show()

png

joint = tfd.JointDistributionNamed(dict(
    e=             tfd.Exponential(rate=1.),
    n=             tfd.Normal(loc=0., scale=2.),
    m=lambda n, e: tfd.Normal(loc=n, scale=e),
    x=lambda    m: tfd.Sample(tfd.Bernoulli(logits=m), 12),
))
joint.sample(seed=random.PRNGKey(0))
{'e': DeviceArray(3.376818, dtype=float32),
 'm': DeviceArray(2.5449684, dtype=float32),
 'n': DeviceArray(-0.6027825, dtype=float32),
 'x': DeviceArray([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], dtype=int32)}
Root = tfd.JointDistributionCoroutine.Root
def model():
  e = yield Root(tfd.Exponential(rate=1.))
  n = yield Root(tfd.Normal(loc=0, scale=2.))
  m = yield tfd.Normal(loc=n, scale=e)
  x = yield tfd.Sample(tfd.Bernoulli(logits=m), 12)

joint = tfd.JointDistributionCoroutine(model)

joint.sample(seed=random.PRNGKey(0))
StructTuple(var0=DeviceArray(0.17315261, dtype=float32), var1=DeviceArray(-3.290489, dtype=float32), var2=DeviceArray(-3.1949058, dtype=float32), var3=DeviceArray([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], dtype=int32))

การกระจายอื่น ๆ

กระบวนการเกาส์เซียนยังทำงานในโหมด JAX!

k1, k2, k3 = random.split(random.PRNGKey(0), 3)
observation_noise_variance = 0.01
f = lambda x: jnp.sin(10*x[..., 0]) * jnp.exp(-x[..., 0]**2)
observation_index_points = random.uniform(
    k1, [50], minval=-1.,maxval= 1.)[..., jnp.newaxis]
observations = f(observation_index_points) + tfd.Normal(
    loc=0., scale=jnp.sqrt(observation_noise_variance)).sample(seed=k2)

index_points = jnp.linspace(-1., 1., 100)[..., jnp.newaxis]

kernel = tfpk.ExponentiatedQuadratic(length_scale=0.1)

gprm = tfd.GaussianProcessRegressionModel(
    kernel=kernel,
    index_points=index_points,
    observation_index_points=observation_index_points,
    observations=observations,
    observation_noise_variance=observation_noise_variance)

samples = gprm.sample(10, seed=k3)
for i in range(10):
  plt.plot(index_points, samples[i], alpha=0.5)
plt.plot(observation_index_points, observations, marker='o', linestyle='')
plt.show()

png

รองรับโมเดล Markov ที่ซ่อนอยู่

initial_distribution = tfd.Categorical(probs=[0.8, 0.2])
transition_distribution = tfd.Categorical(probs=[[0.7, 0.3],
                                                 [0.2, 0.8]])

observation_distribution = tfd.Normal(loc=[0., 15.], scale=[5., 10.])

model = tfd.HiddenMarkovModel(
    initial_distribution=initial_distribution,
    transition_distribution=transition_distribution,
    observation_distribution=observation_distribution,
    num_steps=7)

print(model.mean())
print(model.log_prob(jnp.zeros(7)))
print(model.sample(seed=random.PRNGKey(0)))
[3.       6.       7.5      8.249999 8.625001 8.812501 8.90625 ]
/usr/local/lib/python3.6/dist-packages/tensorflow_probability/substrates/jax/distributions/hidden_markov_model.py:483: UserWarning: HiddenMarkovModel.log_prob in TFP versions < 0.12.0 had a bug in which the transition model was applied prior to the initial step. This bug has been fixed. You may observe a slight change in behavior.
  'HiddenMarkovModel.log_prob in TFP versions < 0.12.0 had a bug '
-19.855635
[ 1.3641367  0.505798   1.3626463  3.6541772  2.272286  15.10309
 22.794212 ]

กระจายน้อยเช่น PixelCNN ยังไม่สนับสนุนเนื่องจากการอ้างอิงที่เข้มงวดเกี่ยวกับ TensorFlow หรือ XLA กันไม่ได้

Bijectors

bijectors ของ TFP ส่วนใหญ่รองรับ JAX แล้ววันนี้!

tfb.Exp().inverse(1.)
DeviceArray(0., dtype=float32)
bij = tfb.Shift(1.)(tfb.Scale(3.))
print(bij.forward(jnp.ones(5)))
print(bij.inverse(jnp.ones(5)))
[4. 4. 4. 4. 4.]
[0. 0. 0. 0. 0.]
b = tfb.FillScaleTriL(diag_bijector=tfb.Exp(), diag_shift=None)
print(b.forward(x=[0., 0., 0.]))
print(b.inverse(y=[[1., 0], [.5, 2]]))
[[1. 0.]
 [0. 1.]]
[0.6931472 0.5       0.       ]
b = tfb.Chain([tfb.Exp(), tfb.Softplus()])
# or:
# b = tfb.Exp()(tfb.Softplus())
print(b.forward(-jnp.ones(5)))
[1.3678794 1.3678794 1.3678794 1.3678794 1.3678794]

Bijectors เข้ากันได้กับการเปลี่ยนแปลง JAX เช่น jit , grad และ vmap

jit(vmap(tfb.Exp().inverse))(jnp.arange(4.))
DeviceArray([     -inf, 0.       , 0.6931472, 1.0986123], dtype=float32)
x = jnp.linspace(0., 1., 100)
plt.plot(x, jit(grad(lambda x: vmap(tfb.Sigmoid().inverse)(x).sum()))(x))
plt.show()

png

bijectors บางอย่างเช่น RealNVP และ FFJORD ยังไม่สนับสนุน

MCMC

เราได้รังเพลิง tfp.mcmc เพื่อ JAX เป็นอย่างดีเพื่อให้เราสามารถเรียกใช้ขั้นตอนวิธีการเช่นมิล Monte Carlo (HMC) และ No-U-Turn-Sampler (ถั่ว) ใน JAX

target_log_prob = tfd.MultivariateNormalDiag(jnp.zeros(2), jnp.ones(2)).log_prob

ซึ่งแตกต่างจาก TFP ใน TF เราจะต้องผ่านการ PRNGKey เข้า sample_chain ใช้ seed โต้แย้งคำหลัก

def run_chain(key, state):
  kernel = tfp.mcmc.NoUTurnSampler(target_log_prob, 1e-1)
  return tfp.mcmc.sample_chain(1000,
      current_state=state,
      kernel=kernel,
      trace_fn=lambda _, results: results.target_log_prob,
      seed=key)
states, log_probs = jit(run_chain)(random.PRNGKey(0), jnp.zeros(2))
plt.figure()
plt.scatter(*states.T, alpha=0.5)
plt.figure()
plt.plot(log_probs)
plt.show()

png

png

เมื่อต้องการเรียกใช้โซ่หลายรายการเราทั้งสามารถส่งผ่านชุดของรัฐเข้าไป sample_chain หรือการใช้ vmap (แม้ว่าเรายังไม่ได้สำรวจความแตกต่างของผลการดำเนินงานระหว่างทั้งสองวิธี)

states, log_probs = jit(run_chain)(random.PRNGKey(0), jnp.zeros([10, 2]))
plt.figure()
for i in range(10):
  plt.scatter(*states[:, i].T, alpha=0.5)
plt.figure()
for i in range(10):
  plt.plot(log_probs[:, i], alpha=0.5)
plt.show()

png

png

เครื่องมือเพิ่มประสิทธิภาพ

TFP บน JAX รองรับเครื่องมือเพิ่มประสิทธิภาพที่สำคัญบางตัว เช่น BFGS และ L-BFGS มาตั้งค่าฟังก์ชันการสูญเสียกำลังสองที่ปรับขนาดอย่างง่ายกัน

minimum = jnp.array([1.0, 1.0])  # The center of the quadratic bowl.
scales = jnp.array([2.0, 3.0])  # The scales along the two axes.

# The objective function and the gradient.
def quadratic_loss(x):
  return jnp.sum(scales * jnp.square(x - minimum))

start = jnp.array([0.6, 0.8])  # Starting point for the search.

BFGS สามารถค้นหาความสูญเสียขั้นต่ำนี้ได้

optim_results = tfp.optimizer.bfgs_minimize(
    value_and_grad(quadratic_loss), initial_position=start, tolerance=1e-8)

# Check that the search converged
assert(optim_results.converged)
# Check that the argmin is close to the actual value.
np.testing.assert_allclose(optim_results.position, minimum)
# Print out the total number of function evaluations it took. Should be 5.
print("Function evaluations: %d" % optim_results.num_objective_evaluations)
Function evaluations: 5

L-BFGS ก็ทำได้เช่นกัน

optim_results = tfp.optimizer.lbfgs_minimize(
    value_and_grad(quadratic_loss), initial_position=start, tolerance=1e-8)

# Check that the search converged
assert(optim_results.converged)
# Check that the argmin is close to the actual value.
np.testing.assert_allclose(optim_results.position, minimum)
# Print out the total number of function evaluations it took. Should be 5.
print("Function evaluations: %d" % optim_results.num_objective_evaluations)
Function evaluations: 5

เพื่อ vmap L-BFGS ให้ชุดของฟังก์ชั่นที่เพิ่มประสิทธิภาพการขาดทุนสำหรับจุดเริ่มต้นเดียว

def optimize_single(start):
  return tfp.optimizer.lbfgs_minimize(
      value_and_grad(quadratic_loss), initial_position=start, tolerance=1e-8)

all_results = jit(vmap(optimize_single))(
    random.normal(random.PRNGKey(0), (10, 2)))
assert all(all_results.converged)
for i in range(10):
  np.testing.assert_allclose(optim_results.position[i], minimum)
print("Function evaluations: %s" % all_results.num_objective_evaluations)
Function evaluations: [6 6 9 6 6 8 6 8 5 9]

คำเตือน

มีความแตกต่างพื้นฐานบางอย่างระหว่าง TF และ JAX ลักษณะการทำงานของ TFP บางอย่างจะแตกต่างกันระหว่างพื้นผิวทั้งสอง และไม่รองรับฟังก์ชันการทำงานทั้งหมด ตัวอย่างเช่น,

  • TFP ใน JAX ไม่สนับสนุนอะไรเช่น tf.Variable ตั้งแต่ไม่มีอะไรเหมือนมันมีอยู่ใน JAX นอกจากนี้ยังหมายความสาธารณูปโภคเช่น tfp.util.TransformedVariable ยังไม่ได้รับการสนับสนุนอย่างใดอย่างหนึ่ง
  • tfp.layers ไม่ได้รับการสนับสนุนในส่วนหลัง ๆ เนื่องจากการพึ่งพา Keras และ tf.Variable s
  • tfp.math.minimize ไม่ทำงานใน TFP ใน JAX เพราะการพึ่งพา tf.Variable
  • ด้วย TFP บน JAX รูปร่างเทนเซอร์จะเป็นค่าจำนวนเต็มที่เป็นรูปธรรมเสมอ และจะไม่เป็นที่รู้จัก/เป็นไดนามิกเหมือนใน TFP บน TF
  • การสุ่มหลอกได้รับการจัดการแตกต่างกันใน TF และ JAX (ดูภาคผนวก)
  • ห้องสมุดใน tfp.experimental จะไม่รับประกันว่าจะมีอยู่ในสารตั้งต้น JAX
  • กฎการเลื่อนระดับ Dtype นั้นแตกต่างกันระหว่าง TF และ JAX TFP บน JAX พยายามเคารพความหมาย dtype ของ TF ภายในเพื่อความสอดคล้อง
  • Bijectors ยังไม่ได้ลงทะเบียนเป็น JAX pytrees

หากต้องการดูรายการที่สมบูรณ์ของสิ่งที่ได้รับการสนับสนุนใน TFP ใน JAX, โปรดดูที่ เอกสาร API

บทสรุป

เราได้ย้ายคุณสมบัติมากมายของ TFP ไปยัง JAX และรู้สึกตื่นเต้นที่จะได้เห็นสิ่งที่ทุกคนจะสร้าง ฟังก์ชันบางอย่างยังไม่ได้รับการสนับสนุน ถ้าเราได้พลาดสิ่งที่สำคัญกับคุณ (หรือถ้าคุณพบข้อผิดพลาด!) โปรดติดต่อเรา - คุณสามารถส่งอีเมล tfprobability@tensorflow.org หรือไฟล์ปัญหาใน repo Github ของเรา

ภาคผนวก: การสุ่มเทียมใน JAX

JAX หมายเลข pseudorandom รุ่น (PRNG) รุ่นไร้สัญชาติ ต่างจากโมเดล stateful ไม่มีสถานะโกลบอลที่ไม่แน่นอนที่วิวัฒนาการหลังจากการสุ่มแต่ละครั้ง ในรูปแบบ JAX ของเราเริ่มต้นด้วยคีย์ PRNG ซึ่งทำหน้าที่เหมือนคู่ของจำนวนเต็ม 32 บิต เราสามารถสร้างปุ่มเหล่านี้โดยใช้ jax.random.PRNGKey

key = random.PRNGKey(0)  # Creates a key with value [0, 0]
print(key)
[0 0]

ฟังก์ชั่นสุ่มใน JAX ใช้กุญแจสำคัญในการผลิต deterministically ตัวแปรสุ่มหมายถึงพวกเขาไม่ควรนำมาใช้อีกครั้ง ตัวอย่างเช่นเราสามารถใช้ key ที่จะลิ้มลองค่าการกระจายตามปกติ แต่เราไม่ควรใช้ key อีกครั้งอื่น ๆ นอกจากนี้การส่งผ่านค่าเดียวกันใน random.normal จะผลิตค่าเดียวกัน

print(random.normal(key))
-0.20584226

แล้วเราจะวาดตัวอย่างหลายตัวอย่างจากคีย์เดียวได้อย่างไร คำตอบคือแยกที่สำคัญ ความคิดพื้นฐานคือการที่เราสามารถแยก PRNGKey ออกเป็นหลาย ๆ และแต่ละคีย์ใหม่สามารถรักษาได้เป็นแหล่งที่เป็นอิสระจากการสุ่ม

key1, key2 = random.split(key, num=2)
print(key1, key2)
[4146024105  967050713] [2718843009 1272950319]

การแยกคีย์เป็นสิ่งที่กำหนดได้ แต่มีความโกลาหล ดังนั้นตอนนี้แต่ละคีย์ใหม่สามารถใช้เพื่อสุ่มตัวอย่างที่แตกต่างกันได้

print(random.normal(key1), random.normal(key2))
0.14389051 -1.2515389

สำหรับรายละเอียดเพิ่มเติมเกี่ยวกับรูปแบบที่สำคัญแยก JAX ของกำหนดดู คำแนะนำนี้