View on TensorFlow.org | Run in Google Colab | View source on GitHub | Download notebook | Run in Google Cloud Vertex AI Workbench |
This notebook-based tutorial will create a simple TFX pipeline and run it using Google Cloud Vertex Pipelines. This notebook is based on the TFX pipeline we built in Simple TFX Pipeline Tutorial. If you are not familiar with TFX and you have not read that tutorial yet, you should read it before proceeding with this notebook.
Google Cloud Vertex Pipelines helps you to automate, monitor, and govern your ML systems by orchestrating your ML workflow in a serverless manner. You can define your ML pipelines using Python with TFX, and then execute your pipelines on Google Cloud. See Vertex Pipelines introduction to learn more about Vertex Pipelines.
This notebook is intended to be run on Google Colab or on AI Platform Notebooks. If you are not using one of these, you can simply click "Run in Google Colab" button above.
Set up
Before you run this notebook, ensure that you have following:
- A Google Cloud Platform project.
- A Google Cloud Storage bucket. See the guide for creating buckets.
- Enable Vertex AI and Cloud Storage API.
Please see Vertex documentation to configure your GCP project further.
Install python packages
We will install required Python packages including TFX and KFP to author ML pipelines and submit jobs to Vertex Pipelines.
# Use the latest version of pip.
pip install --upgrade pip
pip install --upgrade "tfx[kfp]<2"
Did you restart the runtime?
If you are using Google Colab, the first time that you run the cell above, you must restart the runtime by clicking above "RESTART RUNTIME" button or using "Runtime > Restart runtime ..." menu. This is because of the way that Colab loads packages.
If you are not on Colab, you can restart runtime with following cell.
# docs_infra: no_execute
import sys
if not 'google.colab' in sys.modules:
# Automatically restart kernel after installs
import IPython
app = IPython.Application.instance()
app.kernel.do_shutdown(True)
Login in to Google for this notebook
If you are running this notebook on Colab, authenticate with your user account:
import sys
if 'google.colab' in sys.modules:
from google.colab import auth
auth.authenticate_user()
If you are on AI Platform Notebooks, authenticate with Google Cloud before running the next section, by running
gcloud auth login
in the Terminal window (which you can open via File > New in the menu). You only need to do this once per notebook instance.
Check the package versions.
import tensorflow as tf
print('TensorFlow version: {}'.format(tf.__version__))
from tfx import v1 as tfx
print('TFX version: {}'.format(tfx.__version__))
import kfp
print('KFP version: {}'.format(kfp.__version__))
2024-05-08 10:08:36.306387: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered 2024-05-08 10:08:36.306437: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered 2024-05-08 10:08:36.307973: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered TensorFlow version: 2.15.1 TFX version: 1.15.0 KFP version: 1.8.22
Set up variables
We will set up some variables used to customize the pipelines below. Following information is required:
- GCP Project id. See Identifying your project id.
- GCP Region to run pipelines. For more information about the regions that Vertex Pipelines is available in, see the Vertex AI locations guide.
- Google Cloud Storage Bucket to store pipeline outputs.
Enter required values in the cell below before running it.
GOOGLE_CLOUD_PROJECT = '' # <--- ENTER THIS
GOOGLE_CLOUD_REGION = '' # <--- ENTER THIS
GCS_BUCKET_NAME = '' # <--- ENTER THIS
if not (GOOGLE_CLOUD_PROJECT and GOOGLE_CLOUD_REGION and GCS_BUCKET_NAME):
from absl import logging
logging.error('Please set all required parameters.')
ERROR:absl:Please set all required parameters.
Set gcloud
to use your project.
gcloud config set project {GOOGLE_CLOUD_PROJECT}
ERROR: (gcloud.config.set) argument VALUE: Must be specified. Usage: gcloud config set SECTION/PROPERTY VALUE [optional flags] optional flags may be --help | --installation For detailed information on this command and its flags, run: gcloud config set --help
PIPELINE_NAME = 'penguin-vertex-pipelines'
# Path to various pipeline artifact.
PIPELINE_ROOT = 'gs://{}/pipeline_root/{}'.format(
GCS_BUCKET_NAME, PIPELINE_NAME)
# Paths for users' Python module.
MODULE_ROOT = 'gs://{}/pipeline_module/{}'.format(
GCS_BUCKET_NAME, PIPELINE_NAME)
# Paths for input data.
DATA_ROOT = 'gs://{}/data/{}'.format(GCS_BUCKET_NAME, PIPELINE_NAME)
# This is the path where your model will be pushed for serving.
SERVING_MODEL_DIR = 'gs://{}/serving_model/{}'.format(
GCS_BUCKET_NAME, PIPELINE_NAME)
print('PIPELINE_ROOT: {}'.format(PIPELINE_ROOT))
PIPELINE_ROOT: gs:///pipeline_root/penguin-vertex-pipelines
Prepare example data
We will use the same Palmer Penguins dataset as Simple TFX Pipeline Tutorial.
There are four numeric features in this dataset which were already normalized
to have range [0,1]. We will build a classification model which predicts the
species
of penguins.
We need to make our own copy of the dataset. Because TFX ExampleGen reads inputs from a directory, we need to create a directory and copy dataset to it on GCS.
gsutil cp gs://download.tensorflow.org/data/palmer_penguins/penguins_processed.csv {DATA_ROOT}/
InvalidUrlError: Cloud URL scheme should be followed by colon and two slashes: "://". Found: "gs:///data/penguin-vertex-pipelines/".
Take a quick look at the CSV file.
gsutil cat {DATA_ROOT}/penguins_processed.csv | head
InvalidUrlError: Cloud URL scheme should be followed by colon and two slashes: "://". Found: "gs:///data/penguin-vertex-pipelines/penguins_processed.csv".
Create a pipeline
TFX pipelines are defined using Python APIs. We will define a pipeline which consists of three components, CsvExampleGen, Trainer and Pusher. The pipeline and model definition is almost the same as Simple TFX Pipeline Tutorial.
The only difference is that we don't need to set metadata_connection_config
which is used to locate
ML Metadata database. Because
Vertex Pipelines uses a managed metadata service, users don't need to care
of it, and we don't need to specify the parameter.
Before actually define the pipeline, we need to write a model code for the Trainer component first.
Write model code.
We will use the same model code as in the Simple TFX Pipeline Tutorial.
_trainer_module_file = 'penguin_trainer.py'
%%writefile {_trainer_module_file}
# Copied from https://www.tensorflow.org/tfx/tutorials/tfx/penguin_simple
from typing import List
from absl import logging
import tensorflow as tf
from tensorflow import keras
from tensorflow_transform.tf_metadata import schema_utils
from tfx import v1 as tfx
from tfx_bsl.public import tfxio
from tensorflow_metadata.proto.v0 import schema_pb2
_FEATURE_KEYS = [
'culmen_length_mm', 'culmen_depth_mm', 'flipper_length_mm', 'body_mass_g'
]
_LABEL_KEY = 'species'
_TRAIN_BATCH_SIZE = 20
_EVAL_BATCH_SIZE = 10
# Since we're not generating or creating a schema, we will instead create
# a feature spec. Since there are a fairly small number of features this is
# manageable for this dataset.
_FEATURE_SPEC = {
**{
feature: tf.io.FixedLenFeature(shape=[1], dtype=tf.float32)
for feature in _FEATURE_KEYS
},
_LABEL_KEY: tf.io.FixedLenFeature(shape=[1], dtype=tf.int64)
}
def _input_fn(file_pattern: List[str],
data_accessor: tfx.components.DataAccessor,
schema: schema_pb2.Schema,
batch_size: int) -> tf.data.Dataset:
"""Generates features and label for training.
Args:
file_pattern: List of paths or patterns of input tfrecord files.
data_accessor: DataAccessor for converting input to RecordBatch.
schema: schema of the input data.
batch_size: representing the number of consecutive elements of returned
dataset to combine in a single batch
Returns:
A dataset that contains (features, indices) tuple where features is a
dictionary of Tensors, and indices is a single Tensor of label indices.
"""
return data_accessor.tf_dataset_factory(
file_pattern,
tfxio.TensorFlowDatasetOptions(
batch_size=batch_size, label_key=_LABEL_KEY),
schema=schema).repeat()
def _make_keras_model() -> tf.keras.Model:
"""Creates a DNN Keras model for classifying penguin data.
Returns:
A Keras Model.
"""
# The model below is built with Functional API, please refer to
# https://www.tensorflow.org/guide/keras/overview for all API options.
inputs = [keras.layers.Input(shape=(1,), name=f) for f in _FEATURE_KEYS]
d = keras.layers.concatenate(inputs)
for _ in range(2):
d = keras.layers.Dense(8, activation='relu')(d)
outputs = keras.layers.Dense(3)(d)
model = keras.Model(inputs=inputs, outputs=outputs)
model.compile(
optimizer=keras.optimizers.Adam(1e-2),
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=[keras.metrics.SparseCategoricalAccuracy()])
model.summary(print_fn=logging.info)
return model
# TFX Trainer will call this function.
def run_fn(fn_args: tfx.components.FnArgs):
"""Train the model based on given args.
Args:
fn_args: Holds args used to train the model as name/value pairs.
"""
# This schema is usually either an output of SchemaGen or a manually-curated
# version provided by pipeline author. A schema can also derived from TFT
# graph if a Transform component is used. In the case when either is missing,
# `schema_from_feature_spec` could be used to generate schema from very simple
# feature_spec, but the schema returned would be very primitive.
schema = schema_utils.schema_from_feature_spec(_FEATURE_SPEC)
train_dataset = _input_fn(
fn_args.train_files,
fn_args.data_accessor,
schema,
batch_size=_TRAIN_BATCH_SIZE)
eval_dataset = _input_fn(
fn_args.eval_files,
fn_args.data_accessor,
schema,
batch_size=_EVAL_BATCH_SIZE)
model = _make_keras_model()
model.fit(
train_dataset,
steps_per_epoch=fn_args.train_steps,
validation_data=eval_dataset,
validation_steps=fn_args.eval_steps)
# The result of the training should be saved in `fn_args.serving_model_dir`
# directory.
model.save(fn_args.serving_model_dir, save_format='tf')
Writing penguin_trainer.py
Copy the module file to GCS which can be accessed from the pipeline components. Because model training happens on GCP, we need to upload this model definition.
Otherwise, you might want to build a container image including the module file and use the image to run the pipeline.
gsutil cp {_trainer_module_file} {MODULE_ROOT}/
InvalidUrlError: Cloud URL scheme should be followed by colon and two slashes: "://". Found: "gs:///pipeline_module/penguin-vertex-pipelines/".
Write a pipeline definition
We will define a function to create a TFX pipeline.
# Copied from https://www.tensorflow.org/tfx/tutorials/tfx/penguin_simple and
# slightly modified because we don't need `metadata_path` argument.
def _create_pipeline(pipeline_name: str, pipeline_root: str, data_root: str,
module_file: str, serving_model_dir: str,
) -> tfx.dsl.Pipeline:
"""Creates a three component penguin pipeline with TFX."""
# Brings data into the pipeline.
example_gen = tfx.components.CsvExampleGen(input_base=data_root)
# Uses user-provided Python function that trains a model.
trainer = tfx.components.Trainer(
module_file=module_file,
examples=example_gen.outputs['examples'],
train_args=tfx.proto.TrainArgs(num_steps=100),
eval_args=tfx.proto.EvalArgs(num_steps=5))
# Pushes the model to a filesystem destination.
pusher = tfx.components.Pusher(
model=trainer.outputs['model'],
push_destination=tfx.proto.PushDestination(
filesystem=tfx.proto.PushDestination.Filesystem(
base_directory=serving_model_dir)))
# Following three components will be included in the pipeline.
components = [
example_gen,
trainer,
pusher,
]
return tfx.dsl.Pipeline(
pipeline_name=pipeline_name,
pipeline_root=pipeline_root,
components=components)
Run the pipeline on Vertex Pipelines.
We used LocalDagRunner
which runs on local environment in
Simple TFX Pipeline Tutorial.
TFX provides multiple orchestrators to run your pipeline. In this tutorial we
will use the Vertex Pipelines together with the Kubeflow V2 dag runner.
We need to define a runner to actually run the pipeline. You will compile your pipeline into our pipeline definition format using TFX APIs.
# docs_infra: no_execute
import os
PIPELINE_DEFINITION_FILE = PIPELINE_NAME + '_pipeline.json'
runner = tfx.orchestration.experimental.KubeflowV2DagRunner(
config=tfx.orchestration.experimental.KubeflowV2DagRunnerConfig(),
output_filename=PIPELINE_DEFINITION_FILE)
# Following function will write the pipeline definition to PIPELINE_DEFINITION_FILE.
_ = runner.run(
_create_pipeline(
pipeline_name=PIPELINE_NAME,
pipeline_root=PIPELINE_ROOT,
data_root=DATA_ROOT,
module_file=os.path.join(MODULE_ROOT, _trainer_module_file),
serving_model_dir=SERVING_MODEL_DIR))
The generated definition file can be submitted using kfp client.
# docs_infra: no_execute
from google.cloud import aiplatform
from google.cloud.aiplatform import pipeline_jobs
import logging
logging.getLogger().setLevel(logging.INFO)
aiplatform.init(project=GOOGLE_CLOUD_PROJECT, location=GOOGLE_CLOUD_REGION)
job = pipeline_jobs.PipelineJob(template_path=PIPELINE_DEFINITION_FILE,
display_name=PIPELINE_NAME)
job.submit()
Now you can visit the link in the output above or visit 'Vertex AI > Pipelines' in Google Cloud Console to see the progress.