Join us at TensorFlow World, Oct 28-31. Use code TF20 for 20% off select passes. Register now


View source on GitHub

A wrapper to compute pairwise distance between sources and targets.


distances = weights * distance_type(sources, targets)

This wrapper calculates the weighted distance between (sources, targets) pairs, and provides an option to return the distance as the sum over the difference along the given axis, when vector based distance is needed.

For the usage of weights and reduction, please refer to tf.losses. For the usage of sum_over_axis, see the following examples:

Given target tensors with shape [batch_size, features], reduction set to be MEAN, and sum_over_axis set to be last dimension, the weighted average distance of sample pairs will be returned. For example: With a distance_config('L2', sum_over_axis=-1), the distance between [[1, 1], [2, 2], [0, 2], [5, 5]] and [[1, 1], [0, 2], [4, 4], [1, 4]] will be {(0+0) + (4+0) + (16+4) + (16+1)}/4 = 10.25

If sum_over_axis is None, the weighted average distance of feature pairs (instead of sample pairs) will be returned. For example: With a distance_config('L2'), the distance between [[1, 1], [2, 2], [0, 2], [5, 5]] and [[1, 1], [0, 2], [4, 4], [1, 4]] will be {(0+0) + (4+0) + (16+4) + (16+1)}/8 = 5.125

If transform_fn is not None, the transform function is applied to both sources and targets before computing the distance. For example: distance_config('KL_DIVERGENCE', sum_over_axis=-1, transform_fn='SOFTMAX') treats sources and targets as logits, and computes the KL-divergence between the probability distributions.


  • sources: Tensor of type float32 or float64.
  • targets: Tensor of the same type and shape as sources.
  • weights: (optional) Tensor whose rank is either 0, or the same rank as targets, and must be broadcastable to targets (i.e., all dimensions must be either 1, or the same as the corresponding distance dimension).
  • distance_config: DistanceConfig contains the following configs (or hyper-parameters) for computing distances: (a) 'distance_type': Type of distance function to apply. (b) 'reduction': Type of distance reduction. Refer to tf.losses.Reduction. (c) 'sum_over_axis': (optional) The distance is sum over the difference along the axis. Note, if sum_over_axis is not None and the rank of weights is nonzero, the size of weights along the sum_over_axis must be 1. (d) 'transform_fn': (optional) If set, both sources and targets will be transformed before calculating the distance. If set to 'SOFTMAX', it will be performed on the axis specified by 'sum_over_axis', or -1 if that is not specified. If None, the default distance config will be used.


Weighted distance scalar Tensor. If reduction is NONE, this has the same shape as targets.


  • ValueError: If the shape of targets doesn't match that of sources, or if the shape of weights is invalid.
  • TypeError: If the distance function gets an unexpected keyword argument.