nsl.lib.pairwise_distance_wrapper

A wrapper to compute the pairwise distance between sources and targets.

distances = weights * distance_config.distance_type(sources, targets)

This wrapper calculates the weighted distance between (sources, targets) pairs, and provides an option to return the distance as the sum over the difference along the given axis, when vector based distance is needed.

For the usage of weights and reduction, please refer to tf.losses. For the usage of sum_over_axis, see the following examples:

Given target tensors with shape [batch_size, features], the reduction set to tf.compat.v1.losses.Reduction.MEAN, and sum_over_axis set to the last dimension, the weighted average distance of sample pairs will be returned. For example: With a distance_config('L2', sum_over_axis=-1), the distance between [[1, 1], [2, 2], [0, 2], [5, 5]] and [[1, 1], [0, 2], [4, 4], [1, 4]] will be {(0+0) + (4+0) + (16+4) + (16+1)}/4 = 10.25

If sum_over_axis is None, the weighted average distance of feature pairs (instead of sample pairs) will be returned. For example: With a distance_config('L2'), the distance between [[1, 1], [2, 2], [0, 2], [5, 5]] and [[1, 1], [0, 2], [4, 4], [1, 4]] will be {(0+0) + (4+0) + (16+4) + (16+1)}/8 = 5.125

If transform_fn is not None, the transform function is applied to both sources and targets before computing the distance. For example: distance_config('KL_DIVERGENCE', sum_over_axis=-1, transform_fn='SOFTMAX') treats sources and targets as logits, and computes the KL-divergence between the two probability distributions.

sources Tensor of type float32 or float64.
targets Tensor of the same type and shape as sources.
weights (optional) Tensor whose rank is either 0, or the same as that of targets, and must be broadcastable to targets (i.e., all dimensions must be either 1, or the same as the corresponding distance dimension).
distance_config An instance of nsl.configs.DistanceConfig that contains the following configuration (or hyperparameters) for computing distances: (a) distance_type: Type of distance function to apply. (b) reduction: Type of distance reduction. See tf.losses.Reduction. (c) sum_over_axis: (optional) The distance is the sum over the difference along the specified axis. Note that if sum_over_axis is not None and the rank of weights is non-zero, then the size of weights along sum_over_axis must be 1. (d) transform_fn: (optional) If set, both sources and targets will be transformed before calculating the distance. If set to 'SOFTMAX', it will be performed on the axis specified by 'sum_over_axis', or -1 if the axis is not specified. If None, the default distance config will be used.

Weighted distance scalar Tensor. If reduction is tf.compat.v1.losses.Reduction.MEAN, this has the same shape as targets.

ValueError If the shape of targets doesn't match that of sources, or if the shape of weights is invalid.
TypeError If the distance function gets an unexpected keyword argument.