Probabilité TensorFlow sur JAX

Voir sur TensorFlow.org Exécuter dans Google Colab Voir la source sur GitHub Télécharger le cahier

Tensorflow Probabilité (TFP) est une bibliothèque pour le raisonnement probabiliste et l' analyse statistique qui fonctionne également sur JAX ! Pour ceux qui ne sont pas familiers, JAX est une bibliothèque de calcul numérique accéléré basée sur des transformations de fonctions composables.

TFP sur JAX prend en charge de nombreuses fonctionnalités les plus utiles de TFP standard tout en préservant les abstractions et les API avec lesquelles de nombreux utilisateurs de TFP sont désormais à l'aise.

Installer

TFP sur JAX ne dépend pas de tensorflow; désinstallons entièrement TensorFlow de ce Colab.

pip uninstall tensorflow -y -q

Nous pouvons installer TFP sur JAX avec les dernières versions nocturnes de TFP.

pip install -Uq tfp-nightly[jax] > /dev/null

Importons quelques bibliothèques Python utiles.

import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
from sklearn import datasets
sns.set(style='white')
/usr/local/lib/python3.6/dist-packages/statsmodels/tools/_testing.py:19: FutureWarning: pandas.util.testing is deprecated. Use the functions in the public API at pandas.testing instead.
  import pandas.util.testing as tm

Importons également quelques fonctionnalités JAX de base.

import jax.numpy as jnp
from jax import grad
from jax import jit
from jax import random
from jax import value_and_grad
from jax import vmap

Importation de TFP sur JAX

Pour utiliser TFP sur JAX, il suffit d' importer le jax « substrat » et l' utiliser comme vous le feriez habituellement tfp :

from tensorflow_probability.substrates import jax as tfp
tfd = tfp.distributions
tfb = tfp.bijectors
tfpk = tfp.math.psd_kernels

Démo : Régression logistique bayésienne

Pour démontrer ce que nous pouvons faire avec le backend JAX, nous allons implémenter une régression logistique bayésienne appliquée à l'ensemble de données Iris classique.

Tout d'abord, importons l'ensemble de données Iris et extrayons quelques métadonnées.

iris = datasets.load_iris()
features, labels = iris['data'], iris['target']

num_features = features.shape[-1]
num_classes = len(iris.target_names)

Nous pouvons définir le modèle en utilisant tfd.JointDistributionCoroutine . Nous allons mettre prieurs normale standard sur les poids et le terme de polarisation , puis écrire une target_log_prob fonction que les broches étiquettes échantillonnées aux données.

Root = tfd.JointDistributionCoroutine.Root
def model():
  w = yield Root(tfd.Sample(tfd.Normal(0., 1.),
                            sample_shape=(num_features, num_classes)))
  b = yield Root(
      tfd.Sample(tfd.Normal(0., 1.), sample_shape=(num_classes,)))
  logits = jnp.dot(features, w) + b
  yield tfd.Independent(tfd.Categorical(logits=logits),
                        reinterpreted_batch_ndims=1)


dist = tfd.JointDistributionCoroutine(model)
def target_log_prob(*params):
  return dist.log_prob(params + (labels,))

Nous prélevons de dist pour produire un état initial pour MCMC. Nous pouvons alors définir une fonction qui prend une clé aléatoire et un état initial, et produit 500 échantillons à partir d'un No-U-Turn-Sampler (NUTS). Notez que nous pouvons utiliser des transformations JAX comme jit compiler notre NUTS sampler en utilisant XLA.

init_key, sample_key = random.split(random.PRNGKey(0))
init_params = tuple(dist.sample(seed=init_key)[:-1])

@jit
def run_chain(key, state):
  kernel = tfp.mcmc.NoUTurnSampler(target_log_prob, 1e-3)
  return tfp.mcmc.sample_chain(500,
      current_state=state,
      kernel=kernel,
      trace_fn=lambda _, results: results.target_log_prob,
      num_burnin_steps=500,
      seed=key)

