Vertex AI Training and Serving with TFX and Vertex Pipelines

This notebook-based tutorial will create and run a TFX pipeline which trains an ML model using Vertex AI Training service and publishes it to Vertex AI for serving.

This notebook is based on the TFX pipeline we built in Simple TFX Pipeline for Vertex Pipelines Tutorial. If you have not read that tutorial yet, you should read it before proceeding with this notebook.

You can train models on Vertex AI using AutoML, or use custom training. In custom training, you can select many different machine types to power your training jobs, enable distributed training, use hyperparameter tuning, and accelerate with GPUs.

You can also serve prediction requests by deploying the trained model to Vertex AI Models and creating an endpoint.

In this tutorial, we will use Vertex AI Training with custom jobs to train a model in a TFX pipeline. We will also deploy the model to serve prediction request using Vertex AI.

This notebook is intended to be run on Google Colab or on AI Platform Notebooks. If you are not using one of these, you can simply click "Run in Google Colab" button above.

Set up

If you have completed Simple TFX Pipeline for Vertex Pipelines Tutorial, you will have a working GCP project and a GCS bucket and that is all we need for this tutorial. Please read the preliminary tutorial first if you missed it.

Install python packages

We will install required Python packages including TFX and KFP to author ML pipelines and submit jobs to Vertex Pipelines.

# Use the latest version of pip.
pip install --upgrade pip
pip install --upgrade "tfx[kfp]<2"

Uninstall shapely

TODO(b/263441833) This is a temporal solution to avoid an ImportError. Ultimately, it should be handled by supporting a recent version of Bigquery, instead of uninstalling other extra dependencies.

pip uninstall shapely -y

Did you restart the runtime?

If you are using Google Colab, the first time that you run the cell above, you must restart the runtime by clicking above "RESTART RUNTIME" button or using "Runtime > Restart runtime ..." menu. This is because of the way that Colab loads packages.

If you are not on Colab, you can restart runtime with following cell.

# docs_infra: no_execute
import sys
if not 'google.colab' in sys.modules:
  # Automatically restart kernel after installs
  import IPython
  app = IPython.Application.instance()
  app.kernel.do_shutdown(True)

Login in to Google for this notebook

If you are running this notebook on Colab, authenticate with your user account:

import sys
if 'google.colab' in sys.modules:
  from google.colab import auth
  auth.authenticate_user()

If you are on AI Platform Notebooks, authenticate with Google Cloud before running the next section, by running

gcloud auth login

in the Terminal window (which you can open via File > New in the menu). You only need to do this once per notebook instance.

Check the package versions.

import tensorflow as tf
print('TensorFlow version: {}'.format(tf.__version__))
from tfx import v1 as tfx
print('TFX version: {}'.format(tfx.__version__))
import kfp
print('KFP version: {}'.format(kfp.__version__))
TensorFlow version: 2.12.1
TFX version: 1.13.0
KFP version: 1.8.22

Set up variables

We will set up some variables used to customize the pipelines below. Following information is required:

Enter required values in the cell below before running it.

GOOGLE_CLOUD_PROJECT = ''     # <--- ENTER THIS
GOOGLE_CLOUD_REGION = ''      # <--- ENTER THIS
GCS_BUCKET_NAME = ''          # <--- ENTER THIS

if not (GOOGLE_CLOUD_PROJECT and GOOGLE_CLOUD_REGION and GCS_BUCKET_NAME):
    from absl import logging
    logging.error('Please set all required parameters.')
ERROR:absl:Please set all required parameters.

Set gcloud to use your project.

gcloud config set project {GOOGLE_CLOUD_PROJECT}
ERROR: (gcloud.config.set) argument VALUE: Must be specified.
Usage: gcloud config set SECTION/PROPERTY VALUE [optional flags]
  optional flags may be  --help | --installation

For detailed information on this command and its flags, run:
  gcloud config set --help
PIPELINE_NAME = 'penguin-vertex-training'

# Path to various pipeline artifact.
PIPELINE_ROOT = 'gs://{}/pipeline_root/{}'.format(GCS_BUCKET_NAME, PIPELINE_NAME)

# Paths for users' Python module.
MODULE_ROOT = 'gs://{}/pipeline_module/{}'.format(GCS_BUCKET_NAME, PIPELINE_NAME)

# Paths for users' data.
DATA_ROOT = 'gs://{}/data/{}'.format(GCS_BUCKET_NAME, PIPELINE_NAME)

