tf.estimator.MultiClassHead

Creates a Head for multi class classification.

Inherits From: Head

Uses sparse_softmax_cross_entropy loss.

The head expects logits with shape [D0, D1, ... DN, n_classes]. In many applications, the shape is [batch_size, n_classes].

labels must be a dense Tensor with shape matching logits, namely [D0, D1, ... DN, 1]. If label_vocabulary given, labels must be a string Tensor with values from the vocabulary. If label_vocabulary is not given, labels must be an integer Tensor with values specifying the class index.

If weight_column is specified, weights must be of shape [D0, D1, ... DN], or [D0, D1, ... DN, 1].

The loss is the weighted sum over the input dimensions. Namely, if the input labels have shape [batch_size, 1], the loss is the weighted sum over batch_size.

Also supports custom loss_fn. loss_fn takes (labels, logits) or (labels, logits, features, loss_reduction) as arguments and returns unreduced loss with shape [D0, D1, ... DN, 1]. loss_fn must support integer labels with shape [D0, D1, ... DN, 1]. Namely, the head applies label_vocabulary to the input labels before passing them to loss_fn.

Usage:

n_classes = 3
head = tf.estimator.MultiClassHead(n_classes)
logits = np.array(((10, 0, 0), (0, 10, 0),), dtype=np.float32)
labels = np.array(((1,), (1,)), dtype=np.int64)
features = {'x': np