Encapsulates metric logic and state.

Inherits From: Layer, Module

Used in the notebooks

Used in the guide

name (Optional) string name of the metric instance.
dtype (Optional) data type of the metric result.
**kwargs Additional layer keywords arguments.

Standalone usage:

m = SomeMetric(...)
for input in ...:
print('Final result: ', m.result().numpy())

Usage with compile() API:

model = tf.keras.Sequential()
model.add(tf.keras.layers.Dense(64, activation='relu'))
model.add(tf.keras.layers.Dense(64, activation='relu'))
model.add(tf.keras.layers.Dense(10, activation='softmax'))


data = np.random.random((1000, 32))
labels = np.random.random((1000, 10))

dataset =, labels))
dataset = dataset.batch(32), epochs=10)

To be implemented by subclasses:

  • __init__(): All state variables should be created in this method by calling self.add_weight() like: self.var = self.add_weight(...)
  • update_state(): Has all updates to the state variables like: self.var.assign_add(...).
  • result(): Computes and returns a value for the metric from the state variables.

Example subclass implementation:

class BinaryTruePositives(tf.keras.metrics.Metric):

  def __init__(self, name='binary_true_positives', **kwargs):
    super(BinaryTruePositives, self).__init__(name=name, **kwargs)
    self.true_positives = self.add_weight(name='tp', initializer='zeros')

  def update_state(self, y_true, y_pred, sample_weight=None):
    y_true = tf.cast(y_true, tf.bool)
    y_pred = tf.cast(y_pred, tf.bool)

    values = tf.logical_and(tf.equal(y_true, True), tf.equal(y_pred, True))
    values = tf.cast(values, self.dtype)
    if sample_weight is not None:
      sample_weight = tf.cast(sample_weight, self.dtype)
      sample_weight = tf.broadcast_to(sample_weight, values.shape)
      values = tf.multiply(values, sample_weight)

  def result(self):
    return s