Zoológico de distribuciones que se pueden aprender

Ver en TensorFlow.org Ejecutar en Google Colab Ver fuente en GitHub Descargar cuaderno

En este colab mostramos varios ejemplos de cómo construir distribuciones aprendebles ("entrenables"). (No hacemos ningún esfuerzo por explicar las distribuciones, solo para mostrar cómo construirlas).

import numpy as np
import tensorflow.compat.v2 as tf
import tensorflow_probability as tfp
from tensorflow_probability.python.internal import prefer_static
tfb = tfp.bijectors
tfd = tfp.distributions
tf.enable_v2_behavior()
event_size = 4
num_components = 3

Learnable normal multivariante con Identidad a escala para chol(Cov)

learnable_mvn_scaled_identity = tfd.Independent(
    tfd.Normal(
        loc=tf.Variable(tf.zeros(event_size), name='loc'),
        scale=tfp.util.TransformedVariable(
            tf.ones([1]),
            bijector=tfb.Exp(),
            name='scale')),
    reinterpreted_batch_ndims=1,
    name='learnable_mvn_scaled_identity')

print(learnable_mvn_scaled_identity)
print(learnable_mvn_scaled_identity.trainable_variables)
tfp.distributions.Independent("learnable_mvn_scaled_identity", batch_shape=[], event_shape=[4], dtype=float32)
(<tf.Variable 'loc:0' shape=(4,) dtype=float32, numpy=array([0., 0., 0., 0.], dtype=float32)>, <tf.Variable 'scale:0' shape=(1,) dtype=float32, numpy=array([0.], dtype=float32)>)

Learnable normal multivariante con diagonales por chol(Cov)

learnable_mvndiag = tfd.Independent(
    tfd.Normal(
        loc=tf.Variable(tf.zeros(event_size), name='loc'),
        scale=tfp.util.TransformedVariable(
            tf.ones(event_size),
            bijector=tfb.Softplus(),  # Use Softplus...cuz why not?
            name='scale')),
    reinterpreted_batch_ndims=1,
    name='learnable_mvn_diag')

print(learnable_mvndiag)
print(learnable_mvndiag.trainable_variables)
tfp.distributions.Independent("learnable_mvn_diag", batch_shape=[], event_shape=[4], dtype=float32)
(<tf.Variable 'loc:0' shape=(4,) dtype=float32, numpy=array([0., 0., 0., 0.], dtype=float32)>, <tf.Variable 'scale:0' shape=(4,) dtype=float32, numpy=array([0.54132485, 0.54132485, 0.54132485, 0.54132485], dtype=float32)>)

Mezcla de Multivarita Normal (esférica)

learnable_mix_mvn_scaled_identity = tfd.MixtureSameFamily(
    mixture_distribution=tfd.Categorical(
        logits=tf.Variable(
            # Changing the `1.` intializes with a geometric decay.
            -tf.math.log(1.) * tf.range(num_components, dtype=tf.float32),
            name='logits')),
    components_distribution=tfd.Independent(
        tfd.Normal(
            loc=tf.Variable(
              tf.random.normal([num_components, event_size]),
              name='loc'),
            scale=tfp.util.TransformedVariable(
              10. * tf.ones([num_components, 1]),
              bijector=tfb.Softplus(),  # Use Softplus...cuz why not?
              name='scale')),
        reinterpreted_batch_ndims=1),
    name='learnable_mix_mvn_scaled_identity')

print(learnable_mix_mvn_scaled_identity)
print(learnable_mix_mvn_scaled_identity.trainable_variables)
tfp.distributions.MixtureSameFamily("learnable_mix_mvn_scaled_identity", batch_shape=[], event_shape=[4], dtype=float32)
(<tf.Variable 'logits:0' shape=(3,) dtype=float32, numpy=array([-0., -0., -0.], dtype=float32)>, <tf.Variable 'loc:0' shape=(3, 4) dtype=float32, numpy=
array([[ 0.21316044,  0.18825649,  1.3055958 , -1.4072137 ],
       [-1.6604203 , -0.9415946 , -1.1349488 , -0.4928658 ],
       [-0.9672405 ,  0.45094398, -2.615817  ,  3.7891428 ]],
      dtype=float32)>, <tf.Variable 'scale:0' shape=(3, 1) dtype=float32, numpy=
array([[9.999954],
       [9.999954],
       [9.999954]], dtype=float32)>)

