ML Community Day is November 9! Join us for updates from TensorFlow, JAX, and more Learn more

tf.estimator.MultiLabelHead

Creates a Head for multi-label classification.

Inherits From: Head

Multi-label classification handles the case where each example may have zero or more associated labels, from a discrete set. This is distinct from MultiClassHead which has exactly one label per example.

Uses sigmoid_cross_entropy loss average over classes and weighted sum over the batch. Namely, if the input logits have shape [batch_size, n_classes], the loss is the average over n_classes and the weighted sum over batch_size.

The head expects logits with shape [D0, D1, ... DN, n_classes]. In many applications, the shape is [batch_size, n_classes].

Labels can be:

  • A multi-hot tensor of shape [D0, D1, ... DN, n_classes]
  • An integer SparseTensor of class indices. The dense_shape must be [D0, D1, ... DN, ?] and the values within [0, n_classes).
  • If label_vocabulary is given, a string SparseTensor. The dense_shape must be [D0, D1, ... DN, ?] and the values within label_vocabulary or a multi-hot tensor of shape [D0, D1, ... DN, n_classes].

If weight_column is specified, weights must be of shape [D0, D1, ... DN], or [D0, D1, ... DN, 1].

Also supports custom loss_fn. loss_fn takes (labels, logits) or (labels, logits, features) as arguments and returns unreduced loss with shape [D0, D1, ... DN, 1]. loss_fn must support indicator labels with shape [D0, D1, ... DN, n_classes]. Namely, the head applies label_vocabulary to the input labels before passing them to loss_fn.

Usage:

n_classes = 2
head = tf.estimator.MultiLabelHead(n_classes)
logits = np.array([[-1., 1.], [-1.5, 1.5]], dtype=np.float32)
labels = np.array([[1, 0], [1, 1]], dtype=np.int64)
features = {'x': np.array([[41], [42]], dtype=np.int32)}
# expected_loss = sum(_sigmoid_cross_entropy(labels, logits)) / batch_size
#               = sum(1.31326169, 0.9514133) / 2 = 1.13
loss = head.loss(labels, logits, features=features)
print('{:.2f}'.format(loss.numpy()))
1.13
eval_metrics = head.metrics()
updated_metrics = head.update_metrics(
  eval_metrics, features, logits, labels)
for k in sorted(updated_metrics):