Zoo de distributions apprenantes

Voir sur TensorFlow.org Exécuter dans Google Colab Voir la source sur GitHub Télécharger le cahier

Dans cette collaboration, nous montrons divers exemples de création de distributions apprenantes (« entrainables »). (Nous ne faisons aucun effort pour expliquer les distributions, seulement pour montrer comment les construire.)

import numpy as np
import tensorflow.compat.v2 as tf
import tensorflow_probability as tfp
from tensorflow_probability.python.internal import prefer_static
tfb = tfp.bijectors
tfd = tfp.distributions
tf.enable_v2_behavior()
event_size = 4
num_components = 3

Programmable normale multivariée avec Identity Scaled pour chol(Cov)

learnable_mvn_scaled_identity = tfd.Independent(
    tfd.Normal(
        loc=tf.Variable(tf.zeros(event_size), name='loc'),
        scale=tfp.util.TransformedVariable(
            tf.ones([1]),
            bijector=tfb.Exp(),
            name='scale')),
    reinterpreted_batch_ndims=1,
    name='learnable_mvn_scaled_identity')

print(learnable_mvn_scaled_identity)
print(learnable_mvn_scaled_identity.trainable_variables)
tfp.distributions.Independent("learnable_mvn_scaled_identity", batch_shape=[], event_shape=[4], dtype=float32)
(<tf.Variable 'loc:0' shape=(4,) dtype=float32, numpy=array([0., 0., 0., 0.], dtype=float32)>, <tf.Variable 'scale:0' shape=(1,) dtype=float32, numpy=array([0.], dtype=float32)>)

Programmable normale multivariée avec diagonale pour chol(Cov)

learnable_mvndiag = tfd.Independent(
    tfd.Normal(
        loc=tf.Variable(tf.zeros(event_size), name='loc'),
        scale=tfp.util.TransformedVariable(
            tf.ones(event_size),
            bijector=tfb.Softplus(),  # Use Softplus...cuz why not?
            name='scale')),
    reinterpreted_batch_ndims=1,
    name='learnable_mvn_diag')

print(learnable_mvndiag)
print(learnable_mvndiag.trainable_variables)
tfp.distributions.Independent("learnable_mvn_diag", batch_shape=[], event_shape=[4], dtype=float32)
(<tf.Variable 'loc:0' shape=(4,) dtype=float32, numpy=array([0., 0., 0., 0.], dtype=float32)>, <tf.Variable 'scale:0' shape=(4,) dtype=float32, numpy=array([0.54132485, 0.54132485, 0.54132485, 0.54132485], dtype=float32)>)

Mélange de Multivarié Normal (sphérique)

learnable_mix_mvn_scaled_identity = tfd.MixtureSameFamily(
    mixture_distribution=tfd.Categorical(
        logits=tf.Variable(
            # Changing the `1.` intializes with a geometric decay.
            -tf.math.log(1.) * tf.range(num_components, dtype=tf.float32),
            name='logits')),
    components_distribution=tfd.Independent(
        tfd.Normal(
            loc=tf.Variable(
              tf.random.normal([num_components, event_size]),
              name='loc'),
            scale=tfp.util.TransformedVariable(
              10. * tf.ones([num_components, 1]),
              bijector=tfb.Softplus(),  # Use Softplus...cuz why not?
              name='scale')),
        reinterpreted_batch_ndims=1),
    name='learnable_mix_mvn_scaled_identity')

print(learnable_mix_mvn_scaled_identity)
print(learnable_mix_mvn_scaled_identity.trainable_variables)
tfp.distributions.MixtureSameFamily("learnable_mix_mvn_scaled_identity", batch_shape=[], event_shape=[4], dtype=float32)
(<tf.Variable 'logits:0' shape=(3,) dtype=float32, numpy=array([-0., -0., -0.], dtype=float32)>, <tf.Variable 'loc:0' shape=(3, 4) dtype=float32, numpy=
array([[ 0.21316044,  0.18825649,  1.3055958 , -1.4072137 ],
       [-1.6604203 , -0.9415946 , -1.1349488 , -0.4928658 ],
       [-0.9672405 ,  0.45094398, -2.615817  ,  3.7891428 ]],
      dtype=float32)>, <tf.Variable 'scale:0' shape=(3, 1) dtype=float32, numpy=
array([[9.999954],
       [9.999954],
       [9.999954]], dtype=float32)>)

