¡Confirme su asistencia a su evento local de TensorFlow Everywhere hoy!
Se usó la API de Cloud Translation para traducir esta página.
Switch to English

Un recorrido por Oryx

Ver en TensorFlow.org Ejecutar en Google Colab Ver fuente en GitHub Descargar cuaderno

¿Qué es Oryx?

Oryx es una biblioteca experimental que extiende JAX a aplicaciones que van desde la construcción y entrenamiento de redes neuronales complejas hasta la inferencia bayesiana aproximada en modelos generativos profundos. Al igual que ofrece JAX jit , vmap , y grad , Oryx proporciona un conjunto de transformaciones de función componibles que permiten escribir código simple y su transformación para construir la complejidad que se hospedan completamente interoperable con JAX.

JAX solo puede transformar de forma segura código puro y funcional (es decir, código sin efectos secundarios). Mientras que el código puro puede ser más fácil de escribir y razonar, el código "impuro" a menudo puede ser más conciso y expresivo.

En esencia, Oryx es una biblioteca que permite "aumentar" el código funcional puro para realizar tareas como definir el estado o extraer valores intermedios. Su objetivo es ser una capa lo más fina posible sobre JAX, aprovechando el enfoque minimalista de JAX para la computación numérica. Oryx se divide conceptualmente en varias "capas", cada una de las cuales se basa en la que está debajo.

El código fuente de Oryx se puede encontrar en GitHub .

Preparar

pip install -q oryx 1>/dev/null
import matplotlib.pyplot as plt
import seaborn as sns
sns.set(style='whitegrid')

import jax
import jax.numpy as jnp
from jax import random
from jax import vmap
from jax import jit
from jax import grad

import oryx

tfd = oryx.distributions

state = oryx.core.state
ppl = oryx.core.ppl

inverse = oryx.core.inverse
ildj = oryx.core.ildj
plant = oryx.core.plant
reap = oryx.core.reap
sow = oryx.core.sow
unzip = oryx.core.unzip

nn = oryx.experimental.nn
mcmc = oryx.experimental.mcmc
optimizers = oryx.experimental.optimizers

Capa 0: transformaciones de funciones base

En su base, Oryx define varias transformaciones de funciones nuevas. Estas transformaciones se implementan utilizando la maquinaria de seguimiento de JAX y son interoperables con las transformaciones JAX existentes como jit , grad , vmap , etc.

Inversión automática de funciones

oryx.core.inverse y oryx.core.ildj son transformaciones de funciones que pueden invertir programáticamente una función y calcular su inverso log-det Jacobian (ILDJ) respectivamente. Estas transformaciones son útiles en modelos probabilísticos para calcular probabilidades logarítmicas utilizando la fórmula de cambio de variable. Sin embargo, existen limitaciones en los tipos de funciones con las que son compatibles (consulte la documentación para obtener más detalles).

def f(x):
  return jnp.exp(x) + 2.
print(inverse(f)(4.))  # ln(2)
print(ildj(f)(4.)) # -ln(2)
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)

0.6931472
-0.6931472

Cosecha

oryx.core.harvest permite etiquetar valores en funciones junto con la capacidad de recopilarlos o "cosecharlos", y la capacidad de inyectar valores en su lugar, o "plantarlos". Marcamos los valores usando la función de sow .

def f(x):
  y = sow(x + 1., name='y', tag='intermediate')
  return y ** 2
print('Reap:', reap(f, tag='intermediate')(1.))  # Pulls out 'y'
print('Plant:', plant(f, tag='intermediate')(dict(y=5.), 1.))  # Injects 5. for 'y'
Reap: {'y': DeviceArray(2., dtype=float32)}
Plant: 25.0

Abrir la cremallera

oryx.core.unzip divide una función en dos a lo largo de un conjunto de valores etiquetados como intermedios, luego devuelve las funciones init_f y apply_f . init_f toma un argumento clave y devuelve los intermedios. apply_f devuelve una función que toma los intermedios y devuelve la salida de la función original.

def f(key, x):
  w = sow(random.normal(key), tag='variable', name='w')
  return w * x
init_f, apply_f = unzip(f, tag='variable')(random.PRNGKey(0), 1.)

La función init_f ejecuta f pero solo devuelve sus variables.

init_f(random.PRNGKey(0))
{'w': DeviceArray(-0.20584226, dtype=float32)}

apply_f toma un conjunto de variables como su primera entrada y ejecuta f con el conjunto de variables dado.

apply_f(dict(w=2.), 2.)  # Runs f with `w = 2`.
DeviceArray(4., dtype=float32)

Capa 1: transformaciones de nivel superior

Oryx se basa en las transformaciones de función inversa, recolección y descompresión de bajo nivel para ofrecer varias transformaciones de nivel superior para escribir cálculos con estado y para programación probabilística.

Funciones con estado ( core.state )

A menudo estamos interesados ​​en expresar cálculos con estado donde inicializamos un conjunto de parámetros y expresamos un cálculo en términos de parámetros. En oryx.core.state , Oryx proporciona una transformación init que convierte una función en una que inicializa un Module , un contenedor para el estado.