states, log_probs = run_chain(sample_key, init_params)
plt.figure()
plt.plot(log_probs)
plt.ylabel('Target Log Prob')
plt.xlabel('Iterations of NUTS')
plt.show()

png

Utilisons nos échantillons pour effectuer une moyenne du modèle bayésien (BMA) en faisant la moyenne des probabilités prédites de chaque ensemble de poids.

Écrivons d'abord une fonction qui, pour un ensemble donné de paramètres, produira les probabilités sur chaque classe. Nous pouvons utiliser dist.sample_distributions pour obtenir la distribution finale dans le modèle.

def classifier_probs(params):
  dists, _ = dist.sample_distributions(seed=random.PRNGKey(0),
                                       value=params + (None,))
  return dists[-1].distribution.probs_parameter()

Nous pouvons vmap(classifier_probs) sur l'ensemble des échantillons pour obtenir les probabilités de classe prévues pour chacun de nos échantillons. Nous calculons ensuite la précision moyenne sur chaque échantillon et la précision à partir de la moyenne du modèle bayésien.

all_probs = jit(vmap(classifier_probs))(states)
print('Average accuracy:', jnp.mean(all_probs.argmax(axis=-1) == labels))
print('BMA accuracy:', jnp.mean(all_probs.mean(axis=0).argmax(axis=-1) == labels))
Average accuracy: 0.96952
BMA accuracy: 0.97999996

On dirait que BMA réduit notre taux d'erreur de près d'un tiers !

Fondamentaux

TFP sur JAX a une API identique à TF où au lieu d'accepter des objets TF comme tf.Tensor s il accepte l'analogue de JAX. Par exemple, chaque fois qu'un tf.Tensor était auparavant utilisé comme entrée, l'API attend maintenant un JAX DeviceArray . Au lieu de retourner un tf.Tensor , les méthodes TFP retourneront DeviceArray s. TFP sur JAX travaille également avec des structures imbriquées d'objets JAX, comme une liste ou un dictionnaire de DeviceArray s.

Répartition

La plupart des distributions de TFP sont supportées en JAX avec une sémantique très similaire à leurs homologues de TF. Ils sont également inscrits comme JAX Pytrees , afin qu'ils puissent être entrées et sorties des fonctions transformées JAX.

Répartitions de base

La log_prob méthode de distribution fonctionne de la même.

dist = tfd.Normal(0., 1.)
print(dist.log_prob(0.))
-0.9189385

