Post Export Metrics
As the name suggests, this is a metric that is added post-export, before evaluation.
TFMA is packaged with several pre-defined evaluation metrics, like example_count, auc, confusion_matrix_at_thresholds, precision_recall_at_k, mse, mae, to name a few. (Complete list here.)
If you don’t find an existing metrics relevant to your use-case, or want to customize a metric, you can define your own custom metric. Read on for the details!
Adding Custom Metrics in TFMA
Defining Custom Metrics in TFMA 1.x
Extend Abstract Base Class
To add a custom metric, create a new class extending _PostExportMetric abstract class and define its constructor and implement abstract / unimplemented methods.
Define Constructor
In the constructor, take as parameters all the relevant information like label_key, prediction_key, example_weight_key, metric_tag, etc. required for custom metric.
Implement Abstract / Unimplemented Methods
-
Implement this method to check for compatibility of the metric with the model being evaluated, i.e. checking if all required features, expected label and prediction key are present in the model in appropriate data type. It takes three arguments:
- features_dict
- predictions_dict
- labels_dict
These dictionaries contains references to Tensors for the model.
-
Implement this method to provide metric ops (value and update ops) to compute the metric. Similar to check_compatibility method, it also takes three arguments:
- features_dict
- predictions_dict
- labels_dict
Define your metric computation logic using these references to Tensors for the model.
populate_stats_and_pop and populate_plots_and_pop
Implement this metric to convert raw metric results to MetricValue and PlotData proto format. This takes three arguments:
- slice_key: Name of slice metric belongs to.
- combined_metrics: Dictionary containing raw results.
- output_metrics: Output dictionary containing metric in desired proto format.
@_export('my_metric')
class _MyMetric(_PostExportMetric):
def __init__(self,
target_prediction_keys: Optional[List[Text]] = None,
labels_key: Optional[Text] = None,
metric_tag: Optional[Text] = None):
self._target_prediction_keys = target_prediction_keys
self._label_keys = label_keys
self._metric_tag = metric_tag
self._metric_key = 'my_metric_key'
def check_compatibility(self, features_dict:types.TensorTypeMaybeDict,
predictions_dict: types.TensorTypeMaybeDict,
labels_dict: types.TensorTypeMaybeDict) -> None:
# Add compatibility check needed for the metric here.
def get_metric_ops(self, features_dict: types.TensorTypeMaybeDict,
predictions_dict: types.TensorTypeMaybeDict,
labels_dict: types.TensorTypeMaybeDict
) -> Dict[bytes, Tuple[types.TensorType,
types.TensorType]]:
# Metric computation logic here.
# Define value and update ops.
value_op = compute_metric_value(...)
update_op = create_update_op(... )
return {self._metric_key: (value_op, update_op)}
def populate_stats_and_pop(
self, slice_key: slicer.SliceKeyType, combined_metrics: Dict[Text, Any],
output_metrics: Dict[Text, metrics_pb2.MetricValue]) -> None:
# Parses the metric and converts it into required metric format.
metric_result = combined_metrics[self._metric_key]
output_metrics[self._metric_key].double_value.value = metric_result
Usage
# Custom metric callback
custom_metric_callback = my_metric(
labels_key='label',
target_prediction_keys=['prediction'])
fairness_indicators_callback =
post_export_metrics.fairness_indicators(
thresholds=[0.1, 0.3, 0.5, 0.7, 0.9], labels_key=label)
add_metrics_callbacks = [custom_metric_callback,
fairness_indicators_callback]
eval_shared_model = tfma.default_eval_shared_model(
eval_saved_model_path=eval_saved_model_path,
add_metrics_callbacks=add_metrics_callbacks)
eval_config = tfma.EvalConfig(...)
# Run evaluation
tfma.run_model_analysis(
eval_config=eval_config, eval_shared_model=eval_shared_model)