Programação Probabilística em Oryx

Ver no TensorFlow.org Executar no Google Colab Ver fonte no GitHub Baixar caderno
pip install -q -U jax jaxlib
pip install -q -Uq oryx -I
pip install -q tfp-nightly --upgrade
from functools import partial

import matplotlib.pyplot as plt
import seaborn as sns
sns.set(style='white')

import jax
import jax.numpy as jnp
from jax import jit, vmap, grad
from jax import random

from tensorflow_probability.substrates import jax as tfp
tfd = tfp.distributions

import oryx

Programação probabilística é a ideia de que podemos expressar modelos probabilísticos usando recursos de uma linguagem de programação. Tarefas como inferência bayesiana ou marginalização são fornecidas como recursos de linguagem e podem ser automatizadas.

Oryx fornece um sistema de programação probabilística no qual programas probabilísticos são expressos apenas como funções Python; esses programas são então transformados por meio de transformações de função combináveis, como aquelas em JAX! A ideia é começar com programas simples (como amostragem de um normal aleatório) e compô-los juntos para formar modelos (como uma rede neural bayesiana). Um ponto importante do projeto PPL do Oryx é permitir que os programas para se parecer com funções que você já ia escrever e uso de JAX, mas são anotados para fazer transformações conscientes deles.

Vamos primeiro importar a funcionalidade PPL central do Oryx.

from oryx.core.ppl import random_variable
from oryx.core.ppl import log_prob
from oryx.core.ppl import joint_sample
from oryx.core.ppl import joint_log_prob
from oryx.core.ppl import block
from oryx.core.ppl import intervene
from oryx.core.ppl import conditional
from oryx.core.ppl import graph_replace
from oryx.core.ppl import nest

O que são programas probabilísticos no Oryx?

No Oryx, os programas probabilísticos são apenas funções Python puras que operam em valores JAX e chaves pseudo-aleatórias e retornam uma amostra aleatória. Pelo projeto, eles são compatíveis com as transformações como jit e vmap . No entanto, o sistema de programação probabilística Oryx fornece ferramentas que permitem a você anotar suas funções de maneira útil.

Seguindo a filosofia JAX de funções puras, um programa probabilística Oryx é uma função Python que leva um JAX PRNGKey como seu primeiro argumento e qualquer número de argumentos condicionado subseqüentes. A saída da função é chamada uma "amostra" e as mesmas restrições que se aplicam a jit -ed e vmap funções -ED aplicar a programas de probabilidade (por exemplo, nenhum fluxo de dados dependente de controlo, não há efeitos secundários, etc.). Isso difere de muitos sistemas de programação probabilísticos imperativos em que uma 'amostra' é o rastreamento de execução inteiro, incluindo valores internos para a execução do programa. Veremos mais tarde como Oryx pode acessar valores internos usando o joint_sample , discutido abaixo.

Program :: PRNGKey -> ... -> Sample

Aqui está um programa "Olá mundo" que as amostras de uma distribuição log-normal .

def log_normal(key):
  return jnp.exp(random_variable(tfd.Normal(0., 1.))(key))

print(log_normal(random.PRNGKey(0)))
sns.distplot(jit(vmap(log_normal))(random.split(random.PRNGKey(0), 10000)))
plt.show()
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
0.8139614
/home/kbuilder/.local/lib/python3.6/site-packages/seaborn/distributions.py:2551: FutureWarning: `distplot` is a deprecated function and will be removed in a future version. Please adapt your code to use either `displot` (a figure-level function with similar flexibility) or `histplot` (an axes-level function for histograms).
  warnings.warn(msg, FutureWarning)

png

O log_normal função é um wrapper fino em torno de um Tensorflow Probabilidade (TFP) de distribuição, mas em vez de chamar tfd.Normal(0., 1.).sample , nós usamos random_variable vez. Como veremos mais tarde, random_variable nos permite converter objetos em programas probabilísticos, juntamente com outras funcionalidades úteis.

Podemos converter log_normal em uma função de log-densidade usando o log_prob transformação:

print(log_prob(log_normal)(1.))
x = jnp.linspace(0., 5., 1000)
plt.plot(x, jnp.exp(vmap(log_prob(log_normal))(x)))
plt.show()
-0.9189385

