Métricas posteriores a la exportación

Como sugiere el nombre, esta es una métrica que se agrega después de la exportación, antes de la evaluación.

TFMA incluye varias métricas de evaluación predefinidas, como example_count, auc, confusion_matrix_at_thresholds, precision_recall_at_k, mse, mae, por nombrar algunas. (Lista completa aquí .)

Si no encuentra una métrica existente relevante para su caso de uso, o desea personalizar una métrica, puede definir su propia métrica personalizada. ¡Siga leyendo para conocer los detalles!

Adición de métricas personalizadas en TFMA

Definición de métricas personalizadas en TFMA 1.x

Ampliar clase base abstracta

Para agregar una métrica personalizada, cree una nueva clase que extienda la clase abstracta _PostExportMetric y defina su constructor e implemente métodos abstractos/no implementados.

Definir constructor

En el constructor, tome como parámetros toda la información relevante como etiqueta_clave, predicción_clave, ejemplo_peso_clave, métrica_etiqueta, etc. requerida para la métrica personalizada.

Implementar métodos abstractos/no implementados
  • verificar_compatibilidad

    Implemente este método para verificar la compatibilidad de la métrica con el modelo que se está evaluando, es decir, verificar si todas las características requeridas, la etiqueta esperada y la clave de predicción están presentes en el modelo en el tipo de datos apropiado. Se necesitan tres argumentos:

    • características_dict
    • predicciones_dict
    • etiquetas_dict

    Estos diccionarios contienen referencias a Tensores para el modelo.

  • get_metric_ops

    Implemente este método para proporcionar operaciones de métricas (operaciones de valor y actualización) para calcular la métrica. Similar al método check_compatibility, también toma tres argumentos:

    • características_dict
    • predicciones_dict
    • etiquetas_dict

    Defina la lógica de cálculo de su métrica utilizando estas referencias a Tensores para el modelo.

  • populate_stats_and_pop y populate_plots_and_pop

    Implemente esta métrica para convertir los resultados de la métrica sin procesar al formato de prototipo MetricValue y PlotData . Esto toma tres argumentos:

    • slice_key: nombre de la métrica de segmento a la que pertenece.
    • métricas_combinadas: diccionario que contiene resultados sin procesar.
    • output_metrics: diccionario de salida que contiene la métrica en el formato de prototipo deseado.
@_export('my_metric')
class _MyMetric(_PostExportMetric):
   def __init__(self,
                target_prediction_keys: Optional[List[Text]] = None,
                labels_key: Optional[Text] = None,
                metric_tag: Optional[Text] = None):
      self._target_prediction_keys = target_prediction_keys
      self._label_keys = label_keys
      self._metric_tag = metric_tag
      self._metric_key = 'my_metric_key'

   def check_compatibility(self, features_dict:types.TensorTypeMaybeDict,
                           predictions_dict: types.TensorTypeMaybeDict,
                           labels_dict: types.TensorTypeMaybeDict) -> None:
       # Add compatibility check needed for the metric here.

   def get_metric_ops(self, features_dict: types.TensorTypeMaybeDict,
                      predictions_dict: types.TensorTypeMaybeDict,
                      labels_dict: types.TensorTypeMaybeDict
                     ) -> Dict[bytes, Tuple[types.TensorType,
                     types.TensorType]]:
        # Metric computation logic here.
        # Define value and update ops.
        value_op = compute_metric_value(...)
        update_op = create_update_op(... )
        return {self._metric_key: (value_op, update_op)}

   def populate_stats_and_pop(
       self, slice_key: slicer.SliceKeyType, combined_metrics: Dict[Text, Any],
       output_metrics: Dict[Text, metrics_pb2.MetricValue]) -> None:
       # Parses the metric and converts it into required metric format.
       metric_result = combined_metrics[self._metric_key]
       output_metrics[self._metric_key].double_value.value = metric_result

Uso

# Custom metric callback
custom_metric_callback = my_metric(
    labels_key='label',
    target_prediction_keys=['prediction'])

fairness_indicators_callback =
   post_export_metrics.fairness_indicators(
        thresholds=[0.1, 0.3, 0.5, 0.7, 0.9], labels_key=label)

add_metrics_callbacks = [custom_metric_callback,
   fairness_indicators_callback]

eval_shared_model = tfma.default_eval_shared_model(
    eval_saved_model_path=eval_saved_model_path,
    add_metrics_callbacks=add_metrics_callbacks)

eval_config = tfma.EvalConfig(...)

# Run evaluation
tfma.run_model_analysis(
    eval_config=eval_config, eval_shared_model=eval_shared_model)