tfr.keras.losses.ClickEMLoss

Computes click EM loss between y_true and y_pred.

Implementation of click EM loss (Wang et al, 2018). This loss assumes that a click is generated by a factorized model \(P(\text{examination}) \cdot P(\text{relevance})\), which are latent variables determined by exam_logits and rel_logits respectively.

Standalone usage:

y_true = [[1., 0.]]
y_pred = [[[0.6, 0.9], [0.8, 0.2]]]
loss = tfr.keras.losses.ClickEMLoss()
loss(y_true, y_pred).numpy()
1.1462884
# Using ragged tensors
y_true = tf.ragged.constant([[1., 0.], [0., 1., 0.]])
y_pred = tf.ragged.constant([[[0.6, 0.9], [0.8, 0.2]],
    [[0.5, 0.9], [0.8, 0.2], [0.4, 0.8]]])
loss = tfr.keras.losses.ClickEMLoss(ragged=True)
loss(y_true, y_pred).numpy()
1.0770882

Usage with the compile() API:

model.compile(optimizer='sgd', loss=tfr.keras.losses.ClickEMLoss())

reduction (Optional) The tf.keras.losses.Reduction to use (see tf.keras.losses.Loss).
name (Optional) The name for the op.
exam_loss_weight (Optional) Weight of examination logits.
rel_loss_weight (Optional) Weight of relevance logits.
ragged (Optional) If True, this loss will accept ragged tensors. If False, this loss will accept dense tensors.

Methods

from_config

Instantiates a Loss from its config (output of get_config()).

Args
config Output of get_config().

Returns
A Loss instance.

get_config

View source

Returns the config dictionary for a Loss instance.

__call__

View source

See tf.keras.losses.Loss.