png

Porque nós temos anotado a função com random_variable , log_prob está ciente de que havia uma chamada para tfd.Normal(0., 1.).sample e usa tfd.Normal(0., 1.).log_prob para calcular a distribuição base log prob. Para lidar com a jnp.exp , ppl.log_prob calcula automaticamente densidades através de funções bijective, mantendo o controle de mudanças de volume no cálculo de mudança de variável.

Em Oryx, podemos tomar programas e transformá-los usando transformações de função - por exemplo, jax.jit ou log_prob . Oryx não pode fazer isso com qualquer programa; requer funções de amostragem que tenham registrado sua função de densidade de registro com Oryx. Felizmente, Oryx registra automaticamente TensorFlow Probabilidade distribuições (TFP) em seu sistema.

Ferramentas de programação probabilística do Oryx

Oryx tem várias transformações de função voltadas para a programação probabilística. Examinaremos a maioria deles e forneceremos alguns exemplos. No final, colocaremos tudo junto em um estudo de caso MCMC. Você também pode consultar a documentação para core.ppl.transformations para mais detalhes.

random_variable

random_variable tem duas principais peças de funcionalidade, ambos focados em anotar as funções do Python com informações que podem ser usadas em transformações.

  1. random_variable 'opera como a função identidade por padrão, mas pode usar registros de tipo específico de objetos converter em programs.` probabilística

    Para tipos que podem ser chamados (as funções do Python, lambdas, functools.partial s, etc.) e arbitrário object s (como JAX DeviceArray s) ele só vai voltar sua entrada.

    random_variable(x: object) == x
    random_variable(f: Callable[...]) == f
    

    Oryx regista automaticamente TensorFlow Probabilidade (TFP) distribuições, que são convertidos em programas probabilística que chamam a de distribuição de sample método.

    random_variable(tfd.Normal(0., 1.))(random.PRNGKey(0)) # ==> -0.20584235
    

    O Oryx adicionalmente incorpora informações sobre a distribuição TFP em rastreamentos JAX que permitem a computação automática de densidades de log.

  2. random_variable valores pode marcar com nomes, tornando-os úteis para transformações jusante, fornecendo um opcional name argumento palavra-chave para random_variable . Quando se passar uma matriz em random_variable juntamente com um name (por exemplo, random_variable(x, name='x') ), que apenas é marcado para o valor e devolve-lo. Se passar em um exigível ou distribuição TFP, random_variable retorna um programa que tags de sua amostra de saída com name .

Essas anotações não alterar a semântica do programa quando executado, mas só quando transformada (ou seja, o programa irá devolver o mesmo valor com ou sem o uso de random_variable ).

Vamos examinar um exemplo em que usamos as duas partes da funcionalidade juntas.

def latent_normal(key):
  z_key, x_key = random.split(key)
  z = random_variable(tfd.Normal(0., 1.), name='z')(z_key)
  return random_variable(tfd.Normal(z, 1e-1), name='x')(x_key)

Neste programa temos marcado a intermediários z e x , o que torna a transformações joint_sample , intervene , conditional e graph_replace ciente dos nomes 'z' e 'x' . Veremos exatamente como cada transformação usa nomes posteriormente.

log_prob

O log_prob transformação de função converte um programa probabilística Oryx para a sua função de log-densidade. Esta função de densidade de log pega uma amostra potencial do programa como entrada e retorna sua densidade de log sob a distribuição de amostragem subjacente.

log_prob :: Program -> (Sample -> LogDensity)

Como random_variable , ele funciona através de um registo de tipos de onde distribuições TFP são automaticamente registradas, de modo log_prob(tfd.Normal(0., 1.)) chama tfd.Normal(0., 1.).log_prob . Para as funções do Python, no entanto, log_prob traça o programa usando JAX e olha para a amostragem declarações. O log_prob transformação funciona na maioria dos programas que retornam variáveis aleatórias, directamente ou através de transformações inversíveis mas não em programas que valores de amostra internamente que não são devolvidos. Se ele não pode inverter as operações necessárias no programa, log_prob irá lançar um erro.

