tf.keras.metrics.MeanMetricWrapper

Wraps a stateless metric function with the Mean metric.

Inherits From: Mean, Metric, Layer, Module

You could use this class to quickly build a mean metric from a function. The function needs to have the signature fn(y_true, y_pred) and return a per-sample loss array. MeanMetricWrapper.result() will return the average metric value across all samples seen so far.

For example:

def accuracy(y_true, y_pred):
  return tf.cast(tf.math.equal(y_true, y_pred), tf.float32)

accuracy_metric = tf.keras.metrics.MeanMetricWrapper(fn=accuracy)

keras_model.compile(..., metrics=accuracy_metric)

fn The metric function to wrap, with signature