tf.confusion_matrix(labels, predictions, num_classes=None, dtype=tf.int32, name=None, weights=None)

tf.confusion_matrix(labels, predictions, num_classes=None, dtype=tf.int32, name=None, weights=None)

Computes the confusion matrix from predictions and labels.

Calculate the Confusion Matrix for a pair of prediction and label 1-D int arrays.

The matrix columns represent the prediction labels and the rows represent the real labels. The confusion matrix is always a 2-D array of shape [n, n], where n is the number of valid labels for a given classification task. Both prediction and labels must be 1-D arrays of the same shape in order for this function to work.

If num_classes is None, then num_classes will be set to the one plus the maximum value in either predictions or labels. Class labels are expected to start at 0. E.g., if num_classes was three, then the possible labels would be [0, 1, 2].

If weights is not None, then each prediction contributes its corresponding weight to the total value of the confusion matrix cell.

For example:

  tf.contrib.metrics.confusion_matrix([1, 2, 4], [2, 2, 4]) ==>
      [[0 0 0 0 0]
       [0 0 1 0 0]
       [0 0 1 0 0]
       [0 0 0 0 0]
       [0 0 0 0 1]]

Note that the possible labels are assumed to be [0, 1, 2, 3, 4], resulting in a 5x5 confusion matrix.


  • labels: 1-D Tensor of real labels for the classification task.
  • predictions: 1-D Tensor of predictions for a given classification.
  • num_classes: The possible number of labels the classification task can have. If this value is not provided, it will be calculated using both predictions and labels array.
  • dtype: Data type of the confusion matrix.
  • name: Scope name.
  • weights: An optional Tensor whose shape matches predictions.


A k X k matrix representing the confusion matrix, where k is the number of possible labels in the classification task.


  • ValueError: If both predictions and labels are not 1-D vectors and have mismatched shapes, or if weights is not None and its shape doesn't match predictions.

Defined in tensorflow/python/ops/