Aqui estão alguns exemplos de log_prob aplicadas a vários programas.

  1. log_prob funciona em programas que diretamente amostra de distribuições TFP (ou outros tipos registrados) e retornam seus valores.
def normal(key):
  return random_variable(tfd.Normal(0., 1.))(key)
print(log_prob(normal)(0.))
-0.9189385
  1. log_prob é capaz de calcular log-densidades das amostras de programas que transformam variates aleatórios utilizando funções bijective (por exemplo jnp.exp , jnp.tanh , jnp.split ).
def log_normal(key):
  return 2 * jnp.exp(random_variable(tfd.Normal(0., 1.))(key))
print(log_prob(log_normal)(1.))
-1.159165

Para calcular uma amostra de log_normal 's log-densidade, primeiro precisamos inverter a exp , tomando o log da amostra, e depois adicionar uma correção volume-mudança usando o log-det inversa Jacobian de exp (ver a mudança de variável fórmula de Wikipedia).

  1. log_prob obras com programas que as estruturas de amostras de saída gosta, dicionários Python ou tuplas.
def normal_2d(key):
  x = random_variable(
    tfd.MultivariateNormalDiag(jnp.zeros(2), jnp.ones(2)))(key)
  x1, x2 = jnp.split(x, 2, 0)
  return dict(x1=x1, x2=x2)
sample = normal_2d(random.PRNGKey(0))
print(sample)
print(log_prob(normal_2d)(sample))
{'x1': DeviceArray([-0.7847661], dtype=float32), 'x2': DeviceArray([0.8564447], dtype=float32)}
-2.5125546
  1. log_prob caminha o gráfico traçado cálculo da função, computação ambos os valores para a frente e inversos (e o seu log-det Jacobianos), quando necessário, numa tentativa de ligar valores devolvidos com os seus valores de amostra de base através de uma mudança bem definido de variáveis. Veja o seguinte programa de exemplo:
def complex_program(key):
  k1, k2 = random.split(key)
  z = random_variable(tfd.Normal(0., 1.))(k1)
  x = random_variable(tfd.Normal(jax.nn.relu(z), 1.))(k2)
  return jnp.exp(z), jax.nn.sigmoid(x)
sample = complex_program(random.PRNGKey(0))
print(sample)
print(log_prob(complex_program)(sample))
(DeviceArray(1.1547576, dtype=float32), DeviceArray(0.24830955, dtype=float32))
-1.0967848

Neste programa, nós amostra x condicionalmente sobre z , ou seja, nós precisamos o valor de z antes de podermos calcular o log-densidade de x . No entanto, a fim de calcular z , primeiro temos que inverta a jnp.exp aplicada a z . Assim, a fim de calcular os log-densidades de x e z , log_prob necessidades a primeira invertido a primeira saída, e depois passá-lo para a frente através do jax.nn.relu para calcular a média de p(x | z) .

Para mais informações sobre log_prob , você pode se referir a core.interpreters.log_prob . Na execução, log_prob está intimamente baseado fora do inverse transformação JAX; para saber mais sobre inverse , consulte core.interpreters.inverse .

joint_sample

Para definir programas mais complexos e interessantes, usaremos algumas variáveis ​​aleatórias latentes, ou seja, variáveis ​​aleatórias com valores não observados. Vamos referem-se a latent_normal programa que as amostras de um valor aleatório z que é usado como a média de um outro valor aleatório x .

def latent_normal(key):
  z_key, x_key = random.split(key)
  z = random_variable(tfd.Normal(0., 1.), name='z')(z_key)
  return random_variable(tfd.Normal(z, 1e-1), name='x')(x_key)

Neste programa, z é tão latente, se fôssemos apenas chamar latent_normal(random.PRNGKey(0)) não saberíamos o valor real de z que é responsável pela geração x .

joint_sample é uma transformação que transforma um programa para outro programa que retorna um dicionário nomes cadeia de mapeamento (tags) para seus valores. Para funcionar, precisamos ter certeza de marcar as variáveis ​​latentes para garantir que elas apareçam na saída da função transformada.