# Name of Vertex AI Endpoint.
ENDPOINT_NAME = 'prediction-' + PIPELINE_NAME

print('PIPELINE_ROOT: {}'.format(PIPELINE_ROOT))
PIPELINE_ROOT: gs:///pipeline_root/penguin-vertex-training

Prepare example data

We will use the same Palmer Penguins dataset as Simple TFX Pipeline Tutorial.

There are four numeric features in this dataset which were already normalized to have range [0,1]. We will build a classification model which predicts the species of penguins.

We need to make our own copy of the dataset. Because TFX ExampleGen reads inputs from a directory, we need to create a directory and copy dataset to it on GCS.

gsutil cp gs://download.tensorflow.org/data/palmer_penguins/penguins_processed.csv {DATA_ROOT}/
InvalidUrlError: Cloud URL scheme should be followed by colon and two slashes: "://". Found: "gs:///data/penguin-vertex-training/".

Take a quick look at the CSV file.

gsutil cat {DATA_ROOT}/penguins_processed.csv | head
InvalidUrlError: Cloud URL scheme should be followed by colon and two slashes: "://". Found: "gs:///data/penguin-vertex-training/penguins_processed.csv".

Create a pipeline

Our pipeline will be very similar to the pipeline we created in Simple TFX Pipeline for Vertex Pipelines Tutorial. The pipeline will consists of three components, CsvExampleGen, Trainer and Pusher. But we will use a special Trainer and Pusher component. The Trainer component will move training workloads to Vertex AI, and the Pusher component will publish the trained ML model to Vertex AI instead of a filesystem.

TFX provides a special Trainer to submit training jobs to Vertex AI Training service. All we have to do is use Trainer in the extension module instead of the standard Trainer component along with some required GCP parameters.

In this tutorial, we will run Vertex AI Training jobs only using CPUs first and then with a GPU.

TFX also provides a special Pusher to upload the model to Vertex AI Models. Pusher will create Vertex AI Endpoint resource to serve online perdictions, too. See Vertex AI documentation to learn more about online predictions provided by Vertex AI.

Write model code.

The model itself is almost similar to the model in Simple TFX Pipeline Tutorial.

We will add _get_distribution_strategy() function which creates a TensorFlow distribution strategy and it is used in run_fn to use MirroredStrategy if GPU is available.

_trainer_module_file = 'penguin_trainer.py'
%%writefile {_trainer_module_file}

# Copied from https://www.tensorflow.org/tfx/tutorials/tfx/penguin_simple and
# slightly modified run_fn() to add distribution_strategy.

from typing import List
from absl import logging
import tensorflow as tf
from tensorflow import keras
from tensorflow_metadata.proto.v0 import schema_pb2
from tensorflow_transform.tf_metadata import schema_utils

from tfx import v1 as tfx
from tfx_bsl.public import tfxio

_FEATURE_KEYS = [
    'culmen_length_mm', 'culmen_depth_mm', 'flipper_length_mm', 'body_mass_g'
]
_LABEL_KEY = 'species'

_TRAIN_BATCH_SIZE = 20
_EVAL_BATCH_SIZE = 10

# Since we're not generating or creating a schema, we will instead create
# a feature spec.  Since there are a fairly small number of features this is
# manageable for this dataset.
_FEATURE_SPEC = {
    **{
        feature: tf.io.FixedLenFeature(shape=[1], dtype=tf.float32)
        for feature in _FEATURE_KEYS
    }, _LABEL_KEY: tf.io.FixedLenFeature(shape=[1], dtype=tf.int64)
}


def _input_fn(file_pattern: List[str],
              data_accessor: tfx.components.DataAccessor,
              schema: schema_pb2.Schema,
              batch_size: int) -> tf.data.Dataset:
  """Generates features and label for training.

  Args:
    file_pattern: List of paths or patterns of input tfrecord files.
    data_accessor: DataAccessor for converting input to RecordBatch.
    schema: schema of the input data.
    batch_size: representing the number of consecutive elements of returned
      dataset to combine in a single batch

  Returns:
    A dataset that contains (features, indices) tuple where features is a
      dictionary of Tensors, and indices is a single Tensor of label indices.
  """
  return data_accessor.tf_dataset_factory(
      file_pattern,
      tfxio.TensorFlowDatasetOptions(
          batch_size=batch_size, label_key=_LABEL_KEY),
      schema=schema).repeat()