Module s asemejan Pytorch y TensorFlow Module s excepto que son inmutables.

def make_dense(dim_out):
  def forward(x, init_key=None):
    w_key, b_key = random.split(init_key)
    dim_in = x.shape[0]
    w = state.variable(random.normal(w_key, (dim_in, dim_out)), name='w')
    b = state.variable(random.normal(w_key, (dim_out,)), name='b')
    return jnp.dot(x, w) + b
  return forward

layer = state.init(make_dense(5))(random.PRNGKey(0), jnp.zeros(2))
print('layer:', layer)
print('layer.w:', layer.w)
print('layer.b:', layer.b)
layer: FunctionModule(dict_keys(['w', 'b']))
layer.w: [[-2.6105583   0.03385283  1.0863334  -1.4802988   0.48895672]
 [ 1.062516    0.5417484   0.0170228   0.2722685   0.30522448]]
layer.b: [0.59902626 0.2172144  2.4202902  0.03266738 1.2164948 ]

Module se registran como pytrees JAX y se pueden usar como entradas para funciones transformadas JAX. Oryx proporciona una función de call conveniente que ejecuta un Module .

vmap(state.call, in_axes=(None, 0))(layer, jnp.ones((5, 2)))
DeviceArray([[-0.94901603,  0.7928156 ,  3.5236464 , -1.1753628 ,
               2.010676  ],
             [-0.94901603,  0.7928156 ,  3.5236464 , -1.1753628 ,
               2.010676  ],
             [-0.94901603,  0.7928156 ,  3.5236464 , -1.1753628 ,
               2.010676  ],
             [-0.94901603,  0.7928156 ,  3.5236464 , -1.1753628 ,
               2.010676  ],
             [-0.94901603,  0.7928156 ,  3.5236464 , -1.1753628 ,
               2.010676  ]], dtype=float32)

La API de state también permite escribir actualizaciones con estado (como la ejecución de promedios) mediante la función de assign . El Module resultante tiene una función de update con una firma de entrada que es la misma que la __call__ del Module pero crea una nueva copia del Module con un estado actualizado.

def counter(x, init_key=None):
  count = state.variable(0., key=init_key, name='count')
  count = state.assign(count + 1., name='count')
  return x + count
layer = state.init(counter)(random.PRNGKey(0), 0.)
print(layer.count)
updated_layer = layer.update(0.)
print(updated_layer.count) # Count has advanced!
print(updated_layer.call(1.))
0.0
1.0
3.0

Programación probabilística

En oryx.core.ppl , Oryx proporciona un conjunto de herramientas construidas sobre harvest e inverse que tienen como objetivo hacer que escribir y transformar programas probabilísticos sea intuitivo y fácil.

En Oryx, un programa probabilístico es una función JAX que toma una fuente de aleatoriedad como primer argumento y devuelve una muestra de una distribución, es decir, f :: Key -> Sample . Para escribir estos programas, Oryx envuelve las distribuciones de probabilidad de TensorFlow y proporciona una función simple random_variable que convierte una distribución en un programa probabilístico.

def sample(key):
  return ppl.random_variable(tfd.Normal(0., 1.))(key)
sample(random.PRNGKey(0))
DeviceArray(-0.20584235, dtype=float32)

¿Qué podemos hacer con los programas probabilísticos? Lo más simple sería tomar un programa probabilístico (es decir, una función de muestreo) y convertirlo en uno que proporcione la densidad logarítmica de una muestra.

ppl.log_prob(sample)(1.)
DeviceArray(-1.4189385, dtype=float32)

La nueva función logarítmica de probabilidad es compatible con otras transformaciones JAX como vmap y grad .

grad(lambda s: vmap(ppl.log_prob(sample))(s).sum())(jnp.arange(10.))
DeviceArray([-0., -1., -2., -3., -4., -5., -6., -7., -8., -9.], dtype=float32)

Usando la transformación ildj , podemos calcular log_prob de programas que transforman muestras de manera invertible.

def sample(key):
  x = ppl.random_variable(tfd.Normal(0., 1.))(key)
  return jnp.exp(x / 2.) + 2.
_, ax = plt.subplots(2)
ax[0].hist(jit(vmap(sample))(random.split(random.PRNGKey(0), 1000)),
    bins='auto')
x = jnp.linspace(0, 8, 100)
ax[1].plot(x, jnp.exp(jit(vmap(ppl.log_prob(sample)))(x)))
plt.show()

png

Podemos etiquetar valores intermedios en un programa probabilístico con nombres y obtener muestras conjuntas y funciones log-prob conjuntas.

def sample(key):
  z_key, x_key = random.split(key)
  z = ppl.random_variable(tfd.Normal(0., 1.), name='z')(z_key)
  x = ppl.random_variable(tfd.Normal(z, 1.), name='x')(x_key)
  return x
ppl.joint_sample(sample)(random.PRNGKey(0))
{'x': DeviceArray(-1.1076484, dtype=float32),
 'z': DeviceArray(0.14389044, dtype=float32)}