joint_sample(latent_normal)(random.PRNGKey(0))
{'x': DeviceArray(0.01873656, dtype=float32),
 'z': DeviceArray(0.14389044, dtype=float32)}

Note-se que joint_sample transforma um programa para outro programa que as amostras a distribuição conjunta sobre os seus valores latentes, para que possamos transformar-la ainda mais. Para algoritmos como MCMC e VI, é comum calcular a probabilidade de log da distribuição conjunta como parte do procedimento de inferência. log_prob(latent_normal) não funciona porque requer marginalizando a z , mas podemos usar log_prob(joint_sample(latent_normal)) .

print(log_prob(joint_sample(latent_normal))(dict(x=0., z=1.)))
print(log_prob(joint_sample(latent_normal))(dict(x=0., z=-10.)))
-50.03529
-5049.535

Uma vez que este é um teste padrão comum tal, Oryx também tem um joint_log_prob transformação que é apenas a composição de log_prob e joint_sample .

print(joint_log_prob(latent_normal)(dict(x=0., z=1.)))
print(joint_log_prob(latent_normal)(dict(x=0., z=-10.)))
-50.03529
-5049.535

block

O block transformação se em um programa e uma sequência de nomes e retorna um programa que se comporta de forma idêntica, excepto que em transformações a jusante (como joint_sample ), os nomes fornecidos são ignorados. Um exemplo de onde block é útil é a conversão de um conjunto de distribuição para uma prévia sobre as variáveis latentes por "bloqueio" os valores amostrados na probabilidade. Por exemplo, tomar latent_normal , que em primeiro lugar chama a um z ~ N(0, 1) , em seguida, um x | z ~ N(z, 1e-1) . block(latent_normal, names=['x']) é um programa que esconde o x nome, então se fizermos joint_sample(block(latent_normal, names=['x'])) , obtemos um dicionário com apenas z nele .

blocked = block(latent_normal, names=['x'])
joint_sample(blocked)(random.PRNGKey(0))
{'z': DeviceArray(0.14389044, dtype=float32)}

intervene

O intervene amostras clobbers transformação em um programa probabilística com valores a partir do exterior. Voltando ao nosso latent_normal programa, vamos dizer que estavam interessados em executar o mesmo programa, mas queria z a fixar a 4. Ao invés de escrever um novo programa, podemos usar intervene para substituir o valor de z .

intervened = intervene(latent_normal, z=4.)
sns.distplot(vmap(intervened)(random.split(random.PRNGKey(0), 10000)))
plt.show();
/home/kbuilder/.local/lib/python3.6/site-packages/seaborn/distributions.py:2551: FutureWarning: `distplot` is a deprecated function and will be removed in a future version. Please adapt your code to use either `displot` (a figure-level function with similar flexibility) or `histplot` (an axes-level function for histograms).
  warnings.warn(msg, FutureWarning)

png

Os intervened amostras de função de p(x | do(z = 4)) , que é apenas uma distribuição normal padrão centrado em 4. Quando intervene em um determinado valor, que o valor já não é considerada uma variável aleatória. Isto significa que uma z valor não vai ser marcadas durante a execução de intervened .

conditional

conditional transforma um programa que amostras latente valores em que condições sobre esses valores latentes. Voltando ao nosso latent_normal programa, que amostras p(x) com uma latente z , podemos convertê-lo em um programa condicional p(x | z) .

cond_program = conditional(latent_normal, 'z')
print(cond_program(random.PRNGKey(0), 100.))
print(cond_program(random.PRNGKey(0), 50.))
sns.distplot(vmap(lambda key: cond_program(key, 1.))(random.split(random.PRNGKey(0), 10000)))
sns.distplot(vmap(lambda key: cond_program(key, 2.))(random.split(random.PRNGKey(0), 10000)))
plt.show()
99.87485
49.874847
/home/kbuilder/.local/lib/python3.6/site-packages/seaborn/distributions.py:2551: FutureWarning: `distplot` is a deprecated function and will be removed in a future version. Please adapt your code to use either `displot` (a figure-level function with similar flexibility) or `histplot` (an axes-level function for histograms).
  warnings.warn(msg, FutureWarning)
