tfma.types.EvalSharedModel

Shared model used during extraction and evaluation.

Used in the notebooks

Used in the tutorials

More details on add_metrics_callbacks:

Each add_metrics_callback should have the following prototype: def add_metrics_callback(features_dict, predictions_dict, labels_dict):

Note that features_dict, predictions_dict and labels_dict are not necessarily dictionaries - they might also be Tensors, depending on what the model's eval_input_receiver_fn returns.

It should create and return a metric_ops dictionary, such that metric_ops['metric_name'] = (value_op, update_op), just as in the Trainer.

Short example:

def add_metrics_callback(features_dict, predictions_dict, labels): metrics_ops = {} metric_ops['mean_label'] = tf.metrics.mean(labels) metric_ops['mean_probability'] = tf.metrics.mean(tf.slice( predictions_dict['probabilities'], [0, 1], [2, 1])) return metric_ops

model_path Path to EvalSavedModel (containing the saved_model.pb file).
add_metrics_callbacks Optional list of callbacks for adding additional metrics to the graph. The names of the metrics added by the callbacks should not conflict with existing metrics. See below for more details about what each callback should do. The callbacks are only used during evaluation.
include_default_metrics True to include the default metrics that are part of the saved model graph during evaluation.
example_weight_key Example weight key (single-output model) or dict of example weight keys (multi-output model) keyed by output_name.
additional_fetches Prefixes of additional tensors stored in signature_def.inputs that should be fetched at prediction time. The "features" and "labels" tensors are handled automatically and should not be included in this list.
model_loader Model loader.
model_name Model name (should align with ModelSpecs.name).
model_type Model type (tfma.TF_KERAS, tfma.TF_LITE, tfma.TF_ESTIMATOR, ..).
rubber_stamp True if this model is being rubber stamped. When a model is rubber stamped diff thresholds will be ignored if an associated baseline model is not passed.
is_baseline A namedtuple alias for field number 9