View source on GitHub |
MetricSpec connects a model to metric functions.
tf.contrib.learn.MetricSpec(
metric_fn, prediction_key=None, label_key=None, weight_key=None
)
THIS CLASS IS DEPRECATED. See contrib/learn/README.md for general migration instructions.
The MetricSpec class contains all information necessary to connect the
output of a model_fn
to the metrics (usually, streaming metrics) that are
used in evaluation.
It is passed in the metrics
argument of Estimator.evaluate
. The
Estimator
then knows which predictions, labels, and weight to use to call a
given metric function.
When building the ops to run in evaluation, an Estimator
will call
create_metric_ops
, which will connect the given metric_fn
to the model
as detailed in the docstring for create_metric_ops
, and return the metric.
Example:
Assuming a model has an input function which returns inputs containing
(among other things) a tensor with key "input_key", and a labels dictionary
containing "label_key". Let's assume that the model_fn
for this model
returns a prediction with key "prediction_key".
In order to compute the accuracy of the "prediction_key" prediction, we would add
"prediction accuracy": MetricSpec(metric_fn=prediction_accuracy_fn,
prediction_key="prediction_key",
label_key="label_key")
to the metrics argument to evaluate
. prediction_accuracy_fn
can be either
a predefined function in metric_ops (e.g., streaming_accuracy
) or a custom
function you define.
If we would like the accuracy to be weighted by "input_key", we can add that
as the weight_key
argument.
"prediction accuracy": MetricSpec(metric_fn=prediction_accuracy_fn,
prediction_key="prediction_key",
label_key="label_key",
weight_key="input_key")
An end-to-end example is as follows:
estimator = tf.contrib.learn.Estimator(...)
estimator.fit(...)
_ = estimator.evaluate(
input_fn=input_fn,
steps=1,
metrics={
'prediction accuracy':
metric_spec.MetricSpec(
metric_fn=prediction_accuracy_fn,
prediction_key="prediction_key",
label_key="label_key")
})
Args | |
---|---|
metric_fn
|
A function to use as a metric. See _adapt_metric_fn for
rules on how predictions , labels , and weights are passed to this
function. This must return either a single Tensor , which is
interpreted as a value of this metric, or a pair
(value_op, update_op) , where value_op is the op to call to
obtain the value of the metric, and update_op should be run for
each batch to update internal state.
|
prediction_key
|
The key for a tensor in the predictions dict (output
from the model_fn ) to use as the predictions input to the
metric_fn . Optional. If None , the model_fn must return a single
tensor or a dict with only a single entry as predictions .
|
label_key
|
The key for a tensor in the labels dict (output from the
input_fn ) to use as the labels input to the metric_fn .
Optional. If None , the input_fn must return a single tensor or a
dict with only a single entry as labels .
|
weight_key
|
The key for a tensor in the inputs dict (output from the
input_fn ) to use as the weights input to the metric_fn .
Optional. If None , no weights will be passed to the metric_fn .
|
Attributes | |
---|---|
label_key
|
|
metric_fn
|
Metric function.
This function accepts named args: |
prediction_key
|
|
weight_key
|
Methods
create_metric_ops
create_metric_ops(
inputs, labels, predictions
)
Connect our metric_fn
to the specified members of the given dicts.
This function will call the metric_fn
given in our constructor as follows:
metric_fn(predictions[self.prediction_key],
labels[self.label_key],
weights=weights[self.weight_key])
And returns the result. The weights
argument is only passed if
self.weight_key
is not None
.
predictions
and labels
may be single tensors as well as dicts. If
predictions
is a single tensor, self.prediction_key
must be None
. If
predictions
is a single element dict, self.prediction_key
is allowed to
be None
. Conversely, if labels
is a single tensor, self.label_key
must
be None
. If labels
is a single element dict, self.label_key
is allowed
to be None
.
Args | |
---|---|
inputs
|
A dict of inputs produced by the input_fn
|
labels
|
A dict of labels or a single label tensor produced by the
input_fn .
|
predictions
|
A dict of predictions or a single tensor produced by the
model_fn .
|
Returns | |
---|---|
The result of calling metric_fn .
|
Raises | |
---|---|
ValueError
|
If predictions or labels is a single Tensor and
self.prediction_key or self.label_key is not None ; or if
self.label_key is None but labels is a dict with more than one
element, or if self.prediction_key is None but predictions is a
dict with more than one element.
|