Mezcla de multivariante normal (esférico) con el peso de la primera mezcla no aprendeble

learnable_mix_mvndiag_first_fixed = tfd.MixtureSameFamily(
    mixture_distribution=tfd.Categorical(
        logits=tfp.util.TransformedVariable(
            # Initialize logits as geometric decay.
            -tf.math.log(1.5) * tf.range(num_components, dtype=tf.float32),
            tfb.Pad(paddings=[[1, 0]], constant_values=0)),
            name='logits'),
    components_distribution=tfd.Independent(
        tfd.Normal(
            loc=tf.Variable(
                # Use Rademacher...cuz why not?
                tfp.random.rademacher([num_components, event_size]),
                name='loc'),
            scale=tfp.util.TransformedVariable(
                10. * tf.ones([num_components, 1]),
                bijector=tfb.Softplus(),  # Use Softplus...cuz why not?
                name='scale')),
        reinterpreted_batch_ndims=1),
    name='learnable_mix_mvndiag_first_fixed')

print(learnable_mix_mvndiag_first_fixed)
print(learnable_mix_mvndiag_first_fixed.trainable_variables)
tfp.distributions.MixtureSameFamily("learnable_mix_mvndiag_first_fixed", batch_shape=[], event_shape=[4], dtype=float32)
(<tf.Variable 'Variable:0' shape=(2,) dtype=float32, numpy=array([-0.4054651, -0.8109302], dtype=float32)>, <tf.Variable 'loc:0' shape=(3, 4) dtype=float32, numpy=
array([[ 1.,  1., -1., -1.],
       [ 1., -1.,  1.,  1.],
       [-1.,  1., -1., -1.]], dtype=float32)>, <tf.Variable 'scale:0' shape=(3, 1) dtype=float32, numpy=
array([[9.999954],
       [9.999954],
       [9.999954]], dtype=float32)>)

Mezcla de normal multivariante (completa Cov )

learnable_mix_mvntril = tfd.MixtureSameFamily(
    mixture_distribution=tfd.Categorical(
        logits=tf.Variable(
            # Changing the `1.` intializes with a geometric decay.
            -tf.math.log(1.) * tf.range(num_components, dtype=tf.float32),
            name='logits')),
    components_distribution=tfd.MultivariateNormalTriL(
        loc=tf.Variable(tf.zeros([num_components, event_size]), name='loc'),
        scale_tril=tfp.util.TransformedVariable(
            10. * tf.eye(event_size, batch_shape=[num_components]),
            bijector=tfb.FillScaleTriL(),
            name='scale_tril')),
    name='learnable_mix_mvntril')

print(learnable_mix_mvntril)
print(learnable_mix_mvntril.trainable_variables)
tfp.distributions.MixtureSameFamily("learnable_mix_mvntril", batch_shape=[], event_shape=[4], dtype=float32)
(<tf.Variable 'loc:0' shape=(3, 4) dtype=float32, numpy=
array([[0., 0., 0., 0.],
       [0., 0., 0., 0.],
       [0., 0., 0., 0.]], dtype=float32)>, <tf.Variable 'scale_tril:0' shape=(3, 10) dtype=float32, numpy=
array([[9.999945, 0.      , 0.      , 0.      , 9.999945, 9.999945,

        0.      , 0.      , 0.      , 9.999945],
       [9.999945, 0.      , 0.      , 0.      , 9.999945, 9.999945,
        0.      , 0.      , 0.      , 9.999945],
       [9.999945, 0.      , 0.      , 0.      , 9.999945, 9.999945,
        0.      , 0.      , 0.      , 9.999945]], dtype=float32)>, <tf.Variable 'logits:0' shape=(3,) dtype=float32, numpy=array([-0., -0., -0.], dtype=float32)>)

