Better ML Engineering with ML Metadata

View on TensorFlow.org Run in Google Colab View source on GitHub Download notebook

Assume a scenario where you set up a production ML pipeline to classify penguins. The pipeline ingests your training data, trains and evaluates a model, and pushes it to production.

However, when you later try using this model with a larger dataset that contains different kinds of penguins, you observe that your model does not behave as expected and starts classifying the species incorrectly.

At this point, you are interested in knowing:

  • What is the most efficient way to debug the model when the only available artifact is the model in production?
  • Which training dataset was used to train the model?
  • Which training run led to this erroneous model?
  • Where are the model evaluation results?
  • Where to begin debugging?

ML Metadata (MLMD) is a library that leverages the metadata associated with ML models to help you answer these questions and more. A helpful analogy is to think of this metadata as the equivalent of logging in software development. MLMD enables you to reliably track the artifacts and lineage associated with the various components of your ML pipeline.

In this tutorial, you set up a TFX Pipeline to create a model that classifies penguins into three species based on the body mass and the length and depth of their culmens, and the length of their flippers. You then use MLMD to track the lineage of pipeline components.

TFX Pipelines in Colab

Colab is a lightweight development environment which differs significantly from a production environment. In production, you may have various pipeline components like data ingestion, transformation, model training, run histories, etc. across multiple, distributed systems. For this tutorial, you should be aware that siginificant differences exist in Orchestration and Metadata storage - it is all handled locally within Colab. Learn more about TFX in Colab here.

Setup

First, we install and import the necessary packages, set up paths, and download data.

Upgrade Pip

To avoid upgrading Pip in a system when running locally, check to make sure that we're running in Colab. Local systems can of course be upgraded separately.

try:
  import colab
  !pip install --upgrade pip
except:
  pass

Install and import TFX

pip install -q tfx

Import packages

Did you restart the runtime?

If you are using Google Colab, the first time that you run the cell above, you must restart the runtime by clicking above "RESTART RUNTIME" button or using "Runtime > Restart runtime ..." menu. This is because of the way that Colab loads packages.

import os
import tempfile
import urllib
import pandas as pd

import tensorflow_model_analysis as tfma
from tfx.orchestration.experimental.interactive.interactive_context import InteractiveContext

Check the TFX, and MLMD versions.

from tfx import v1 as tfx
print('TFX version: {}'.format(tfx.__version__))
import ml_metadata as mlmd
print('MLMD version: {}'.format(mlmd.__version__))
TFX version: 1.13.0
MLMD version: 1.13.1

Download the dataset

In this colab, we use the Palmer Penguins dataset which can be found on Github. We processed the dataset by leaving out any incomplete records, and drops island and sex columns, and converted labels to int32. The dataset contains 334 records of the body mass and the length and depth of penguins' culmens, and the length of their flippers. You use this data to classify penguins into one of three species.

DATA_PATH = 'https://raw.githubusercontent.com/tensorflow/tfx/master/tfx/examples/penguin/data/labelled/penguins_processed.csv'
_data_root = tempfile.mkdtemp(prefix='tfx-data')
_data_filepath = os.path.join(_data_root, "penguins_processed.csv")
urllib.request.urlretrieve(DATA_PATH, _data_filepath)
('/tmpfs/tmp/tfx-datap4i8w56n/penguins_processed.csv',
 <http.client.HTTPMessage at 0x7f3a76776370>)

Create an InteractiveContext

To run TFX components interactively in this notebook, create an InteractiveContext. The InteractiveContext uses a temporary directory with an ephemeral MLMD database instance. Note that calls to InteractiveContext are no-ops outside the Colab environment.

In general, it is a good practice to group similar pipeline runs under a Context.

interactive_context = InteractiveContext()
WARNING:absl:InteractiveContext pipeline_root argument not provided: using temporary directory /tmpfs/tmp/tfx-interactive-2023-07-28T11_11_10.063419-p4royv0g as root for pipeline outputs.
WARNING:absl:InteractiveContext metadata_connection_config not provided: using SQLite ML Metadata database at /tmpfs/tmp/tfx-interactive-2023-07-28T11_11_10.063419-p4royv0g/metadata.sqlite.

