class tf.contrib.learn.MetricSpec

See the guide: Learn (contrib) > Estimators

MetricSpec connects a model to metric functions.

The MetricSpec class contains all information necessary to connect the output of a model_fn to the metrics (usually, streaming metrics) that are used in evaluation.

It is passed in the metrics argument of Estimator.evaluate. The Estimator then knows which predictions, labels, and weight to use to call a given metric function.

When building the ops to run in evaluation, Estimator will call create_metric_ops, which will connect the given metric_fn to the model as detailed in the docstring for create_metric_ops, and return the metric.


Assuming a model has an input function which returns inputs containing (among other things) a tensor with key "input_key", and a labels dictionary containing "label_key". Let's assume that the model_fn for this model returns a prediction with key "prediction_key".

In order to compute the accuracy of the "prediction_key" prediction, we would add

"prediction accuracy": MetricSpec(metric_fn=prediction_accuracy_fn,

to the metrics argument to evaluate. prediction_accuracy_fn can be either a predefined function in metric_ops (e.g., streaming_accuracy) or a custom function you define.

If we would like the accuracy to be weighted by "input_key", we can add that as the weight_key argument.

"prediction accuracy": MetricSpec(metric_fn=prediction_accuracy_fn,

An end-to-end example is as follows:

estimator = tf.contrib.learn.Estimator(...)
_ = estimator.evaluate(
        'prediction accuracy':







__init__(metric_fn, prediction_key=None, label_key=None, weight_key=None)


Creates a MetricSpec.


  • metric_fn: A function to use as a metric. Must accept predictions, labels and optionally, weights tensors as inputs, and must return either a single tensor which is interpreted as a value of this metric, or a pair (value_op, update_op), where value_op is the op to call to obtain the value of the metric, and update_op should be evaluated for each batch in order to update internal state.
  • prediction_key: The key for a tensor in the predictions dict (output from the model_fn) to use as the predictions input to the metric_fn. Optional. If None, the model_fn must return a single tensor or a dict with only a single entry as predictions.
  • label_key: The key for a tensor in the labels dict (output from the input_fn) to use as the labels input to the metric_fn. Optional. If None, the input_fn must return a single tensor or a dict with only a single entry as labels.
  • weight_key: The key for a tensor in the inputs dict (output from the input_fn) to use as the weights input to the metric_fn. Optional. If None, no weights will be passed to the metric_fn.

create_metric_ops(inputs, labels, predictions)

Connect our metric_fn to the specified members of the given dicts.

This function will call the metric_fn given in our constructor as follows:


And returns the result. The weights argument is only passed if self.weight_key is not None.

predictions and labels may be single tensors as well as dicts. If predictions is a single tensor, self.prediction_key must be None. If predictions is a single element dict, self.prediction_key is allowed to be None. Conversely, if labels is a single tensor, self.label_key must be None. If labels is a single element dict, self.label_key is allowed to be None.


  • inputs: A dict of inputs produced by the input_fn
  • labels: A dict of labels or a single label tensor produced by the input_fn.
  • predictions: A dict of predictions or a single tensor produced by the model_fn.


The result of calling metric_fn.


  • ValueError: If predictions or labels is a single Tensor and self.prediction_key or self.label_key is not None; or if self.label_key is None but labels is a dict with more than one element, or if self.prediction_key is None butpredictions` is a dict with more than one element.

Defined in tensorflow/contrib/learn/python/learn/