Oryx también tiene una función joint_log_prob que compone log_prob con joint_sample .

ppl.joint_log_prob(sample)(dict(x=0., z=0.))
DeviceArray(-1.837877, dtype=float32)

Para obtener más información, consulte la documentación .

Capa 2: minibibliotecas

Sobre la base de las capas que manejan la programación probabilística y de estado, Oryx proporciona minibibliotecas experimentales diseñadas para aplicaciones específicas como el aprendizaje profundo y la inferencia bayesiana.

Redes neuronales

En oryx.experimental.nn , Oryx proporciona un conjunto de Layer redes neuronales comunes que encajan perfectamente en la API state . Estas capas se crean para ejemplos individuales (no lotes) pero anulan los comportamientos de los lotes para manejar patrones como los promedios de ejecución en la normalización de lotes. También permiten pasar argumentos de palabras clave como training=True/False en módulos.

Layer se inicializan desde una Template como nn.Dense(200) usando state.init .

layer = state.init(nn.Dense(200))(random.PRNGKey(0), jnp.zeros(50))
print(layer, layer.params.kernel.shape, layer.params.bias.shape)
Dense(200) (50, 200) (200,)

Una Layer tiene un método de call que ejecuta su pase hacia adelante.

layer.call(jnp.ones(50)).shape
(200,)

Oryx también proporciona un combinador Serial .

mlp_template = nn.Serial([
  nn.Dense(200), nn.Relu(),
  nn.Dense(200), nn.Relu(),
  nn.Dense(10), nn.Softmax()
])
# OR
mlp_template = (
    nn.Dense(200) >> nn.Relu()
    >> nn.Dense(200) >> nn.Relu()
    >> nn.Dense(10) >> nn.Softmax())
mlp = state.init(mlp_template)(random.PRNGKey(0), jnp.ones(784))
mlp(jnp.ones(784))
DeviceArray([0.16362445, 0.21150257, 0.14715882, 0.10425295, 0.05952952,
             0.07531884, 0.08368199, 0.0376978 , 0.0159679 , 0.10126514],            dtype=float32)

Podemos intercalar funciones y combinadores para crear un "metalenguaje" de red neuronal flexible.

def resnet(template):
  def forward(x, init_key=None):
    layer = state.init(template, name='layer')(init_key, x)
    return x + layer(x)
  return forward

big_resnet_template = nn.Serial([
  nn.Dense(50)
  >> resnet(nn.Dense(50) >> nn.Relu())
  >> resnet(nn.Dense(50) >> nn.Relu())
  >> nn.Dense(10)
])
network = state.init(big_resnet_template)(random.PRNGKey(0), jnp.ones(784))
network(jnp.ones(784))
DeviceArray([-0.03828401,  0.9046303 ,  1.6083915 , -0.17005858,
              3.889552  ,  1.7427744 , -1.0567027 ,  3.0192878 ,
              0.28983995,  1.7103616 ], dtype=float32)

Optimizadores

En oryx.experimental.optimizers , Oryx proporciona un conjunto de optimizadores de primer orden, construidos utilizando la API state . Su diseño se basa en la biblioteca optix de JAX, donde los optimizadores mantienen el estado de un conjunto de actualizaciones de gradiente. La versión de Oryx administra el estado mediante la API state .

network_key, opt_key = random.split(random.PRNGKey(0))
def autoencoder_loss(network, x):
  return jnp.square(network.call(x) - x).mean()
network = state.init(nn.Dense(200) >> nn.Relu() >> nn.Dense(2))(network_key, jnp.zeros(2))
opt = state.init(optimizers.adam(1e-4))(opt_key, network, network)
g = grad(autoencoder_loss)(network, jnp.zeros(2))

g, opt = opt.call_and_update(network, g)
network = optimizers.optix.apply_updates(network, g)

Cadena de Markov Monte Carlo

En oryx.experimental.mcmc , Oryx proporciona un conjunto de núcleos de Markov Chain Monte Carlo (MCMC). MCMC es un enfoque para aproximar la inferencia bayesiana en la que extraemos muestras de una cadena de Markov cuya distribución estacionaria es la distribución posterior de interés.

La biblioteca MCMC de Oryx se basa en la API state y ppl .

def model(key):
  return jnp.exp(ppl.random_variable(tfd.MultivariateNormalDiag(
      jnp.zeros(2), jnp.ones(2)))(key))

Metrópolis a pie aleatorio

samples = jit(mcmc.sample_chain(mcmc.metropolis(
    ppl.log_prob(model),
    mcmc.random_walk()), 1000))(random.PRNGKey(0), jnp.ones(2))
plt.scatter(samples[:, 0], samples[:, 1], alpha=0.5)
plt.show()

png

Montecarlo hamiltoniano

samples = jit(mcmc.sample_chain(mcmc.hmc(
    ppl.log_prob(model)), 1000))(random.PRNGKey(0), jnp.ones(2))
plt.scatter(samples[:, 0], samples[:, 1], alpha=0.5)
plt.show()

png