tfr.keras.losses.CoupledRankDistilLoss

Computes the Rank Distil loss between y_true and y_pred.

The Coupled-RankDistil loss (Reddi et al, 2021) is the cross-entropy between k-Plackett's probability of logits (student) and labels (teacher).

Standalone usage:

tf.random.set_seed(1)
y_true = [[0., 2., 1.], [1., 0., 2.]]
ln = tf.math.log
y_pred = [[0., ln(3.), ln(2.)], [0., ln(2.), ln(3.)]]
loss = tfr.keras.losses.CoupledRankDistilLoss(topk=2, sample_size=1)
loss(y_true, y_pred).numpy()
2.138333

Usage with the compile() API:

model.compile(optimizer='sgd',
              loss=tfr.keras.losses.CoupledRankDistilLoss())

Definition:

The k-Plackett's probability model is defined as:

\[ \mathcal{P}_k(\pi|s) = \frac{1}{(N-k)!} \\ \frac{\prod_{i=1}^k exp(s_{\pi(i)})}{\sum_{j=k}^N log(exp(s_{\pi(i)}))}. \]

The Coupled-RankDistil loss is defined as:

\[ \mathcal{L}(y, s) = -\sum_{\pi} \mathcal{P}_k(\pi|y) log\mathcal{P}(\pi|s) \\ = \mathcal{E}_{\pi \sim \matcal{P}(.|y)} [-\log \mathcal{P}(\pi|s)] \]

reduction (Optional) The tf.keras.losses.Reduction to use (see tf.keras.losses.Loss).
name (Optional) The name for the op.
ragged (Optional) If True, this loss will accept ragged tensors. If False, this loss will accept dense tensors.
sample_size (Optional) Number of permutations to sample from teacher scores. Defaults to 8.
topk (Optional) top-k entries over which order is matched. A penalty is applied over non top-k items. Defaults to None, which treats top-k as all entries in the list.
temperature (Optional) A float number to modify the logits as logits=logits/temperature. Defaults to 1.

Methods

from_config

Instantiates a Loss from its config (output of get_config()).

Args
config Output of get_config().

Returns
A Loss instance.

get_config

View source

Returns the config dictionary for a Loss instance.

__call__

View source

See tf.keras.losses.Loss.