Mélange de normal multivarié (sphérique) avec le premier poids de mélange impossible à apprendre

learnable_mix_mvndiag_first_fixed = tfd.MixtureSameFamily(
    mixture_distribution=tfd.Categorical(
        logits=tfp.util.TransformedVariable(
            # Initialize logits as geometric decay.
            -tf.math.log(1.5) * tf.range(num_components, dtype=tf.float32),
            tfb.Pad(paddings=[[1, 0]], constant_values=0)),
            name='logits'),
    components_distribution=tfd.Independent(
        tfd.Normal(
            loc=tf.Variable(
                # Use Rademacher...cuz why not?
                tfp.random.rademacher([num_components, event_size]),
                name='loc'),
            scale=tfp.util.TransformedVariable(
                10. * tf.ones([num_components, 1]),
                bijector=tfb.Softplus(),  # Use Softplus...cuz why not?
                name='scale')),
        reinterpreted_batch_ndims=1),
    name='learnable_mix_mvndiag_first_fixed')

print(learnable_mix_mvndiag_first_fixed)
print(learnable_mix_mvndiag_first_fixed.trainable_variables)
tfp.distributions.MixtureSameFamily("learnable_mix_mvndiag_first_fixed", batch_shape=[], event_shape=[4], dtype=float32)
(<tf.Variable 'Variable:0' shape=(2,) dtype=float32, numpy=array([-0.4054651, -0.8109302], dtype=float32)>, <tf.Variable 'loc:0' shape=(3, 4) dtype=float32, numpy=
array([[ 1.,  1., -1., -1.],
       [ 1., -1.,  1.,  1.],
       [-1.,  1., -1., -1.]], dtype=float32)>, <tf.Variable 'scale:0' shape=(3, 1) dtype=float32, numpy=
array([[9.999954],
       [9.999954],
       [9.999954]], dtype=float32)>)

Mélange de multivariée normal (plein Cov )

learnable_mix_mvntril = tfd.MixtureSameFamily(
    mixture_distribution=tfd.Categorical(
        logits=tf.Variable(
            # Changing the `1.` intializes with a geometric decay.
            -tf.math.log(1.) * tf.range(num_components, dtype=tf.float32),
            name='logits')),
    components_distribution=tfd.MultivariateNormalTriL(
        loc=tf.Variable(tf.zeros([num_components, event_size]), name='loc'),
        scale_tril=tfp.util.TransformedVariable(
            10. * tf.eye(event_size, batch_shape=[num_components]),
            bijector=tfb.FillScaleTriL(),
            name='scale_tril')),
    name='learnable_mix_mvntril')

print(learnable_mix_mvntril)
print(learnable_mix_mvntril.trainable_variables)
tfp.distributions.MixtureSameFamily("learnable_mix_mvntril", batch_shape=[], event_shape=[4], dtype=float32)
(<tf.Variable 'loc:0' shape=(3, 4) dtype=float32, numpy=
array([[0., 0., 0., 0.],
       [0., 0., 0., 0.],
       [0., 0., 0., 0.]], dtype=float32)>, <tf.Variable 'scale_tril:0' shape=(3, 10) dtype=float32, numpy=
array([[9.999945, 0.      , 0.      , 0.      , 9.999945, 9.999945,

        0.      , 0.      , 0.      , 9.999945],
       [9.999945, 0.      , 0.      , 0.      , 9.999945, 9.999945,
        0.      , 0.      , 0.      , 9.999945],
       [9.999945, 0.      , 0.      , 0.      , 9.999945, 9.999945,
        0.      , 0.      , 0.      , 9.999945]], dtype=float32)>, <tf.Variable 'logits:0' shape=(3,) dtype=float32, numpy=array([-0., -0., -0.], dtype=float32)>)

