tfa.losses.GIoULoss

View source on GitHub

Class GIoULoss

Implements the GIoU loss function.

GIoU loss was first introduced in the Generalized Intersection over Union: A Metric and A Loss for Bounding Box Regression. GIoU is an enhancement for models which use IoU in object detection.

Usage:

gl = tfa.losses.GIoULoss()
boxes1 = tf.constant([[4.0, 3.0, 7.0, 5.0], [5.0, 6.0, 10.0, 7.0]])
boxes2 = tf.constant([[3.0, 4.0, 6.0, 8.0], [14.0, 14.0, 15.0, 15.0]])
loss = gl(boxes1, boxes2)
print('Loss: ', loss.numpy())  # Loss: [1.07500000298023224, 1.9333333373069763]

Usage with tf.keras API:

model = tf.keras.Model(inputs, outputs)
model.compile('sgd', loss=tfa.losses.GIoULoss())

Args:

  • mode: one of ['giou', 'iou'], decided to calculate GIoU or IoU loss.

__init__

View source

__init__(
    mode='giou',
    reduction=tf.keras.losses.Reduction.AUTO,
    name='giou_loss'
)

Initialize self. See help(type(self)) for accurate signature.

Methods

__call__

__call__(
    y_true,
    y_pred,
    sample_weight=None
)

Invokes the Loss instance.

Args:

  • y_true: Ground truth values. shape = [batch_size, d0, .. dN]
  • y_pred: The predicted values. shape = [batch_size, d0, .. dN]
  • sample_weight: Optional sample_weight acts as a coefficient for the loss. If a scalar is provided, then the loss is simply scaled by the given value. If sample_weight is a tensor of size [batch_size], then the total loss for each sample of the batch is rescaled by the corresponding element in the sample_weight vector. If the shape of sample_weight is [batch_size, d0, .. dN-1] (or can be broadcasted to this shape), then each loss element of y_pred is scaled by the corresponding value of sample_weight. (Note ondN-1: all loss functions reduce by 1 dimension, usually axis=-1.)

Returns:

Weighted loss float Tensor. If reduction is NONE, this has shape [batch_size, d0, .. dN-1]; otherwise, it is scalar. (Note dN-1 because all loss functions reduce by 1 dimension, usually axis=-1.)

Raises:

  • ValueError: If the shape of sample_weight is invalid.

from_config

@classmethod
from_config(
    cls,
    config
)

Instantiates a Loss from its config (output of get_config()).

Args:

  • config: Output of get_config().

Returns:

A Loss instance.

get_config

View source

get_config()