Construct the TFX Pipeline

A TFX pipeline consists of several components that perform different aspects of the ML workflow. In this notebook, you create and run the ExampleGen, StatisticsGen, SchemaGen, and Trainer components and use the Evaluator and Pusher component to evaluate and push the trained model.

Refer to the components tutorial for more information on TFX pipeline components.

Instantiate and run the ExampleGen Component

example_gen = tfx.components.CsvExampleGen(input_base=_data_root)
interactive_context.run(example_gen)
WARNING:apache_beam.runners.interactive.interactive_environment:Dependencies required for Interactive Beam PCollection visualization are not available, please use: `pip install apache-beam[interactive]` to install necessary dependencies to enable all data visualization features.
WARNING:apache_beam.io.tfrecordio:Couldn't find python-snappy so the implementation of _TFRecordUtil._masked_crc32c is not as fast as it could be.

Instantiate and run the StatisticsGen Component

statistics_gen = tfx.components.StatisticsGen(
    examples=example_gen.outputs['examples'])
interactive_context.run(statistics_gen)

Instantiate and run the SchemaGen Component

infer_schema = tfx.components.SchemaGen(
    statistics=statistics_gen.outputs['statistics'], infer_feature_shape=True)
interactive_context.run(infer_schema)

Instantiate and run the Trainer Component

# Define the module file for the Trainer component
trainer_module_file = 'penguin_trainer.py'
%%writefile {trainer_module_file}

# Define the training algorithm for the Trainer module file
import os
from typing import List, Text

import tensorflow as tf
from tensorflow import keras

from tfx import v1 as tfx
from tfx_bsl.public import tfxio

from tensorflow_metadata.proto.v0 import schema_pb2

# Features used for classification - culmen length and depth, flipper length,
# body mass, and species.

_LABEL_KEY = 'species'

_FEATURE_KEYS = [
    'culmen_length_mm', 'culmen_depth_mm', 'flipper_length_mm', 'body_mass_g'
]


def _input_fn(file_pattern: List[Text],
              data_accessor: tfx.components.DataAccessor,
              schema: schema_pb2.Schema, batch_size: int) -> tf.data.Dataset:
  return data_accessor.tf_dataset_factory(
      file_pattern,
      tfxio.TensorFlowDatasetOptions(
          batch_size=batch_size, label_key=_LABEL_KEY), schema).repeat()


def _build_keras_model():
  inputs = [keras.layers.Input(shape=(1,), name=f) for f in _FEATURE_KEYS]
  d = keras.layers.concatenate(inputs)
  d = keras.layers.Dense(8, activation='relu')(d)
  d = keras.layers.Dense(8, activation='relu')(d)
  outputs = keras.layers.Dense(3)(d)
  model = keras.Model(inputs=inputs, outputs=outputs)
  model.compile(
      optimizer=keras.optimizers.Adam(1e-2),
      loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
      metrics=[keras.metrics.SparseCategoricalAccuracy()])
  return model


def run_fn(fn_args: tfx.components.FnArgs):
  schema = schema_pb2.Schema()
  tfx.utils.parse_pbtxt_file(fn_args.schema_path, schema)
  train_dataset = _input_fn(
      fn_args.train_files, fn_args.data_accessor, schema, batch_size=10)
  eval_dataset = _input_fn(
      fn_args.eval_files, fn_args.data_accessor, schema, batch_size=10)
  model = _build_keras_model()
  model.fit(
      train_dataset,
      epochs=int(fn_args.train_steps / 20),
      steps_per_epoch=20,
      validation_data=eval_dataset,
      validation_steps=fn_args.eval_steps)
  model.save(fn_args.serving_model_dir, save_format='tf')
Writing penguin_trainer.py

Run the Trainer component.

trainer = tfx.components.Trainer(
    module_file=os.path.abspath(trainer_module_file),
    examples=example_gen.outputs['examples'],
    schema=infer_schema.outputs['schema'],
    train_args=tfx.proto.TrainArgs(num_steps=100),
    eval_args=tfx.proto.EvalArgs(num_steps=50))