L' échantillonnage d'une distribution nécessite le passage explicitement dans une PRNGKey (ou une liste d'entiers) comme seed argument mot - clé. Ne pas transmettre explicitement une graine générera une erreur.

tfd.Normal(0., 1.).sample(seed=random.PRNGKey(0))
DeviceArray(-0.20584226, dtype=float32)

La sémantique de forme pour les distributions restent les mêmes dans JAX, où les distributions auront chacun un event_shape et un batch_shape et dessin de nombreux échantillons ajoutera supplémentaires sample_shape dimensions.

Par exemple, un tfd.MultivariateNormalDiag avec des paramètres de vecteur aura une forme d'événement vecteur et la forme de lot vide.

dist = tfd.MultivariateNormalDiag(
    loc=jnp.zeros(5),
    scale_diag=jnp.ones(5)
)
print('Event shape:', dist.event_shape)
print('Batch shape:', dist.batch_shape)
Event shape: (5,)
Batch shape: ()

D'autre part, un tfd.Normal paramétrés avec des vecteurs aura une forme d'événement scalaire et vecteur forme lot.

dist = tfd.Normal(
    loc=jnp.ones(5),
    scale=jnp.ones(5),
)
print('Event shape:', dist.event_shape)
print('Batch shape:', dist.batch_shape)
Event shape: ()
Batch shape: (5,)

La sémantique de la prise log_prob des échantillons fonctionne de la même dans JAX aussi.

dist =  tfd.Normal(jnp.zeros(5), jnp.ones(5))
s = dist.sample(sample_shape=(10, 2), seed=random.PRNGKey(0))
print(dist.log_prob(s).shape)

dist =  tfd.Independent(tfd.Normal(jnp.zeros(5), jnp.ones(5)), 1)
s = dist.sample(sample_shape=(10, 2), seed=random.PRNGKey(0))
print(dist.log_prob(s).shape)
(10, 2, 5)
(10, 2)

Parce que JAX DeviceArray s sont compatibles avec les bibliothèques comme NumPy et Matplotlib, nous pouvons nourrir des échantillons directement dans une fonction de traçage.

sns.distplot(tfd.Normal(0., 1.).sample(1000, seed=random.PRNGKey(0)))
plt.show()

png

Distribution méthodes sont compatibles avec les transformations de JAX.

sns.distplot(jit(vmap(lambda key: tfd.Normal(0., 1.).sample(seed=key)))(
    random.split(random.PRNGKey(0), 2000)))
plt.show()

png

x = jnp.linspace(-5., 5., 100)
plt.plot(x, jit(vmap(grad(tfd.Normal(0., 1.).prob)))(x))
plt.show()

png

Parce que les distributions TFP sont enregistrées en tant que JAX nœuds de pytree, nous pouvons écrire des fonctions avec des distributions comme entrées ou sorties et de les transformer en utilisant jit , mais ils ne sont pas encore pris en charge comme arguments pour vmap fonctions -ed.

@jit
def random_distribution(key):
  loc_key, scale_key = random.split(key)
  loc, log_scale = random.normal(loc_key), random.normal(scale_key)
  return tfd.Normal(loc, jnp.exp(log_scale))
random_dist = random_distribution(random.PRNGKey(0))
print(random_dist.mean(), random_dist.variance())
0.14389051 0.081832744

Distributions transformées

Distributions distributions dont Transformé -à- dire les échantillons sont passés à travers un Bijector travaillent également hors de la boîte (bijectors trop de travail! Voir ci - dessous).

dist = tfd.TransformedDistribution(
    tfd.Normal(0., 1.),
    tfb.Sigmoid()
)
sns.distplot(dist.sample(1000, seed=random.PRNGKey(0)))
plt.show()

png

Distributions conjointes

TFP offre JointDistribution s pour permettre la combinaison des distributions de composants dans une distribution unique sur plusieurs variables aléatoires. À l' heure actuelle, TFP propose trois variantes de base ( JointDistributionSequential , JointDistributionNamed et JointDistributionCoroutine ) qui sont tous pris en charge par JAX. Les AutoBatched variantes sont tous pris en charge.

dist = tfd.JointDistributionSequential([
  tfd.Normal(0., 1.),
  lambda x: tfd.Normal(x, 1e-1)
])
plt.scatter(*dist.sample(1000, seed=random.PRNGKey(0)), alpha=0.5)
plt.show()

png

joint = tfd.JointDistributionNamed(dict(
    e=             tfd.Exponential(rate=1.),
    n=             tfd.Normal(loc=0., scale=2.),
    m=lambda n, e: tfd.Normal(loc=n, scale=e),
    x=lambda    m: tfd.Sample(tfd.Bernoulli(logits=m), 12),
))
joint.sample(seed=random.PRNGKey(0))
{'e': DeviceArray(3.376818, dtype=float32),
 'm': DeviceArray(2.5449684, dtype=float32),
 'n': DeviceArray(-0.6027825, dtype=float32),
 'x': DeviceArray([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], dtype=int32)}
Root = tfd.JointDistributionCoroutine.Root
def model():
  e = yield Root(tfd.Exponential(rate=1.))
  n = yield Root(tfd.Normal(loc=0, scale=2.))
  m = yield tfd.Normal(loc=n, scale=e)
  x = yield tfd.Sample(tfd.Bernoulli(logits=m), 12)

joint = tfd.JointDistributionCoroutine(model)

joint.sample(seed=random.PRNGKey(0))
StructTuple(var0=DeviceArray(0.17315261, dtype=float32), var1=DeviceArray(-3.290489, dtype=float32), var2=DeviceArray(-3.1949058, dtype=float32), var3=DeviceArray([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], dtype=int32))

Autres répartitions

Les processus gaussiens fonctionnent également en mode JAX !

k1, k2, k3 = random.split(random.PRNGKey(0), 3)
observation_noise_variance = 0.01
f = lambda x: jnp.sin(10*x[..., 0]) * jnp.exp(-x[..., 0]**2)
observation_index_points = random.uniform(
    k1, [50], minval=-1.,maxval= 1.)[..., jnp.newaxis]
observations = f(observation_index_points) + tfd.Normal(
    loc=0., scale=jnp.sqrt(observation_noise_variance)).sample(seed=k2)

index_points = jnp.linspace(-1., 1., 100)[..., jnp.newaxis]

kernel = tfpk.ExponentiatedQuadratic(length_scale=0.1)

gprm = tfd.GaussianProcessRegressionModel(
    kernel=kernel,
    index_points=index_points,
    observation_index_points=observation_index_points,
    observations=observations,
    observation_noise_variance=observation_noise_variance)

samples = gprm.sample(10, seed=k3)
for i in range(10):
  plt.plot(index_points, samples[i], alpha=0.5)
plt.plot(observation_index_points, observations, marker='o', linestyle='')
plt.show()

png

Les modèles de Markov cachés sont également pris en charge.

initial_distribution = tfd.Categorical(probs=[0.8, 0.2])
transition_distribution = tfd.Categorical(probs=[[0.7, 0.3],
                                                 [0.2, 0.8]])

observation_distribution = tfd.Normal(loc=[0., 15.], scale=[5., 10.])

model = tfd.HiddenMarkovModel(
    initial_distribution=initial_distribution,
    transition_distribution=transition_distribution,
    observation_distribution=observation_distribution,
    num_steps=7)

print(model.mean())
print(model.log_prob(jnp.zeros(7)))
print(model.sample(seed=random.PRNGKey(0)))
[3.       6.       7.5      8.249999 8.625001 8.812501 8.90625 ]
/usr/local/lib/python3.6/dist-packages/tensorflow_probability/substrates/jax/distributions/hidden_markov_model.py:483: UserWarning: HiddenMarkovModel.log_prob in TFP versions < 0.12.0 had a bug in which the transition model was applied prior to the initial step. This bug has been fixed. You may observe a slight change in behavior.
  'HiddenMarkovModel.log_prob in TFP versions < 0.12.0 had a bug '
-19.855635
[ 1.3641367  0.505798   1.3626463  3.6541772  2.272286  15.10309
 22.794212 ]

Quelques distributions comme PixelCNN ne sont pas encore pris en charge en raison de dépendances strictes sur les incompatibilités tensorflow ou XLA.

Bijecteurs

La plupart des bijecteurs de TFP sont supportés en JAX aujourd'hui !

tfb.Exp().inverse(1.)
DeviceArray(0., dtype=float32)
bij = tfb.Shift(1.)(tfb.Scale(3.))
print(bij.forward(jnp.ones(5)))
print(bij.inverse(jnp.ones(5)))
[4. 4. 4. 4. 4.]
[0. 0. 0. 0. 0.]
b = tfb.FillScaleTriL(diag_bijector=tfb.Exp(), diag_shift=None)
print(b.forward(x=[0., 0., 0.]))
print(b.inverse(y=[[1., 0], [.5, 2]]))
[[1. 0.]
 [0. 1.]]
[0.6931472 0.5       0.       ]
b = tfb.Chain([tfb.Exp(), tfb.Softplus()])
# or:
# b = tfb.Exp()(tfb.Softplus())
print(b.forward(-jnp.ones(5)))
[1.3678794 1.3678794 1.3678794 1.3678794 1.3678794]

Bijectors sont compatibles avec les transformations de JAX comme jit , grad et vmap .

jit(vmap(tfb.Exp().inverse))(jnp.arange(4.))
DeviceArray([     -inf, 0.       , 0.6931472, 1.0986123], dtype=float32)
x = jnp.linspace(0., 1., 100)
plt.plot(x, jit(grad(lambda x: vmap(tfb.Sigmoid().inverse)(x).sum()))(x))
plt.show()

png

Certains bijectors, comme RealNVP et FFJORD ne sont pas encore pris en charge.

MCMC

Nous avons porté tfp.mcmc à JAX aussi bien, afin que nous puissions exécuter des algorithmes comme hamiltonien Monte Carlo (HMC) et le No-U-Turn-Sampler (NUTS) en JAX.

target_log_prob = tfd.MultivariateNormalDiag(jnp.zeros(2), jnp.ones(2)).log_prob

Contrairement à TFP sur TF, nous sommes tenus de passer un PRNGKey en sample_chain en utilisant la seed argument mot - clé.

def run_chain(key, state):
  kernel = tfp.mcmc.NoUTurnSampler(target_log_prob, 1e-1)
  return tfp.mcmc.sample_chain(1000,
      current_state=state,
      kernel=kernel,
      trace_fn=lambda _, results: results.target_log_prob,
      seed=key)
states, log_probs = jit(run_chain)(random.PRNGKey(0), jnp.zeros(2))
plt.figure()
plt.scatter(*states.T, alpha=0.5)
plt.figure()
plt.plot(log_probs)
plt.show()

png

png

Pour exécuter plusieurs chaînes, nous pouvons soit passer un lot d'états dans sample_chain ou l' utilisation vmap (bien que nous n'avons pas encore exploré les différences de performance entre les deux approches).

states, log_probs = jit(run_chain)(random.PRNGKey(0), jnp.zeros([10, 2]))
plt.figure()
for i in range(10):
  plt.scatter(*states[:, i].T, alpha=0.5)
plt.figure()
for i in range(10):
  plt.plot(log_probs[:, i], alpha=0.5)
plt.show()

png

png

Optimiseurs

TFP sur JAX prend en charge certains optimiseurs importants tels que BFGS et L-BFGS. Mettons en place une simple fonction de perte quadratique mise à l'échelle.

minimum = jnp.array([1.0, 1.0])  # The center of the quadratic bowl.
scales = jnp.array([2.0, 3.0])  # The scales along the two axes.

# The objective function and the gradient.
def quadratic_loss(x):
  return jnp.sum(scales * jnp.square(x - minimum))

start = jnp.array([0.6, 0.8])  # Starting point for the search.

BFGS peut trouver le minimum de cette perte.

optim_results = tfp.optimizer.bfgs_minimize(
    value_and_grad(quadratic_loss), initial_position=start, tolerance=1e-8)

# Check that the search converged
assert(optim_results.converged)
# Check that the argmin is close to the actual value.
np.testing.assert_allclose(optim_results.position, minimum)
# Print out the total number of function evaluations it took. Should be 5.
print("Function evaluations: %d" % optim_results.num_objective_evaluations)
Function evaluations: 5

Le L-BFGS aussi.

optim_results = tfp.optimizer.lbfgs_minimize(
    value_and_grad(quadratic_loss), initial_position=start, tolerance=1e-8)

# Check that the search converged
assert(optim_results.converged)
# Check that the argmin is close to the actual value.
np.testing.assert_allclose(optim_results.position, minimum)
# Print out the total number of function evaluations it took. Should be 5.
print("Function evaluations: %d" % optim_results.num_objective_evaluations)
Function evaluations: 5

Pour vmap L-BFGS, Fixons une fonction qui permet d' optimiser la perte d'un seul point de départ.

def optimize_single(start):
  return tfp.optimizer.lbfgs_minimize(
      value_and_grad(quadratic_loss), initial_position=start, tolerance=1e-8)

all_results = jit(vmap(optimize_single))(
    random.normal(random.PRNGKey(0), (10, 2)))
assert all(all_results.converged)
for i in range(10):
  np.testing.assert_allclose(optim_results.position[i], minimum)
print("Function evaluations: %s" % all_results.num_objective_evaluations)
Function evaluations: [6 6 9 6 6 8 6 8 5 9]

Mises en garde

Il existe des différences fondamentales entre TF et JAX, certains comportements TFP seront différents entre les deux substrats et toutes les fonctionnalités ne sont pas prises en charge. Par example,

  • TFP sur JAX ne supporte pas quelque chose comme tf.Variable puisque rien comme il existe dans JAX. Cela signifie également des utilitaires comme tfp.util.TransformedVariable ne sont pas pris en charge non plus .
  • tfp.layers est pas pris en charge dans le back - end encore, en raison de sa dépendance à l' égard Keras et tf.Variable s.
  • tfp.math.minimize ne fonctionne pas dans le TFP JAX en raison de sa dépendance à l' égard tf.Variable .
  • Avec TFP sur JAX, les formes tensorielles sont toujours des valeurs entières concrètes et ne sont jamais inconnues/dynamiques comme dans TFP sur TF.
  • Le pseudo-aléatoire est géré différemment dans TF et JAX (voir annexe).
  • Les bibliothèques de tfp.experimental ne sont pas garantis d'exister dans le substrat JAX.
  • Les règles de promotion Dtype sont différentes entre TF et JAX. TFP sur JAX essaie de respecter la sémantique dtype de TF en interne, par souci de cohérence.
  • Les bijecteurs n'ont pas encore été enregistrés en tant que pytrees JAX.

Pour voir la liste complète de ce qui est pris en charge dans TFP sur JAX, s'il vous plaît se référer à la documentation de l' API .

Conclusion

Nous avons porté de nombreuses fonctionnalités de TFP sur JAX et sommes impatients de voir ce que tout le monde va construire. Certaines fonctionnalités ne sont pas encore prises en charge ; si nous avons manqué quelque chose d' important pour vous (ou si vous trouvez un bug!) s'il vous plaît nous rejoindre - vous pouvez envoyer tfprobability@tensorflow.org ou déposer une question sur notre repo Github .

Annexe : pseudo-aléatoire dans JAX

Le modèle de génération de nombres pseudo - aléatoires (PRNG) de Jax est apatride. Contrairement à un modèle avec état, il n'y a pas d'état global mutable qui évolue après chaque tirage aléatoire. Dans le modèle de JAX, nous commençons par une clé PRNG, qui agit comme une paire d'entiers 32 bits. Nous pouvons construire ces clés en utilisant jax.random.PRNGKey .

key = random.PRNGKey(0)  # Creates a key with value [0, 0]
print(key)
[0 0]

Fonctions aléatoires JAX consomment une clé pour produire un nombre aléatoire déterministe, ce qui signifie qu'ils ne doivent pas être utilisés à nouveau. Par exemple, nous pouvons utiliser la key pour échantillonner une valeur distribuée normalement, mais il ne faut pas utiliser la key à nouveau ailleurs. De plus, en passant la même valeur en random.normal produira la même valeur.

print(random.normal(key))
-0.20584226

Alors, comment pouvons-nous tirer plusieurs échantillons à partir d'une seule clé ? La réponse est la scission de clé. L'idée de base est que l' on peut diviser un PRNGKey en plusieurs, et chacune des nouvelles clés peuvent être traitées comme une source indépendante de caractère aléatoire.

key1, key2 = random.split(key, num=2)
print(key1, key2)
[4146024105  967050713] [2718843009 1272950319]

Le fractionnement des clés est déterministe mais chaotique, de sorte que chaque nouvelle clé peut désormais être utilisée pour tirer un échantillon aléatoire distinct.

print(random.normal(key1), random.normal(key2))
0.14389051 -1.2515389

Pour plus de détails sur le modèle de partage de clé déterministe de JAX, consultez ce guide .