/home/kbuilder/.local/lib/python3.6/site-packages/seaborn/distributions.py:2551: FutureWarning: `distplot` is a deprecated function and will be removed in a future version. Please adapt your code to use either `displot` (a figure-level function with similar flexibility) or `histplot` (an axes-level function for histograms).
  warnings.warn(msg, FutureWarning)

png

nest

Quando começamos a compor programas probabilísticos para construir outros mais complexos, é comum reutilizar funções que possuem alguma lógica importante. Por exemplo, se nós gostaríamos de construir uma rede neural Bayesian, pode haver um importante dense programa que amostras pesos e executa um passe para frente.

Se reutilizar funções, no entanto, pode acabar com os valores marcados duplicados no programa final, que é rejeitado por transformações como joint_sample . Podemos usar o nest para criar tag "escopos", onde todas as amostras dentro de um escopo nomeado será inserido em um dicionário aninhado.

def f(key):
  return random_variable(tfd.Normal(0., 1.), name='x')(key)

def g(key):
  k1, k2 = random.split(key)
  return nest(f, scope='x1')(k1) + nest(f, scope='x2')(k2)
joint_sample(g)(random.PRNGKey(0))
{'x1': {'x': DeviceArray(0.14389044, dtype=float32)},
 'x2': {'x': DeviceArray(-1.2515389, dtype=float32)} }

Estudo de caso: rede neural bayesiana

Vamos tentar a nossa mão de treinamento de uma rede neural Bayesian para classificar o clássico Fisher Iris conjunto de dados. É relativamente pequeno e de baixa dimensão, então podemos tentar amostrar diretamente a parte posterior com MCMC.

Primeiro, vamos importar o conjunto de dados e alguns utilitários adicionais do Oryx.

from sklearn import datasets
iris = datasets.load_iris()
features, labels = iris['data'], iris['target']

num_features = features.shape[-1]
num_classes = len(iris.target_names)

from oryx.experimental import mcmc
from oryx.util import summary, get_summaries

Começamos implementando uma camada densa, que terá antecedentes normais sobre os pesos e o viés. Para fazer isso, em primeiro lugar definir uma dense função de ordem superior que leva na função de dimensão e de saída de activação desejada. O dense função retorna um programa probabilística que representa uma distribuição condicional p(h | x) , onde h é a saída de uma camada densa e x é a sua entrada. It primeiras amostras do peso e de polarização e, em seguida, aplica-se-lhes a x .

def dense(dim_out, activation=jax.nn.relu):
  def forward(key, x):
    dim_in = x.shape[-1]
    w_key, b_key = random.split(key)
    w = random_variable(
          tfd.Sample(tfd.Normal(0., 1.), sample_shape=(dim_out, dim_in)),
          name='w')(w_key)
    b = random_variable(
          tfd.Sample(tfd.Normal(0., 1.), sample_shape=(dim_out,)),
          name='b')(b_key)
    return activation(jnp.dot(w, x) + b)
  return forward

Para compor vários dense camadas juntos, vamos implementar um mlp (perceptron multicamadas) função de ordem superior que o leva em uma lista de tamanhos escondidas e um número de classes. Ele retorna um programa que repetidamente chama dense usando o apropriado hidden_size e finalmente retorna logits para cada classe na camada final. Note-se a utilização de nest que cria âmbitos nome para cada camada.

def mlp(hidden_sizes, num_classes):
  num_hidden = len(hidden_sizes)
  def forward(key, x):
    keys = random.split(key, num_hidden + 1)
    for i, (subkey, hidden_size) in enumerate(zip(keys[:-1], hidden_sizes)):
      x = nest(dense(hidden_size), scope=f'layer_{i + 1}')(subkey, x)
    logits = nest(dense(num_classes, activation=lambda x: x),
                  scope=f'layer_{num_hidden + 1}')(keys[-1], x)
    return logits
  return forward

Para implementar o modelo completo, precisaremos modelar os rótulos como variáveis ​​aleatórias categóricas. Vamos definir uma predict função que leva em um conjunto de dados de xs (os recursos) que são então passadas para um mlp utilizando vmap . Quando usamos vmap(partial(mlp, mlp_key)) , que amostra um único conjunto de pesos, mas mapear a passagem para a frente sobre toda a entrada de xs . Isto produz um conjunto de logits que parametriza distribuições categóricas independentes.