def _make_keras_model() -> tf.keras.Model:
  """Creates a DNN Keras model for classifying penguin data.

  Returns:
    A Keras Model.
  """
  # The model below is built with Functional API, please refer to
  # https://www.tensorflow.org/guide/keras/overview for all API options.
  inputs = [keras.layers.Input(shape=(1,), name=f) for f in _FEATURE_KEYS]
  d = keras.layers.concatenate(inputs)
  for _ in range(2):
    d = keras.layers.Dense(8, activation='relu')(d)
  outputs = keras.layers.Dense(3)(d)

  model = keras.Model(inputs=inputs, outputs=outputs)
  model.compile(
      optimizer=keras.optimizers.Adam(1e-2),
      loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
      metrics=[keras.metrics.SparseCategoricalAccuracy()])

  model.summary(print_fn=logging.info)
  return model


# NEW: Read `use_gpu` from the custom_config of the Trainer.
#      if it uses GPU, enable MirroredStrategy.
def _get_distribution_strategy(fn_args: tfx.components.FnArgs):
  if fn_args.custom_config.get('use_gpu', False):
    logging.info('Using MirroredStrategy with one GPU.')
    return tf.distribute.MirroredStrategy(devices=['device:GPU:0'])
  return None


# TFX Trainer will call this function.
def run_fn(fn_args: tfx.components.FnArgs):
  """Train the model based on given args.

  Args:
    fn_args: Holds args used to train the model as name/value pairs.
  """

  # This schema is usually either an output of SchemaGen or a manually-curated
  # version provided by pipeline author. A schema can also derived from TFT
  # graph if a Transform component is used. In the case when either is missing,
  # `schema_from_feature_spec` could be used to generate schema from very simple
  # feature_spec, but the schema returned would be very primitive.
  schema = schema_utils.schema_from_feature_spec(_FEATURE_SPEC)

  train_dataset = _input_fn(
      fn_args.train_files,
      fn_args.data_accessor,
      schema,
      batch_size=_TRAIN_BATCH_SIZE)
  eval_dataset = _input_fn(
      fn_args.eval_files,
      fn_args.data_accessor,
      schema,
      batch_size=_EVAL_BATCH_SIZE)

  # NEW: If we have a distribution strategy, build a model in a strategy scope.
  strategy = _get_distribution_strategy(fn_args)
  if strategy is None:
    model = _make_keras_model()
  else:
    with strategy.scope():
      model = _make_keras_model()

  model.fit(
      train_dataset,
      steps_per_epoch=fn_args.train_steps,
      validation_data=eval_dataset,
      validation_steps=fn_args.eval_steps)

  # The result of the training should be saved in `fn_args.serving_model_dir`
  # directory.
  model.save(fn_args.serving_model_dir, save_format='tf')
Writing penguin_trainer.py

Copy the module file to GCS which can be accessed from the pipeline components.

Otherwise, you might want to build a container image including the module file and use the image to run the pipeline and AI Platform Training jobs.

gsutil cp {_trainer_module_file} {MODULE_ROOT}/
InvalidUrlError: Cloud URL scheme should be followed by colon and two slashes: "://". Found: "gs:///pipeline_module/penguin-vertex-training/".

Write a pipeline definition

We will define a function to create a TFX pipeline. It has the same three Components as in Simple TFX Pipeline Tutorial, but we use a Trainer and Pusher component in the GCP extension module.

tfx.extensions.google_cloud_ai_platform.Trainer behaves like a regular Trainer, but it just moves the computation for the model training to cloud. It launches a custom job in Vertex AI Training service and the trainer component in the orchestration system will just wait until the Vertex AI Training job completes.

tfx.extensions.google_cloud_ai_platform.Pusher creates a Vertex AI Model and a Vertex AI Endpoint using the trained model.