Mezcla de normal multivariante (completa Cov ) con inaprensible primera mezcla y primer componente

# Make a bijector which pads an eye to what otherwise fills a tril.
num_tril_nonzero = lambda num_rows: num_rows * (num_rows + 1) // 2

num_tril_rows = lambda nnz: prefer_static.cast(
    prefer_static.sqrt(0.25 + 2. * prefer_static.cast(nnz, tf.float32)) - 0.5,
    tf.int32)

# TFP doesn't have a concat bijector, so we roll out our own.
class PadEye(tfb.Bijector):

  def __init__(self, tril_fn=None):
    if tril_fn is None:
      tril_fn = tfb.FillScaleTriL()
    self._tril_fn = getattr(tril_fn, 'inverse', tril_fn)
    super(PadEye, self).__init__(
        forward_min_event_ndims=2,
        inverse_min_event_ndims=2,
        is_constant_jacobian=True,
        name='PadEye')

  def _forward(self, x):
    num_rows = int(num_tril_rows(tf.compat.dimension_value(x.shape[-1])))
    eye = tf.eye(num_rows, batch_shape=prefer_static.shape(x)[:-2])
    return tf.concat([self._tril_fn(eye)[..., tf.newaxis, :], x],
                     axis=prefer_static.rank(x) - 2)

  def _inverse(self, y):
    return y[..., 1:, :]

  def _forward_log_det_jacobian(self, x):
    return tf.zeros([], dtype=x.dtype)

  def _inverse_log_det_jacobian(self, y):
    return tf.zeros([], dtype=y.dtype)

  def _forward_event_shape(self, in_shape):
    n = prefer_static.size(in_shape)
    return in_shape + prefer_static.one_hot(n - 2, depth=n, dtype=tf.int32)

  def _inverse_event_shape(self, out_shape):
    n = prefer_static.size(out_shape)
    return out_shape - prefer_static.one_hot(n - 2, depth=n, dtype=tf.int32)


tril_bijector = tfb.FillScaleTriL(diag_bijector=tfb.Softplus())
learnable_mix_mvntril_fixed_first = tfd.MixtureSameFamily(
  mixture_distribution=tfd.Categorical(
      logits=tfp.util.TransformedVariable(
          # Changing the `1.` intializes with a geometric decay.
          -tf.math.log(1.) * tf.range(num_components, dtype=tf.float32),
          bijector=tfb.Pad(paddings=[(1, 0)]),
          name='logits')),
  components_distribution=tfd.MultivariateNormalTriL(
      loc=tfp.util.TransformedVariable(
          tf.zeros([num_components, event_size]),
          bijector=tfb.Pad(paddings=[(1, 0)], axis=-2),
          name='loc'),
      scale_tril=tfp.util.TransformedVariable(
          10. * tf.eye(event_size, batch_shape=[num_components]),
          bijector=tfb.Chain([tril_bijector, PadEye(tril_bijector)]),
          name='scale_tril')),
  name='learnable_mix_mvntril_fixed_first')


print(learnable_mix_mvntril_fixed_first)
print(learnable_mix_mvntril_fixed_first.trainable_variables)
tfp.distributions.MixtureSameFamily("learnable_mix_mvntril_fixed_first", batch_shape=[], event_shape=[4], dtype=float32)
(<tf.Variable 'loc:0' shape=(2, 4) dtype=float32, numpy=
array([[0., 0., 0., 0.],
       [0., 0., 0., 0.]], dtype=float32)>, <tf.Variable 'scale_tril:0' shape=(2, 10) dtype=float32, numpy=
array([[9.999945, 0.      , 0.      , 0.      , 9.999945, 9.999945,

        0.      , 0.      , 0.      , 9.999945],
       [9.999945, 0.      , 0.      , 0.      , 9.999945, 9.999945,
        0.      , 0.      , 0.      , 9.999945]], dtype=float32)>, <tf.Variable 'logits:0' shape=(2,) dtype=float32, numpy=array([-0., -0.], dtype=float32)>)