ML Community Day is November 9! Join us for updates from TensorFlow, JAX, and more Learn more


Creates a Head for single label binary classification.

Inherits From: Head

Used in the notebooks

Used in the guide

Uses sigmoid_cross_entropy_with_logits loss.

The head expects logits with shape [D0, D1, ... DN, 1]. In many applications, the shape is [batch_size, 1].

labels must be a dense Tensor with shape matching logits, namely [D0, D1, ... DN, 1]. If label_vocabulary given, labels must be a string Tensor with values from the vocabulary. If label_vocabulary is not given, labels must be float Tensor with values in the interval [0, 1].

If weight_column is specified, weights must be of shape [D0, D1, ... DN], or [D0, D1, ... DN, 1].

The loss is the weighted sum over the input dimensions. Namely, if the input labels have shape [batch_size, 1], the loss is the weighted sum over batch_size.

Also supports custom loss_fn. loss_fn takes (labels, logits) or (labels, logits, features, loss_reduction) as arguments and returns loss with shape [D0, D1, ... DN, 1]. loss_fn must support float labels with shape [D0, D1, ... DN, 1]. Namely, the head applies label_vocabulary to the input labels before passing them to loss_fn.


head = tf.estimator.BinaryClassHead()
logits = np.array(((45,), (-41,),), dtype=np.float32)
labels = np.array(((1,), (1,),), dtype=np.int32)
features = {'x': np.array(((42,),), dtype=np.float32)}
# expected_loss = sum(cross_entropy(labels, logits)) / batch_size
#               = sum(0, 41) / 2 = 41 / 2 = 20.50
loss = head.loss(labels, logits, features=features)
eval_metrics = head.metrics()
updated_metrics = head.update_metrics(
  eval_metrics, features, logits, labels)
for k in sorted(updated_metrics):
 print('{} : {:.2f}'.format(k, updated_metrics[k].result().numpy()))
  accuracy : 0.50
  accuracy_baseline : 1.00