def _create_pipeline(pipeline_name: str, pipeline_root: str, data_root: str,
                     module_file: str, endpoint_name: str, project_id: str,
                     region: str, use_gpu: bool) -> tfx.dsl.Pipeline:
  """Implements the penguin pipeline with TFX."""
  # Brings data into the pipeline or otherwise joins/converts training data.
  example_gen = tfx.components.CsvExampleGen(input_base=data_root)

  # NEW: Configuration for Vertex AI Training.
  # This dictionary will be passed as `CustomJobSpec`.
  vertex_job_spec = {
      'project': project_id,
      'worker_pool_specs': [{
          'machine_spec': {
              'machine_type': 'n1-standard-4',
          },
          'replica_count': 1,
          'container_spec': {
              'image_uri': 'gcr.io/tfx-oss-public/tfx:{}'.format(tfx.__version__),
          },
      }],
  }
  if use_gpu:
    # See https://cloud.google.com/vertex-ai/docs/reference/rest/v1/MachineSpec#acceleratortype
    # for available machine types.
    vertex_job_spec['worker_pool_specs'][0]['machine_spec'].update({
        'accelerator_type': 'NVIDIA_TESLA_K80',
        'accelerator_count': 1
    })

  # Trains a model using Vertex AI Training.
  # NEW: We need to specify a Trainer for GCP with related configs.
  trainer = tfx.extensions.google_cloud_ai_platform.Trainer(
      module_file=module_file,
      examples=example_gen.outputs['examples'],
      train_args=tfx.proto.TrainArgs(num_steps=100),
      eval_args=tfx.proto.EvalArgs(num_steps=5),
      custom_config={
          tfx.extensions.google_cloud_ai_platform.ENABLE_VERTEX_KEY:
              True,
          tfx.extensions.google_cloud_ai_platform.VERTEX_REGION_KEY:
              region,
          tfx.extensions.google_cloud_ai_platform.TRAINING_ARGS_KEY:
              vertex_job_spec,
          'use_gpu':
              use_gpu,
      })

  # NEW: Configuration for pusher.
  vertex_serving_spec = {
      'project_id': project_id,
      'endpoint_name': endpoint_name,
      # Remaining argument is passed to aiplatform.Model.deploy()
      # See https://cloud.google.com/vertex-ai/docs/predictions/deploy-model-api#deploy_the_model
      # for the detail.
      #
      # Machine type is the compute resource to serve prediction requests.
      # See https://cloud.google.com/vertex-ai/docs/predictions/configure-compute#machine-types
      # for available machine types and acccerators.
      'machine_type': 'n1-standard-4',
  }

  # Vertex AI provides pre-built containers with various configurations for
  # serving.
  # See https://cloud.google.com/vertex-ai/docs/predictions/pre-built-containers
  # for available container images.
  serving_image = 'us-docker.pkg.dev/vertex-ai/prediction/tf2-cpu.2-6:latest'
  if use_gpu:
    vertex_serving_spec.update({
        'accelerator_type': 'NVIDIA_TESLA_K80',
        'accelerator_count': 1
    })
    serving_image = 'us-docker.pkg.dev/vertex-ai/prediction/tf2-gpu.2-6:latest'

  # NEW: Pushes the model to Vertex AI.
  pusher = tfx.extensions.google_cloud_ai_platform.Pusher(
      model=trainer.outputs['model'],
      custom_config={
          tfx.extensions.google_cloud_ai_platform.ENABLE_VERTEX_KEY:
              True,
          tfx.extensions.google_cloud_ai_platform.VERTEX_REGION_KEY:
              region,
          tfx.extensions.google_cloud_ai_platform.VERTEX_CONTAINER_IMAGE_URI_KEY:
              serving_image,
          tfx.extensions.google_cloud_ai_platform.SERVING_ARGS_KEY:
            vertex_serving_spec,
      })

  components = [
      example_gen,
      trainer,
      pusher,
  ]

  return tfx.dsl.Pipeline(
      pipeline_name=pipeline_name,
      pipeline_root=pipeline_root,
      components=components)

Run the pipeline on Vertex Pipelines.

We will use Vertex Pipelines to run the pipeline as we did in Simple TFX Pipeline for Vertex Pipelines Tutorial.

# docs_infra: no_execute
import os

PIPELINE_DEFINITION_FILE = PIPELINE_NAME + '_pipeline.json'

runner = tfx.orchestration.experimental.KubeflowV2DagRunner(
    config=tfx.orchestration.experimental.KubeflowV2DagRunnerConfig(),
    output_filename=PIPELINE_DEFINITION_FILE)
_ = runner.run(
    _create_pipeline(
        pipeline_name=PIPELINE_NAME,
        pipeline_root=PIPELINE_ROOT,
        data_root=DATA_ROOT,
        module_file=os.path.join(MODULE_ROOT, _trainer_module_file),
        endpoint_name=ENDPOINT_NAME,
        project_id=GOOGLE_CLOUD_PROJECT,
        region=GOOGLE_CLOUD_REGION,
        # We will use CPUs only for now.
        use_gpu=False))