def predict(mlp):
  def forward(key, xs):
    mlp_key, label_key = random.split(key)
    logits = vmap(partial(mlp, mlp_key))(xs)
    return random_variable(
        tfd.Independent(tfd.Categorical(logits=logits), 1), name='y')(label_key)
  return forward

Esse é o modelo completo! Vamos usar MCMC para amostrar a posteriori dos pesos BNN dados dados; Primeiro, construir uma BNN "template" usando mlp .

bnn = mlp([200, 200], num_classes)

Para construir um ponto de partida para a cadeia de Markov, podemos usar joint_sample com uma entrada de manequim.

weights = joint_sample(bnn)(random.PRNGKey(0), jnp.ones(num_features))
print(weights.keys())
dict_keys(['layer_1', 'layer_2', 'layer_3'])

O cálculo da probabilidade do log de distribuição conjunta é suficiente para muitos algoritmos de inferência. Vamos agora dizer que observamos x e quer provar a posterior p(z | x) . Para distribuições complexas, não será capaz de marginalizar fora x (embora para latent_normal podemos), mas podemos calcular uma densidade unnormalized registo log p(z, x) em que x é fixado para um valor particular. Podemos usar a probabilidade de log não normalizada com MCMC para amostrar a posterior. Vamos escrever esta função de log prob "fixada".

def target_log_prob(weights):
  return joint_log_prob(predict(bnn))(dict(weights, y=labels), features)

Agora podemos usar tfp.mcmc para provar a posterior usando a nossa função densidade log unnormalized. Note que nós vamos ter que usar uma versão "achatada" dos nossos pesos aninhados dicionário para ser compatível com tfp.mcmc , por isso usamos utilitários árvores de JAX para achatar e unflatten.

@jit
def run_chain(key, weights):
  flat_state, sample_tree = jax.tree_flatten(weights)

  def flat_log_prob(*states):
    return target_log_prob(jax.tree_unflatten(sample_tree, states))

  def trace_fn(_, results):
    return results.inner_results.accepted_results.target_log_prob

  flat_states, log_probs = tfp.mcmc.sample_chain(
    1000,
    num_burnin_steps=9000,
    kernel=tfp.mcmc.DualAveragingStepSizeAdaptation(
        tfp.mcmc.HamiltonianMonteCarlo(flat_log_prob, 1e-3, 100),
        9000, target_accept_prob=0.7),
    trace_fn=trace_fn,
    current_state=flat_state,
    seed=key)
  samples = jax.tree_unflatten(sample_tree, flat_states)
  return samples, log_probs
posterior_weights, log_probs = run_chain(random.PRNGKey(0), weights)
plt.plot(log_probs)
plt.show()

png

Podemos usar nossas amostras para obter uma estimativa de média do modelo bayesiano (BMA) da precisão do treinamento. Para calculá-lo, podemos usar intervene com bnn "inserir" posterior pesos no lugar dos que são amostrados a partir da chave. Para calcular logits para cada ponto de dados para cada amostra posterior, podemos dobrar vmap sobre posterior_weights e features .

output_logits = vmap(lambda weights: vmap(lambda x: intervene(bnn, **weights)(
    random.PRNGKey(0), x))(features))(posterior_weights)
output_probs = jax.nn.softmax(output_logits)
print('Average sample accuracy:', (
    output_probs.argmax(axis=-1) == labels[None]).mean())
print('BMA accuracy:', (
    output_probs.mean(axis=0).argmax(axis=-1) == labels[None]).mean())
Average sample accuracy: 0.9874067
BMA accuracy: 0.99333334

Conclusão

No Oryx, os programas probabilísticos são apenas funções JAX que recebem (pseudo-) aleatoriedade como uma entrada. Por causa da forte integração do Oryx com o sistema de transformação de função JAX, podemos escrever e manipular programas probabilísticos como se estivéssemos escrevendo o código JAX. Isso resulta em um sistema simples, mas flexível para construir modelos complexos e fazer inferências.