Métricas pós-exportação

Como o nome sugere, esta é uma métrica adicionada pós-exportação, antes da avaliação.

O TFMA vem com várias métricas de avaliação predefinidas, como example_count, auc, confusão_matrix_at_thresholds, Precision_recall_at_k, mse, mae, para citar alguns. (Lista completa aqui .)

Se você não encontrar uma métrica existente relevante para seu caso de uso ou quiser personalizar uma métrica, poderá definir sua própria métrica personalizada. Continue lendo para saber os detalhes!

Adicionando métricas personalizadas no TFMA

Definindo Métricas Personalizadas no TFMA 1.x

Estender classe base abstrata

Para adicionar uma métrica personalizada, crie uma nova classe estendendo a classe abstrata _PostExportMetric e defina seu construtor e implemente métodos abstratos/não implementados.

Definir Construtor

No construtor, tome como parâmetros todas as informações relevantes como label_key, Prediction_key, example_weight_key, metric_tag, etc., necessárias para a métrica customizada.

Implementar métodos abstratos/não implementados
  • verificação_compatibilidade

    Implemente este método para verificar a compatibilidade da métrica com o modelo que está sendo avaliado, ou seja, verificando se todos os recursos necessários, rótulo esperado e chave de previsão estão presentes no modelo no tipo de dados apropriado. São necessários três argumentos:

    • recursos_dict
    • previsões_dict
    • rótulos_dict

    Esses dicionários contêm referências a tensores para o modelo.

  • get_metric_ops

    Implemente este método para fornecer operações métricas (operações de valor e atualização) para calcular a métrica. Semelhante ao método check_compatibility, também leva três argumentos:

    • recursos_dict
    • previsões_dict
    • rótulos_dict

    Defina sua lógica de cálculo métrico usando essas referências aos Tensores do modelo.

  • populate_stats_and_pop e populate_plots_and_pop

    Implemente esta métrica para converter resultados de métricas brutas para o formato proto MetricValue e PlotData . Isso leva três argumentos:

    • slice_key: nome ao qual a métrica de fatia pertence.
    • combinado_metrics: Dicionário contendo resultados brutos.
    • output_metrics: Dicionário de saída contendo a métrica no formato proto desejado.
@_export('my_metric')
class _MyMetric(_PostExportMetric):
   def __init__(self,
                target_prediction_keys: Optional[List[Text]] = None,
                labels_key: Optional[Text] = None,
                metric_tag: Optional[Text] = None):
      self._target_prediction_keys = target_prediction_keys
      self._label_keys = label_keys
      self._metric_tag = metric_tag
      self._metric_key = 'my_metric_key'

   def check_compatibility(self, features_dict:types.TensorTypeMaybeDict,
                           predictions_dict: types.TensorTypeMaybeDict,
                           labels_dict: types.TensorTypeMaybeDict) -> None:
       # Add compatibility check needed for the metric here.

   def get_metric_ops(self, features_dict: types.TensorTypeMaybeDict,
                      predictions_dict: types.TensorTypeMaybeDict,
                      labels_dict: types.TensorTypeMaybeDict
                     ) -> Dict[bytes, Tuple[types.TensorType,
                     types.TensorType]]:
        # Metric computation logic here.
        # Define value and update ops.
        value_op = compute_metric_value(...)
        update_op = create_update_op(... )
        return {self._metric_key: (value_op, update_op)}

   def populate_stats_and_pop(
       self, slice_key: slicer.SliceKeyType, combined_metrics: Dict[Text, Any],
       output_metrics: Dict[Text, metrics_pb2.MetricValue]) -> None:
       # Parses the metric and converts it into required metric format.
       metric_result = combined_metrics[self._metric_key]
       output_metrics[self._metric_key].double_value.value = metric_result

Uso

# Custom metric callback
custom_metric_callback = my_metric(
    labels_key='label',
    target_prediction_keys=['prediction'])

fairness_indicators_callback =
   post_export_metrics.fairness_indicators(
        thresholds=[0.1, 0.3, 0.5, 0.7, 0.9], labels_key=label)

add_metrics_callbacks = [custom_metric_callback,
   fairness_indicators_callback]

eval_shared_model = tfma.default_eval_shared_model(
    eval_saved_model_path=eval_saved_model_path,
    add_metrics_callbacks=add_metrics_callbacks)

eval_config = tfma.EvalConfig(...)

# Run evaluation
tfma.run_model_analysis(
    eval_config=eval_config, eval_shared_model=eval_shared_model)