The generated definition file can be submitted using Google Cloud aiplatform client in google-cloud-aiplatform package.

# docs_infra: no_execute
from google.cloud import aiplatform
from google.cloud.aiplatform import pipeline_jobs
import logging
logging.getLogger().setLevel(logging.INFO)

aiplatform.init(project=GOOGLE_CLOUD_PROJECT, location=GOOGLE_CLOUD_REGION)

job = pipeline_jobs.PipelineJob(template_path=PIPELINE_DEFINITION_FILE,
                                display_name=PIPELINE_NAME)
job.submit()

Now you can visit the link in the output above or visit 'Vertex AI > Pipelines' in Google Cloud Console to see the progress.

Test with a prediction request

Once the pipeline completes, you will find a deployed model at the one of the endpoints in 'Vertex AI > Endpoints'. We need to know the id of the endpoint to send a prediction request to the new endpoint. This is different from the endpoint name we entered above. You can find the id at the Endpoints page in Google Cloud Console, it looks like a very long number.

Set ENDPOINT_ID below before running it.

ENDPOINT_ID=''     # <--- ENTER THIS
if not ENDPOINT_ID:
    from absl import logging
    logging.error('Please set the endpoint id.')
ERROR:absl:Please set the endpoint id.

We use the same aiplatform client to send a request to the endpoint. We will send a prediction request for Penguin species classification. The input is the four features that we used, and the model will return three values, because our model outputs one value for each species.

For example, the following specific example has the largest value at index '2' and will print '2'.

# docs_infra: no_execute
import numpy as np

# The AI Platform services require regional API endpoints.
client_options = {
    'api_endpoint': GOOGLE_CLOUD_REGION + '-aiplatform.googleapis.com'
    }
# Initialize client that will be used to create and send requests.
client = aiplatform.gapic.PredictionServiceClient(client_options=client_options)

# Set data values for the prediction request.
# Our model expects 4 feature inputs and produces 3 output values for each
# species. Note that the output is logit value rather than probabilities.
# See the model code to understand input / output structure.
instances = [{
    'culmen_length_mm':[0.71],
    'culmen_depth_mm':[0.38],
    'flipper_length_mm':[0.98],
    'body_mass_g': [0.78],
}]

endpoint = client.endpoint_path(
    project=GOOGLE_CLOUD_PROJECT,
    location=GOOGLE_CLOUD_REGION,
    endpoint=ENDPOINT_ID,
)
# Send a prediction request and get response.
response = client.predict(endpoint=endpoint, instances=instances)

# Uses argmax to find the index of the maximum value.
print('species:', np.argmax(response.predictions[0]))

For detailed information about online prediction, please visit the Endpoints page in Google Cloud Console. you can find a guide on sending sample requests and links to more resources.

Run the pipeline using a GPU

Vertex AI supports training using various machine types including support for GPUs. See Machine spec reference for available options.

We already defined our pipeline to support GPU training. All we need to do is setting use_gpu flag to True. Then a pipeline will be created with a machine spec including one NVIDIA_TESLA_K80 and our model training code will use tf.distribute.MirroredStrategy.

Note that use_gpu flag is not a part of the Vertex or TFX API. It is just used to control the training code in this tutorial.

# docs_infra: no_execute
runner.run(
    _create_pipeline(
        pipeline_name=PIPELINE_NAME,
        pipeline_root=PIPELINE_ROOT,
        data_root=DATA_ROOT,
        module_file=os.path.join(MODULE_ROOT, _trainer_module_file),
        endpoint_name=ENDPOINT_NAME,
        project_id=GOOGLE_CLOUD_PROJECT,
        region=GOOGLE_CLOUD_REGION,
        # Updated: Use GPUs. We will use a NVIDIA_TESLA_K80 and 
        # the model code will use tf.distribute.MirroredStrategy.
        use_gpu=True))

job = pipeline_jobs.PipelineJob(template_path=PIPELINE_DEFINITION_FILE,
                                display_name=PIPELINE_NAME)
job.submit()

Now you can visit the link in the output above or visit 'Vertex AI > Pipelines' in Google Cloud Console to see the progress.

Cleaning up

You have created a Vertex AI Model and Endpoint in this tutorial. Please delete these resources to avoid any unwanted charges by going to Endpoints and undeploying the model from the endpoint first. Then you can delete the endpoint and the model separately.