interactive_context.run(trainer)
running bdist_wheel
running build
running build_py
creating build
creating build/lib
copying penguin_trainer.py -> build/lib
installing to /tmpfs/tmp/tmpi9la89eh
running install
running install_lib
copying build/lib/penguin_trainer.py -> /tmpfs/tmp/tmpi9la89eh
running install_egg_info
/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/setuptools/_distutils/cmd.py:66: SetuptoolsDeprecationWarning: setup.py install is deprecated.
!!

        ********************************************************************************
        Please avoid running ``setup.py`` directly.
        Instead, use pypa/build, pypa/installer or other
        standards-based tools.

        See https://blog.ganssle.io/articles/2021/10/setup-py-deprecated.html for details.
        ********************************************************************************

!!
  self.initialize_options()
running egg_info
creating tfx_user_code_Trainer.egg-info
writing tfx_user_code_Trainer.egg-info/PKG-INFO
writing dependency_links to tfx_user_code_Trainer.egg-info/dependency_links.txt
writing top-level names to tfx_user_code_Trainer.egg-info/top_level.txt
writing manifest file 'tfx_user_code_Trainer.egg-info/SOURCES.txt'
reading manifest file 'tfx_user_code_Trainer.egg-info/SOURCES.txt'
writing manifest file 'tfx_user_code_Trainer.egg-info/SOURCES.txt'
Copying tfx_user_code_Trainer.egg-info to /tmpfs/tmp/tmpi9la89eh/tfx_user_code_Trainer-0.0+fef7c4ed90dc336ca26daee59d65660cf8da5fa988b2ca0c89df2f558fda10f4-py3.9.egg-info
running install_scripts
creating /tmpfs/tmp/tmpi9la89eh/tfx_user_code_Trainer-0.0+fef7c4ed90dc336ca26daee59d65660cf8da5fa988b2ca0c89df2f558fda10f4.dist-info/WHEEL
creating '/tmpfs/tmp/tmp7hfjtnjp/tfx_user_code_Trainer-0.0+fef7c4ed90dc336ca26daee59d65660cf8da5fa988b2ca0c89df2f558fda10f4-py3-none-any.whl' and adding '/tmpfs/tmp/tmpi9la89eh' to it
adding 'penguin_trainer.py'
adding 'tfx_user_code_Trainer-0.0+fef7c4ed90dc336ca26daee59d65660cf8da5fa988b2ca0c89df2f558fda10f4.dist-info/METADATA'
adding 'tfx_user_code_Trainer-0.0+fef7c4ed90dc336ca26daee59d65660cf8da5fa988b2ca0c89df2f558fda10f4.dist-info/WHEEL'
adding 'tfx_user_code_Trainer-0.0+fef7c4ed90dc336ca26daee59d65660cf8da5fa988b2ca0c89df2f558fda10f4.dist-info/top_level.txt'
adding 'tfx_user_code_Trainer-0.0+fef7c4ed90dc336ca26daee59d65660cf8da5fa988b2ca0c89df2f558fda10f4.dist-info/RECORD'
removing /tmpfs/tmp/tmpi9la89eh
Processing /tmpfs/tmp/tfx-interactive-2023-07-28T11_11_10.063419-p4royv0g/_wheels/tfx_user_code_Trainer-0.0+fef7c4ed90dc336ca26daee59d65660cf8da5fa988b2ca0c89df2f558fda10f4-py3-none-any.whl
Installing collected packages: tfx-user-code-Trainer
Successfully installed tfx-user-code-Trainer-0.0+fef7c4ed90dc336ca26daee59d65660cf8da5fa988b2ca0c89df2f558fda10f4
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tfx_bsl/tfxio/tf_example_record.py:339: parse_example_dataset (from tensorflow.python.data.experimental.ops.parsing_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use `tf.data.Dataset.map(tf.io.parse_example(...))` instead.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tfx_bsl/tfxio/tf_example_record.py:339: parse_example_dataset (from tensorflow.python.data.experimental.ops.parsing_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use `tf.data.Dataset.map(tf.io.parse_example(...))` instead.
Epoch 1/5
20/20 [==============================] - 3s 19ms/step - loss: 0.9458 - sparse_categorical_accuracy: 0.7500 - val_loss: 0.8589 - val_sparse_categorical_accuracy: 0.7800
Epoch 2/5
20/20 [==============================] - 0s 11ms/step - loss: 0.6942 - sparse_categorical_accuracy: 0.8000 - val_loss: 0.5478 - val_sparse_categorical_accuracy: 0.7800
Epoch 3/5
20/20 [==============================] - 0s 11ms/step - loss: 0.4146 - sparse_categorical_accuracy: 0.8100 - val_loss: 0.3478 - val_sparse_categorical_accuracy: 0.7800
Epoch 4/5
20/20 [==============================] - 0s 11ms/step - loss: 0.2747 - sparse_categorical_accuracy: 0.9350 - val_loss: 0.2253 - val_sparse_categorical_accuracy: 0.9600
Epoch 5/5
20/20 [==============================] - 0s 11ms/step - loss: 0.1738 - sparse_categorical_accuracy: 0.9700 - val_loss: 0.1330 - val_sparse_categorical_accuracy: 0.9800
INFO:tensorflow:Assets written to: /tmpfs/tmp/tfx-interactive-2023-07-28T11_11_10.063419-p4royv0g/Trainer/model/4/Format-Serving/assets
INFO:tensorflow:Assets written to: /tmpfs/tmp/tfx-interactive-2023-07-28T11_11_10.063419-p4royv0g/Trainer/model/4/Format-Serving/assets

Evaluate and push the model

Use the Evaluator component to evaluate and 'bless' the model before using the Pusher component to push the model to a serving directory.

_serving_model_dir = os.path.join(tempfile.mkdtemp(),
                                  'serving_model/penguins_classification')
eval_config = tfma.EvalConfig(
    model_specs=[
        tfma.ModelSpec(label_key='species', signature_name='serving_default')
    ],
    metrics_specs=[
        tfma.MetricsSpec(metrics=[
            tfma.MetricConfig(
                class_name='SparseCategoricalAccuracy',
                threshold=tfma.MetricThreshold(
                    value_threshold=tfma.GenericValueThreshold(
                        lower_bound={'value': 0.6})))
        ])
    ],
    slicing_specs=[tfma.SlicingSpec()])
evaluator = tfx.components.Evaluator(
    examples=example_gen.outputs['examples'],
    model=trainer.outputs['model'],
    schema=infer_schema.outputs['schema'],
    eval_config=eval_config)
interactive_context.run(evaluator)
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_model_analysis/writers/metrics_plots_and_validations_writer.py:110: tf_record_iterator (from tensorflow.python.lib.io.tf_record) is deprecated and will be removed in a future version.
Instructions for updating:
Use eager execution and: 
`tf.data.TFRecordDataset(path)`
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_model_analysis/writers/metrics_plots_and_validations_writer.py:110: tf_record_iterator (from tensorflow.python.lib.io.tf_record) is deprecated and will be removed in a future version.
Instructions for updating:
Use eager execution and: 
`tf.data.TFRecordDataset(path)`
pusher = tfx.components.Pusher(
    model=trainer.outputs['model'],
    model_blessing=evaluator.outputs['blessing'],
    push_destination=tfx.proto.PushDestination(
        filesystem=tfx.proto.PushDestination.Filesystem(
            base_directory=_serving_model_dir)))
interactive_context.run(pusher)

Running the TFX pipeline populates the MLMD Database. In the next section, you use the MLMD API to query this database for metadata information.

Query the MLMD Database

The MLMD database stores three types of metadata:

  • Metadata about the pipeline and lineage information associated with the pipeline components
  • Metadata about artifacts that were generated during the pipeline run
  • Metadata about the executions of the pipeline

A typical production environment pipeline serves multiple models as new data arrives. When you encounter erroneous results in served models, you can query the MLMD database to isolate the erroneous models. You can then trace the lineage of the pipeline components that correspond to these models to debug your models

Set up the metadata (MD) store with the InteractiveContext defined previously to query the MLMD database.

connection_config = interactive_context.metadata_connection_config
store = mlmd.MetadataStore(connection_config)

# All TFX artifacts are stored in the base directory
base_dir = connection_config.sqlite.filename_uri.split('metadata.sqlite')[0]

Create some helper functions to view the data from the MD store.

def display_types(types):
  # Helper function to render dataframes for the artifact and execution types
  table = {'id': [], 'name': []}
  for a_type in types:
    table['id'].append(a_type.id)
    table['name'].append(a_type.name)
  return pd.DataFrame(data=table)
def display_artifacts(store, artifacts):
  # Helper function to render dataframes for the input artifacts
  table = {'artifact id': [], 'type': [], 'uri': []}
  for a in artifacts:
    table['artifact id'].append(a.id)
    artifact_type = store.get_artifact_types_by_id([a.type_id])[0]
    table['type'].append(artifact_type.name)
    table['uri'].append(a.uri.replace(base_dir, './'))
  return pd.DataFrame(data=table)
def display_properties(store, node):
  # Helper function to render dataframes for artifact and execution properties
  table = {'property': [], 'value': []}
  for k, v in node.properties.items():
    table['property'].append(k)
    table['value'].append(
        v.string_value if v.HasField('string_value') else v.int_value)
  for k, v in node.custom_properties.items():
    table['property'].append(k)
    table['value'].append(
        v.string_value if v.HasField('string_value') else v.int_value)
  return pd.DataFrame(data=table)

First, query the MD store for a list of all its stored ArtifactTypes.

display_types(store.get_artifact_types())

Next, query all PushedModel artifacts.

pushed_models = store.get_artifacts_by_type("PushedModel")
display_artifacts(store, pushed_models)

Query the MD store for the latest pushed model. This tutorial has only one pushed model.

pushed_model = pushed_models[-1]
display_properties(store, pushed_model)

One of the first steps in debugging a pushed model is to look at which trained model is pushed and to see which training data is used to train that model.

MLMD provides traversal APIs to walk through the provenance graph, which you can use to analyze the model provenance.

def get_one_hop_parent_artifacts(store, artifacts):
  # Get a list of artifacts within a 1-hop of the artifacts of interest
  artifact_ids = [artifact.id for artifact in artifacts]
  executions_ids = set(
      event.execution_id
      for event in store.get_events_by_artifact_ids(artifact_ids)
      if event.type == mlmd.proto.Event.OUTPUT)
  artifacts_ids = set(
      event.artifact_id
      for event in store.get_events_by_execution_ids(executions_ids)
      if event.type == mlmd.proto.Event.INPUT)
  return [artifact for artifact in store.get_artifacts_by_id(artifacts_ids)]

Query the parent artifacts for the pushed model.

parent_artifacts = get_one_hop_parent_artifacts(store, [pushed_model])
display_artifacts(store, parent_artifacts)

Query the properties for the model.

exported_model = parent_artifacts[0]
display_properties(store, exported_model)

Query the upstream artifacts for the model.

model_parents = get_one_hop_parent_artifacts(store, [exported_model])
display_artifacts(store, model_parents)

Get the training data the model trained with.

used_data = model_parents[0]
display_properties(store, used_data)

Now that you have the training data that the model trained with, query the database again to find the training step (execution). Query the MD store for a list of the registered execution types.

display_types(store.get_execution_types())

The training step is the ExecutionType named tfx.components.trainer.component.Trainer. Traverse the MD store to get the trainer run that corresponds to the pushed model.

def find_producer_execution(store, artifact):
  executions_ids = set(
      event.execution_id
      for event in store.get_events_by_artifact_ids([artifact.id])
      if event.type == mlmd.proto.Event.OUTPUT)
  return store.get_executions_by_id(executions_ids)[0]

trainer = find_producer_execution(store, exported_model)
display_properties(store, trainer)

Summary

In this tutorial, you learned about how you can leverage MLMD to trace the lineage of your TFX pipeline components and resolve issues.

To learn more about how to use MLMD, check out these additional resources: