TensorFlow 2.0 Beta is available Learn more

tf.contrib.losses.metric_learning.cluster_loss

View source on GitHub

Computes the clustering loss.

tf.contrib.losses.metric_learning.cluster_loss(
    labels,
    embeddings,
    margin_multiplier,
    enable_pam_finetuning=True,
    margin_type='nmi',
    print_losses=False
)

The following structured margins are supported: nmi: normalized mutual information ami: adjusted mutual information ari: adjusted random index vmeasure: v-measure const: indicator checking whether the two clusterings are the same.

Args:

  • labels: 2-D Tensor of labels of shape [batch size, 1]
  • embeddings: 2-D Tensor of embeddings of shape [batch size, embedding dimension]. Embeddings should be l2 normalized.
  • margin_multiplier: float32 scalar. multiplier on the structured margin term See section 3.2 of paper for discussion.
  • enable_pam_finetuning: Boolean, Whether to run local pam refinement. See section 3.4 of paper for discussion.
  • margin_type: Type of structured margin to use. See section 3.2 of paper for discussion. Can be 'nmi', 'ami', 'ari', 'vmeasure', 'const'.
  • print_losses: Boolean. Option to print the loss.

Paper: https://arxiv.org/abs/1612.01213.

Returns:

  • clustering_loss: A float32 scalar Tensor.

Raises:

  • ImportError: If sklearn dependency is not installed.