Mélange de multivariée normal (plein Cov ) avec le premier mélange et premier composant unlearnable

# Make a bijector which pads an eye to what otherwise fills a tril.
num_tril_nonzero = lambda num_rows: num_rows * (num_rows + 1) // 2

num_tril_rows = lambda nnz: prefer_static.cast(
    prefer_static.sqrt(0.25 + 2. * prefer_static.cast(nnz, tf.float32)) - 0.5,
    tf.int32)

# TFP doesn't have a concat bijector, so we roll out our own.
class PadEye(tfb.Bijector):

  def __init__(self, tril_fn=None):
    if tril_fn is None:
      tril_fn = tfb.FillScaleTriL()
    self._tril_fn = getattr(tril_fn, 'inverse', tril_fn)
    super(PadEye, self).__init__(
        forward_min_event_ndims=2,
        inverse_min_event_ndims=2,
        is_constant_jacobian=True,
        name='PadEye')

  def _forward(self, x):
    num_rows = int(num_tril_rows(tf.compat.dimension_value(x.shape[-1])))
    eye = tf.eye(num_rows, batch_shape=prefer_static.shape(x)[:-2])
    return tf.concat([self._tril_fn(eye)[..., tf.newaxis, :], x],
                     axis=prefer_static.rank(x) - 2)

  def _inverse(self, y):
    return y[..., 1:, :]

  def _forward_log_det_jacobian(self, x):
    return tf.zeros([], dtype=x.dtype)

  def _inverse_log_det_jacobian(self, y):
    return tf.zeros([], dtype=y.dtype)

  def _forward_event_shape(self, in_shape):
    n = prefer_static.size(in_shape)
    return in_shape + prefer_static.one_hot(n - 2, depth=n, dtype=tf.int32)

  def _inverse_event_shape(self, out_shape):
    n = prefer_static.size(out_shape)
    return out_shape - prefer_static.one_hot(n - 2, depth=n, dtype=tf.int32)


tril_bijector = tfb.FillScaleTriL(diag_bijector=tfb.Softplus())
learnable_mix_mvntril_fixed_first = tfd.MixtureSameFamily(
  mixture_distribution=tfd.Categorical(
      logits=tfp.util.TransformedVariable(
          # Changing the `1.` intializes with a geometric decay.
          -tf.math.log(1.) * tf.range(num_components, dtype=tf.float32),
          bijector=tfb.Pad(paddings=[(1, 0)]),
          name='logits')),
  components_distribution=tfd.MultivariateNormalTriL(
      loc=tfp.util.TransformedVariable(
          tf.zeros([num_components, event_size]),
          bijector=tfb.Pad(paddings=[(1, 0)], axis=-2),
          name='loc'),
      scale_tril=tfp.util.TransformedVariable(
          10. * tf.eye(event_size, batch_shape=[num_components]),
          bijector=tfb.Chain([tril_bijector, PadEye(tril_bijector)]),
          name='scale_tril')),
  name='learnable_mix_mvntril_fixed_first')


print(learnable_mix_mvntril_fixed_first)
print(learnable_mix_mvntril_fixed_first.trainable_variables)
tfp.distributions.MixtureSameFamily("learnable_mix_mvntril_fixed_first", batch_shape=[], event_shape=[4], dtype=float32)
(<tf.Variable 'loc:0' shape=(2, 4) dtype=float32, numpy=
array([[0., 0., 0., 0.],
       [0., 0., 0., 0.]], dtype=float32)>, <tf.Variable 'scale_tril:0' shape=(2, 10) dtype=float32, numpy=
array([[9.999945, 0.      , 0.      , 0.      , 9.999945, 9.999945,

        0.      , 0.      , 0.      , 9.999945],
       [9.999945, 0.      , 0.      , 0.      , 9.999945, 9.999945,
        0.      , 0.      , 0.      , 9.999945]], dtype=float32)>, <tf.Variable 'logits:0' shape=(2,) dtype=float32, numpy=array([-0., -0.], dtype=float32)>)