Feature Engineering using TFX Pipeline and TensorFlow Transform

Transform input data and train a model with a TFX pipeline.

In this notebook-based tutorial, we will create and run a TFX pipeline to ingest raw input data and preprocess it appropriately for ML training. This notebook is based on the TFX pipeline we built in Data validation using TFX Pipeline and TensorFlow Data Validation Tutorial. If you have not read that one yet, you should read it before proceeding with this notebook.

You can increase the predictive quality of your data and/or reduce dimensionality with feature engineering. One of the benefits of using TFX is that you will write your transformation code once, and the resulting transforms will be consistent between training and serving in order to avoid training/serving skew.

We will add a Transform component to the pipeline. The Transform component is implemented using the tf.transform library.

Please see Understanding TFX Pipelines to learn more about various concepts in TFX.

Set Up

We first need to install the TFX Python package and download the dataset which we will use for our model.

Upgrade Pip

To avoid upgrading Pip in a system when running locally, check to make sure that we are running in Colab. Local systems can of course be upgraded separately.

try:
  import colab
  !pip install --upgrade pip
except:
  pass

Install TFX

pip install -U tfx

Uninstall shapely

TODO(b/263441833) This is a temporal solution to avoid an ImportError. Ultimately, it should be handled by supporting a recent version of Bigquery, instead of uninstalling other extra dependencies.

pip uninstall shapely -y

Did you restart the runtime?

If you are using Google Colab, the first time that you run the cell above, you must restart the runtime by clicking above "RESTART RUNTIME" button or using "Runtime > Restart runtime ..." menu. This is because of the way that Colab loads packages.

Check the TensorFlow and TFX versions.

import tensorflow as tf
print('TensorFlow version: {}'.format(tf.__version__))
from tfx import v1 as tfx
print('TFX version: {}'.format(tfx.__version__))
TensorFlow version: 2.12.1
TFX version: 1.13.0

Set up variables

There are some variables used to define a pipeline. You can customize these variables as you want. By default all output from the pipeline will be generated under the current directory.

import os

PIPELINE_NAME = "penguin-transform"

# Output directory to store artifacts generated from the pipeline.
PIPELINE_ROOT = os.path.join('pipelines', PIPELINE_NAME)
# Path to a SQLite DB file to use as an MLMD storage.
METADATA_PATH = os.path.join('metadata', PIPELINE_NAME, 'metadata.db')
# Output directory where created models from the pipeline will be exported.
SERVING_MODEL_DIR = os.path.join('serving_model', PIPELINE_NAME)

from absl import logging
logging.set_verbosity(logging.INFO)  # Set default logging level.

Prepare example data

We will download the example dataset for use in our TFX pipeline. The dataset we are using is Palmer Penguins dataset.

However, unlike previous tutorials which used an already preprocessed dataset, we will use the raw Palmer Penguins dataset.

Because the TFX ExampleGen component reads inputs from a directory, we need to create a directory and copy the dataset to it.

import urllib.request
import tempfile

DATA_ROOT = tempfile.mkdtemp(prefix='tfx-data')  # Create a temporary directory.
_data_path = 'https://storage.googleapis.com/download.tensorflow.org/data/palmer_penguins/penguins_size.csv'
_data_filepath = os.path.join(DATA_ROOT, "data.csv")
urllib.request.urlretrieve(_data_path, _data_filepath)
('/tmpfs/tmp/tfx-data84pyvey3/data.csv',
 <http.client.HTTPMessage at 0x7f49700384c0>)

Take a quick look at what the raw data looks like.

head {_data_filepath}
species,island,culmen_length_mm,culmen_depth_mm,flipper_length_mm,body_mass_g,sex
Adelie,Torgersen,39.1,18.7,181,3750,MALE
Adelie,Torgersen,39.5,17.4,186,3800,FEMALE
Adelie,Torgersen,40.3,18,195,3250,FEMALE
Adelie,Torgersen,NA,NA,NA,NA,NA
Adelie,Torgersen,36.7,19.3,193,3450,FEMALE
Adelie,Torgersen,39.3,20.6,190,3650,MALE
Adelie,Torgersen,38.9,17.8,181,3625,FEMALE
Adelie,Torgersen,39.2,19.6,195,4675,MALE
Adelie,Torgersen,34.1,18.1,193,3475,NA

There are some entries with missing values which are represented as NA. We will just delete those entries in this tutorial.

sed -i '/\bNA\b/d' {_data_filepath}
head {_data_filepath}
species,island,culmen_length_mm,culmen_depth_mm,flipper_length_mm,body_mass_g,sex
Adelie,Torgersen,39.1,18.7,181,3750,MALE
Adelie,Torgersen,39.5,17.4,186,3800,FEMALE
Adelie,Torgersen,40.3,18,195,3250,FEMALE
Adelie,Torgersen,36.7,19.3,193,3450,FEMALE
Adelie,Torgersen,39.3,20.6,190,3650,MALE
Adelie,Torgersen,38.9,17.8,181,3625,FEMALE
Adelie,Torgersen,39.2,19.6,195,4675,MALE
Adelie,Torgersen,41.1,17.6,182,3200,FEMALE
Adelie,Torgersen,38.6,21.2,191,3800,MALE

You should be able to see seven features which describe penguins. We will use the same set of features as the previous tutorials - 'culmen_length_mm', 'culmen_depth_mm', 'flipper_length_mm', 'body_mass_g' - and will predict the 'species' of a penguin.

The only difference will be that the input data is not preprocessed. Note that we will not use other features like 'island' or 'sex' in this tutorial.

Prepare a schema file

As described in Data validation using TFX Pipeline and TensorFlow Data Validation Tutorial, we need a schema file for the dataset. Because the dataset is different from the previous tutorial we need to generate it again. In this tutorial, we will skip those steps and just use a prepared schema file.

import shutil

SCHEMA_PATH = 'schema'

_schema_uri = 'https://raw.githubusercontent.com/tensorflow/tfx/master/tfx/examples/penguin/schema/raw/schema.pbtxt'
_schema_filename = 'schema.pbtxt'
_schema_filepath = os.path.join(SCHEMA_PATH, _schema_filename)

os.makedirs(SCHEMA_PATH, exist_ok=True)
urllib.request.urlretrieve(_schema_uri, _schema_filepath)
('schema/schema.pbtxt', <http.client.HTTPMessage at 0x7f4adc050ca0>)

This schema file was created with the same pipeline as in the previous tutorial without any manual changes.

Create a pipeline

TFX pipelines are defined using Python APIs. We will add Transform component to the pipeline we created in the Data Validation tutorial.

A Transform component requires input data from an ExampleGen component and a schema from a SchemaGen component, and produces a "transform graph". The output will be used in a Trainer component. Transform can optionally produce "transformed data" in addition, which is the materialized data after transformation. However, we will transform data during training in this tutorial without materialization of the intermediate transformed data.

One thing to note is that we need to define a Python function, preprocessing_fn to describe how input data should be transformed. This is similar to a Trainer component which also requires user code for model definition.

Write preprocessing and training code

We need to define two Python functions. One for Transform and one for Trainer.

preprocessing_fn

The Transform component will find a function named preprocessing_fn in the given module file as we did for Trainer component. You can also specify a specific function using the preprocessing_fn parameter of the Transform component.

In this example, we will do two kinds of transformation. For continuous numeric features like culmen_length_mm and body_mass_g, we will normalize these values using the tft.scale_to_z_score function. For the label feature, we need to convert string labels into numeric index values. We will use tf.lookup.StaticHashTable for conversion.

To identify transformed fields easily, we append a _xf suffix to the transformed feature names.

run_fn

The model itself is almost the same as in the previous tutorials, but this time we will transform the input data using the transform graph from the Transform component.

One more important difference compared to the previous tutorial is that now we export a model for serving which includes not only the computation graph of the model, but also the transform graph for preprocessing, which is generated in Transform component. We need to define a separate function which will be used for serving incoming requests. You can see that the same function _apply_preprocessing was used for both of the training data and the serving request.

_module_file = 'penguin_utils.py'
%%writefile {_module_file}


from typing import List, Text
from absl import logging
import tensorflow as tf
from tensorflow import keras
from tensorflow_metadata.proto.v0 import schema_pb2
import tensorflow_transform as tft
from tensorflow_transform.tf_metadata import schema_utils

from tfx import v1 as tfx
from tfx_bsl.public import tfxio

# Specify features that we will use.
_FEATURE_KEYS = [
    'culmen_length_mm', 'culmen_depth_mm', 'flipper_length_mm', 'body_mass_g'
]
_LABEL_KEY = 'species'

_TRAIN_BATCH_SIZE = 20
_EVAL_BATCH_SIZE = 10


# NEW: TFX Transform will call this function.
def preprocessing_fn(inputs):
  """tf.transform's callback function for preprocessing inputs.

  Args:
    inputs: map from feature keys to raw not-yet-transformed features.

  Returns:
    Map from string feature key to transformed feature.
  """
  outputs = {}

  # Uses features defined in _FEATURE_KEYS only.
  for key in _FEATURE_KEYS:
    # tft.scale_to_z_score computes the mean and variance of the given feature
    # and scales the output based on the result.
    outputs[key] = tft.scale_to_z_score(inputs[key])

  # For the label column we provide the mapping from string to index.
  # We could instead use `tft.compute_and_apply_vocabulary()` in order to
  # compute the vocabulary dynamically and perform a lookup.
  # Since in this example there are only 3 possible values, we use a hard-coded
  # table for simplicity.
  table_keys = ['Adelie', 'Chinstrap', 'Gentoo']
  initializer = tf.lookup.KeyValueTensorInitializer(
      keys=table_keys,
      values=tf.cast(tf.range(len(table_keys)), tf.int64),
      key_dtype=tf.string,
      value_dtype=tf.int64)
  table = tf.lookup.StaticHashTable(initializer, default_value=-1)
  outputs[_LABEL_KEY] = table.lookup(inputs[_LABEL_KEY])

  return outputs


# NEW: This function will apply the same transform operation to training data
#      and serving requests.
def _apply_preprocessing(raw_features, tft_layer):
  transformed_features = tft_layer(raw_features)
  if _LABEL_KEY in raw_features:
    transformed_label = transformed_features.pop(_LABEL_KEY)
    return transformed_features, transformed_label
  else:
    return transformed_features, None


# NEW: This function will create a handler function which gets a serialized
#      tf.example, preprocess and run an inference with it.
def _get_serve_tf_examples_fn(model, tf_transform_output):
  # We must save the tft_layer to the model to ensure its assets are kept and
  # tracked.
  model.tft_layer = tf_transform_output.transform_features_layer()

  @tf.function(input_signature=[
      tf.TensorSpec(shape=[None], dtype=tf.string, name='examples')
  ])
  def serve_tf_examples_fn(serialized_tf_examples):
    # Expected input is a string which is serialized tf.Example format.
    feature_spec = tf_transform_output.raw_feature_spec()
    # Because input schema includes unnecessary fields like 'species' and
    # 'island', we filter feature_spec to include required keys only.
    required_feature_spec = {
        k: v for k, v in feature_spec.items() if k in _FEATURE_KEYS
    }
    parsed_features = tf.io.parse_example(serialized_tf_examples,
                                          required_feature_spec)

    # Preprocess parsed input with transform operation defined in
    # preprocessing_fn().
    transformed_features, _ = _apply_preprocessing(parsed_features,
                                                   model.tft_layer)
    # Run inference with ML model.
    return model(transformed_features)

  return serve_tf_examples_fn


def _input_fn(file_pattern: List[Text],
              data_accessor: tfx.components.DataAccessor,
              tf_transform_output: tft.TFTransformOutput,
              batch_size: int = 200) -> tf.data.Dataset:
  """Generates features and label for tuning/training.

  Args:
    file_pattern: List of paths or patterns of input tfrecord files.
    data_accessor: DataAccessor for converting input to RecordBatch.
    tf_transform_output: A TFTransformOutput.
    batch_size: representing the number of consecutive elements of returned
      dataset to combine in a single batch

  Returns:
    A dataset that contains (features, indices) tuple where features is a
      dictionary of Tensors, and indices is a single Tensor of label indices.
  """
  dataset = data_accessor.tf_dataset_factory(
      file_pattern,
      tfxio.TensorFlowDatasetOptions(batch_size=batch_size),
      schema=tf_transform_output.raw_metadata.schema)

  transform_layer = tf_transform_output.transform_features_layer()
  def apply_transform(raw_features):
    return _apply_preprocessing(raw_features, transform_layer)

  return dataset.map(apply_transform).repeat()


def _build_keras_model() -> tf.keras.Model:
  """Creates a DNN Keras model for classifying penguin data.

  Returns:
    A Keras Model.
  """
  # The model below is built with Functional API, please refer to
  # https://www.tensorflow.org/guide/keras/overview for all API options.
  inputs = [
      keras.layers.Input(shape=(1,), name=key)
      for key in _FEATURE_KEYS
  ]
  d = keras.layers.concatenate(inputs)
  for _ in range(2):
    d = keras.layers.Dense(8, activation='relu')(d)
  outputs = keras.layers.Dense(3)(d)

  model = keras.Model(inputs=inputs, outputs=outputs)
  model.compile(
      optimizer=keras.optimizers.Adam(1e-2),
      loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
      metrics=[keras.metrics.SparseCategoricalAccuracy()])

  model.summary(print_fn=logging.info)
  return model


# TFX Trainer will call this function.
def run_fn(fn_args: tfx.components.FnArgs):
  """Train the model based on given args.

  Args:
    fn_args: Holds args used to train the model as name/value pairs.
  """
  tf_transform_output = tft.TFTransformOutput(fn_args.transform_output)

  train_dataset = _input_fn(
      fn_args.train_files,
      fn_args.data_accessor,
      tf_transform_output,
      batch_size=_TRAIN_BATCH_SIZE)
  eval_dataset = _input_fn(
      fn_args.eval_files,
      fn_args.data_accessor,
      tf_transform_output,
      batch_size=_EVAL_BATCH_SIZE)

  model = _build_keras_model()
  model.fit(
      train_dataset,
      steps_per_epoch=fn_args.train_steps,
      validation_data=eval_dataset,
      validation_steps=fn_args.eval_steps)

  # NEW: Save a computation graph including transform layer.
  signatures = {
      'serving_default': _get_serve_tf_examples_fn(model, tf_transform_output),
  }
  model.save(fn_args.serving_model_dir, save_format='tf', signatures=signatures)
Writing penguin_utils.py

Now you have completed all of the preparation steps to build a TFX pipeline.

Write a pipeline definition

We define a function to create a TFX pipeline. A Pipeline object represents a TFX pipeline, which can be run using one of the pipeline orchestration systems that TFX supports.

def _create_pipeline(pipeline_name: str, pipeline_root: str, data_root: str,
                     schema_path: str, module_file: str, serving_model_dir: str,
                     metadata_path: str) -> tfx.dsl.Pipeline:
  """Implements the penguin pipeline with TFX."""
  # Brings data into the pipeline or otherwise joins/converts training data.
  example_gen = tfx.components.CsvExampleGen(input_base=data_root)

  # Computes statistics over data for visualization and example validation.
  statistics_gen = tfx.components.StatisticsGen(
      examples=example_gen.outputs['examples'])

  # Import the schema.
  schema_importer = tfx.dsl.Importer(
      source_uri=schema_path,
      artifact_type=tfx.types.standard_artifacts.Schema).with_id(
          'schema_importer')

  # Performs anomaly detection based on statistics and data schema.
  example_validator = tfx.components.ExampleValidator(
      statistics=statistics_gen.outputs['statistics'],
      schema=schema_importer.outputs['result'])

  # NEW: Transforms input data using preprocessing_fn in the 'module_file'.
  transform = tfx.components.Transform(
      examples=example_gen.outputs['examples'],
      schema=schema_importer.outputs['result'],
      materialize=False,
      module_file=module_file)

  # Uses user-provided Python function that trains a model.
  trainer = tfx.components.Trainer(
      module_file=module_file,
      examples=example_gen.outputs['examples'],

      # NEW: Pass transform_graph to the trainer.
      transform_graph=transform.outputs['transform_graph'],

      train_args=tfx.proto.TrainArgs(num_steps=100),
      eval_args=tfx.proto.EvalArgs(num_steps=5))

  # Pushes the model to a filesystem destination.
  pusher = tfx.components.Pusher(
      model=trainer.outputs['model'],
      push_destination=tfx.proto.PushDestination(
          filesystem=tfx.proto.PushDestination.Filesystem(
              base_directory=serving_model_dir)))

  components = [
      example_gen,
      statistics_gen,
      schema_importer,
      example_validator,

      transform,  # NEW: Transform component was added to the pipeline.

      trainer,
      pusher,
  ]

  return tfx.dsl.Pipeline(
      pipeline_name=pipeline_name,
      pipeline_root=pipeline_root,
      metadata_connection_config=tfx.orchestration.metadata
      .sqlite_metadata_connection_config(metadata_path),
      components=components)

Run the pipeline

We will use LocalDagRunner as in the previous tutorial.

tfx.orchestration.LocalDagRunner().run(
  _create_pipeline(
      pipeline_name=PIPELINE_NAME,
      pipeline_root=PIPELINE_ROOT,
      data_root=DATA_ROOT,
      schema_path=SCHEMA_PATH,
      module_file=_module_file,
      serving_model_dir=SERVING_MODEL_DIR,
      metadata_path=METADATA_PATH))
INFO:absl:Excluding no splits because exclude_splits is not set.
INFO:absl:Excluding no splits because exclude_splits is not set.
INFO:absl:Generating ephemeral wheel package for '/tmpfs/src/temp/docs/tutorials/tfx/penguin_utils.py' (including modules: ['penguin_utils']).
INFO:absl:User module package has hash fingerprint version a5e9139bd7facf5026b5306a6aea534f89db0dea58ebe1bb1fb5ebb9df5fdea9.
INFO:absl:Executing: ['/tmpfs/src/tf_docs_env/bin/python', '/tmpfs/tmp/tmpazogc7u6/_tfx_generated_setup.py', 'bdist_wheel', '--bdist-dir', '/tmpfs/tmp/tmpv6zsxacn', '--dist-dir', '/tmpfs/tmp/tmpu975z_wo']
running bdist_wheel
running build
running build_py
creating build
creating build/lib
copying penguin_utils.py -> build/lib
installing to /tmpfs/tmp/tmpv6zsxacn
running install
running install_lib
copying build/lib/penguin_utils.py -> /tmpfs/tmp/tmpv6zsxacn
running install_egg_info
running egg_info
creating tfx_user_code_Transform.egg-info
writing tfx_user_code_Transform.egg-info/PKG-INFO
writing dependency_links to tfx_user_code_Transform.egg-info/dependency_links.txt
writing top-level names to tfx_user_code_Transform.egg-info/top_level.txt
writing manifest file 'tfx_user_code_Transform.egg-info/SOURCES.txt'
reading manifest file 'tfx_user_code_Transform.egg-info/SOURCES.txt'
/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/setuptools/_distutils/cmd.py:66: SetuptoolsDeprecationWarning: setup.py install is deprecated.
!!

        ********************************************************************************
        Please avoid running ``setup.py`` directly.
        Instead, use pypa/build, pypa/installer or other
        standards-based tools.

        See https://blog.ganssle.io/articles/2021/10/setup-py-deprecated.html for details.
        ********************************************************************************

!!
  self.initialize_options()
INFO:absl:Successfully built user code wheel distribution at 'pipelines/penguin-transform/_wheels/tfx_user_code_Transform-0.0+a5e9139bd7facf5026b5306a6aea534f89db0dea58ebe1bb1fb5ebb9df5fdea9-py3-none-any.whl'; target user module is 'penguin_utils'.
INFO:absl:Full user module path is 'penguin_utils@pipelines/penguin-transform/_wheels/tfx_user_code_Transform-0.0+a5e9139bd7facf5026b5306a6aea534f89db0dea58ebe1bb1fb5ebb9df5fdea9-py3-none-any.whl'
INFO:absl:Generating ephemeral wheel package for '/tmpfs/src/temp/docs/tutorials/tfx/penguin_utils.py' (including modules: ['penguin_utils']).
INFO:absl:User module package has hash fingerprint version a5e9139bd7facf5026b5306a6aea534f89db0dea58ebe1bb1fb5ebb9df5fdea9.
INFO:absl:Executing: ['/tmpfs/src/tf_docs_env/bin/python', '/tmpfs/tmp/tmp06hexlsv/_tfx_generated_setup.py', 'bdist_wheel', '--bdist-dir', '/tmpfs/tmp/tmpf_y1n74l', '--dist-dir', '/tmpfs/tmp/tmpa8m5a1tb']
writing manifest file 'tfx_user_code_Transform.egg-info/SOURCES.txt'
Copying tfx_user_code_Transform.egg-info to /tmpfs/tmp/tmpv6zsxacn/tfx_user_code_Transform-0.0+a5e9139bd7facf5026b5306a6aea534f89db0dea58ebe1bb1fb5ebb9df5fdea9-py3.9.egg-info
running install_scripts
creating /tmpfs/tmp/tmpv6zsxacn/tfx_user_code_Transform-0.0+a5e9139bd7facf5026b5306a6aea534f89db0dea58ebe1bb1fb5ebb9df5fdea9.dist-info/WHEEL
creating '/tmpfs/tmp/tmpu975z_wo/tfx_user_code_Transform-0.0+a5e9139bd7facf5026b5306a6aea534f89db0dea58ebe1bb1fb5ebb9df5fdea9-py3-none-any.whl' and adding '/tmpfs/tmp/tmpv6zsxacn' to it
adding 'penguin_utils.py'
adding 'tfx_user_code_Transform-0.0+a5e9139bd7facf5026b5306a6aea534f89db0dea58ebe1bb1fb5ebb9df5fdea9.dist-info/METADATA'
adding 'tfx_user_code_Transform-0.0+a5e9139bd7facf5026b5306a6aea534f89db0dea58ebe1bb1fb5ebb9df5fdea9.dist-info/WHEEL'
adding 'tfx_user_code_Transform-0.0+a5e9139bd7facf5026b5306a6aea534f89db0dea58ebe1bb1fb5ebb9df5fdea9.dist-info/top_level.txt'
adding 'tfx_user_code_Transform-0.0+a5e9139bd7facf5026b5306a6aea534f89db0dea58ebe1bb1fb5ebb9df5fdea9.dist-info/RECORD'
removing /tmpfs/tmp/tmpv6zsxacn
running bdist_wheel
running build
running build_py
creating build
creating build/lib
copying penguin_utils.py -> build/lib
installing to /tmpfs/tmp/tmpf_y1n74l
running install
running install_lib
copying build/lib/penguin_utils.py -> /tmpfs/tmp/tmpf_y1n74l
running install_egg_info
running egg_info
creating tfx_user_code_Trainer.egg-info
writing tfx_user_code_Trainer.egg-info/PKG-INFO
writing dependency_links to tfx_user_code_Trainer.egg-info/dependency_links.txt
writing top-level names to tfx_user_code_Trainer.egg-info/top_level.txt
writing manifest file 'tfx_user_code_Trainer.egg-info/SOURCES.txt'
/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/setuptools/_distutils/cmd.py:66: SetuptoolsDeprecationWarning: setup.py install is deprecated.
!!

        ********************************************************************************
        Please avoid running ``setup.py`` directly.
        Instead, use pypa/build, pypa/installer or other
        standards-based tools.

        See https://blog.ganssle.io/articles/2021/10/setup-py-deprecated.html for details.
        ********************************************************************************

!!
  self.initialize_options()
INFO:absl:Successfully built user code wheel distribution at 'pipelines/penguin-transform/_wheels/tfx_user_code_Trainer-0.0+a5e9139bd7facf5026b5306a6aea534f89db0dea58ebe1bb1fb5ebb9df5fdea9-py3-none-any.whl'; target user module is 'penguin_utils'.
INFO:absl:Full user module path is 'penguin_utils@pipelines/penguin-transform/_wheels/tfx_user_code_Trainer-0.0+a5e9139bd7facf5026b5306a6aea534f89db0dea58ebe1bb1fb5ebb9df5fdea9-py3-none-any.whl'
INFO:absl:Using deployment config:
 executor_specs {
  key: "CsvExampleGen"
  value {
    beam_executable_spec {
      python_executor_spec {
        class_path: "tfx.components.example_gen.csv_example_gen.executor.Executor"
      }
    }
  }
}
executor_specs {
  key: "ExampleValidator"
  value {
    python_class_executable_spec {
      class_path: "tfx.components.example_validator.executor.Executor"
    }
  }
}
executor_specs {
  key: "Pusher"
  value {
    python_class_executable_spec {
      class_path: "tfx.components.pusher.executor.Executor"
    }
  }
}
executor_specs {
  key: "StatisticsGen"
  value {
    beam_executable_spec {
      python_executor_spec {
        class_path: "tfx.components.statistics_gen.executor.Executor"
      }
    }
  }
}
executor_specs {
  key: "Trainer"
  value {
    python_class_executable_spec {
      class_path: "tfx.components.trainer.executor.GenericExecutor"
    }
  }
}
executor_specs {
  key: "Transform"
  value {
    beam_executable_spec {
      python_executor_spec {
        class_path: "tfx.components.transform.executor.Executor"
      }
    }
  }
}
custom_driver_specs {
  key: "CsvExampleGen"
  value {
    python_class_executable_spec {
      class_path: "tfx.components.example_gen.driver.FileBasedDriver"
    }
  }
}
metadata_connection_config {
  database_connection_config {
    sqlite {
      filename_uri: "metadata/penguin-transform/metadata.db"
      connection_mode: READWRITE_OPENCREATE
    }
  }
}

INFO:absl:Using connection config:
 sqlite {
  filename_uri: "metadata/penguin-transform/metadata.db"
  connection_mode: READWRITE_OPENCREATE
}

INFO:absl:Component CsvExampleGen is running.
INFO:absl:Running launcher for node_info {
  type {
    name: "tfx.components.example_gen.csv_example_gen.component.CsvExampleGen"
  }
  id: "CsvExampleGen"
}
contexts {
  contexts {
    type {
      name: "pipeline"
    }
    name {
      field_value {
        string_value: "penguin-transform"
      }
    }
  }
  contexts {
    type {
      name: "pipeline_run"
    }
    name {
      field_value {
        string_value: "2023-07-28T10:00:34.516274"
      }
    }
  }
  contexts {
    type {
      name: "node"
    }
    name {
      field_value {
        string_value: "penguin-transform.CsvExampleGen"
      }
    }
  }
}
outputs {
  outputs {
    key: "examples"
    value {
      artifact_spec {
        type {
          name: "Examples"
          properties {
            key: "span"
            value: INT
          }
          properties {
            key: "split_names"
            value: STRING
          }
          properties {
            key: "version"
            value: INT
          }
          base_type: DATASET
        }
      }
    }
  }
}
parameters {
  parameters {
    key: "input_base"
    value {
      field_value {
        string_value: "/tmpfs/tmp/tfx-data84pyvey3"
      }
    }
  }
  parameters {
    key: "input_config"
    value {
      field_value {
        string_value: "{\n  \"splits\": [\n    {\n      \"name\": \"single_split\",\n      \"pattern\": \"*\"\n    }\n  ]\n}"
      }
    }
  }
  parameters {
    key: "output_config"
    value {
      field_value {
        string_value: "{\n  \"split_config\": {\n    \"splits\": [\n      {\n        \"hash_buckets\": 2,\n        \"name\": \"train\"\n      },\n      {\n        \"hash_buckets\": 1,\n        \"name\": \"eval\"\n      }\n    ]\n  }\n}"
      }
    }
  }
  parameters {
    key: "output_data_format"
    value {
      field_value {
        int_value: 6
      }
    }
  }
  parameters {
    key: "output_file_format"
    value {
      field_value {
        int_value: 5
      }
    }
  }
}
downstream_nodes: "StatisticsGen"
downstream_nodes: "Trainer"
downstream_nodes: "Transform"
execution_options {
  caching_options {
  }
}

INFO:absl:MetadataStore with DB connection initialized
reading manifest file 'tfx_user_code_Trainer.egg-info/SOURCES.txt'
writing manifest file 'tfx_user_code_Trainer.egg-info/SOURCES.txt'
Copying tfx_user_code_Trainer.egg-info to /tmpfs/tmp/tmpf_y1n74l/tfx_user_code_Trainer-0.0+a5e9139bd7facf5026b5306a6aea534f89db0dea58ebe1bb1fb5ebb9df5fdea9-py3.9.egg-info
running install_scripts
creating /tmpfs/tmp/tmpf_y1n74l/tfx_user_code_Trainer-0.0+a5e9139bd7facf5026b5306a6aea534f89db0dea58ebe1bb1fb5ebb9df5fdea9.dist-info/WHEEL
creating '/tmpfs/tmp/tmpa8m5a1tb/tfx_user_code_Trainer-0.0+a5e9139bd7facf5026b5306a6aea534f89db0dea58ebe1bb1fb5ebb9df5fdea9-py3-none-any.whl' and adding '/tmpfs/tmp/tmpf_y1n74l' to it
adding 'penguin_utils.py'
adding 'tfx_user_code_Trainer-0.0+a5e9139bd7facf5026b5306a6aea534f89db0dea58ebe1bb1fb5ebb9df5fdea9.dist-info/METADATA'
adding 'tfx_user_code_Trainer-0.0+a5e9139bd7facf5026b5306a6aea534f89db0dea58ebe1bb1fb5ebb9df5fdea9.dist-info/WHEEL'
adding 'tfx_user_code_Trainer-0.0+a5e9139bd7facf5026b5306a6aea534f89db0dea58ebe1bb1fb5ebb9df5fdea9.dist-info/top_level.txt'
adding 'tfx_user_code_Trainer-0.0+a5e9139bd7facf5026b5306a6aea534f89db0dea58ebe1bb1fb5ebb9df5fdea9.dist-info/RECORD'
removing /tmpfs/tmp/tmpf_y1n74l
INFO:absl:[CsvExampleGen] Resolved inputs: ({},)
INFO:absl:select span and version = (0, None)
INFO:absl:latest span and version = (0, None)
INFO:absl:MetadataStore with DB connection initialized
INFO:absl:Going to run a new execution 1
INFO:absl:Going to run a new execution: ExecutionInfo(execution_id=1, input_dict={}, output_dict=defaultdict(<class 'list'>, {'examples': [Artifact(artifact: uri: "pipelines/penguin-transform/CsvExampleGen/examples/1"
custom_properties {
  key: "input_fingerprint"
  value {
    string_value: "split:single_split,num_files:1,total_bytes:13161,xor_checksum:1690538432,sum_checksum:1690538432"
  }
}
custom_properties {
  key: "span"
  value {
    int_value: 0
  }
}
, artifact_type: name: "Examples"
properties {
  key: "span"
  value: INT
}
properties {
  key: "split_names"
  value: STRING
}
properties {
  key: "version"
  value: INT
}
base_type: DATASET
)]}), exec_properties={'output_config': '{\n  "split_config": {\n    "splits": [\n      {\n        "hash_buckets": 2,\n        "name": "train"\n      },\n      {\n        "hash_buckets": 1,\n        "name": "eval"\n      }\n    ]\n  }\n}', 'output_data_format': 6, 'output_file_format': 5, 'input_config': '{\n  "splits": [\n    {\n      "name": "single_split",\n      "pattern": "*"\n    }\n  ]\n}', 'input_base': '/tmpfs/tmp/tfx-data84pyvey3', 'span': 0, 'version': None, 'input_fingerprint': 'split:single_split,num_files:1,total_bytes:13161,xor_checksum:1690538432,sum_checksum:1690538432'}, execution_output_uri='pipelines/penguin-transform/CsvExampleGen/.system/executor_execution/1/executor_output.pb', stateful_working_dir='pipelines/penguin-transform/CsvExampleGen/.system/stateful_working_dir/2023-07-28T10:00:34.516274', tmp_dir='pipelines/penguin-transform/CsvExampleGen/.system/executor_execution/1/.temp/', pipeline_node=node_info {
  type {
    name: "tfx.components.example_gen.csv_example_gen.component.CsvExampleGen"
  }
  id: "CsvExampleGen"
}
contexts {
  contexts {
    type {
      name: "pipeline"
    }
    name {
      field_value {
        string_value: "penguin-transform"
      }
    }
  }
  contexts {
    type {
      name: "pipeline_run"
    }
    name {
      field_value {
        string_value: "2023-07-28T10:00:34.516274"
      }
    }
  }
  contexts {
    type {
      name: "node"
    }
    name {
      field_value {
        string_value: "penguin-transform.CsvExampleGen"
      }
    }
  }
}
outputs {
  outputs {
    key: "examples"
    value {
      artifact_spec {
        type {
          name: "Examples"
          properties {
            key: "span"
            value: INT
          }
          properties {
            key: "split_names"
            value: STRING
          }
          properties {
            key: "version"
            value: INT
          }
          base_type: DATASET
        }
      }
    }
  }
}
parameters {
  parameters {
    key: "input_base"
    value {
      field_value {
        string_value: "/tmpfs/tmp/tfx-data84pyvey3"
      }
    }
  }
  parameters {
    key: "input_config"
    value {
      field_value {
        string_value: "{\n  \"splits\": [\n    {\n      \"name\": \"single_split\",\n      \"pattern\": \"*\"\n    }\n  ]\n}"
      }
    }
  }
  parameters {
    key: "output_config"
    value {
      field_value {
        string_value: "{\n  \"split_config\": {\n    \"splits\": [\n      {\n        \"hash_buckets\": 2,\n        \"name\": \"train\"\n      },\n      {\n        \"hash_buckets\": 1,\n        \"name\": \"eval\"\n      }\n    ]\n  }\n}"
      }
    }
  }
  parameters {
    key: "output_data_format"
    value {
      field_value {
        int_value: 6
      }
    }
  }
  parameters {
    key: "output_file_format"
    value {
      field_value {
        int_value: 5
      }
    }
  }
}
downstream_nodes: "StatisticsGen"
downstream_nodes: "Trainer"
downstream_nodes: "Transform"
execution_options {
  caching_options {
  }
}
, pipeline_info=id: "penguin-transform"
, pipeline_run_id='2023-07-28T10:00:34.516274')
INFO:absl:Generating examples.
WARNING:apache_beam.runners.interactive.interactive_environment:Dependencies required for Interactive Beam PCollection visualization are not available, please use: `pip install apache-beam[interactive]` to install necessary dependencies to enable all data visualization features.
INFO:absl:Processing input csv data /tmpfs/tmp/tfx-data84pyvey3/* to TFExample.
WARNING:apache_beam.io.tfrecordio:Couldn't find python-snappy so the implementation of _TFRecordUtil._masked_crc32c is not as fast as it could be.
INFO:absl:Examples generated.
INFO:absl:Value type <class 'NoneType'> of key version in exec_properties is not supported, going to drop it
INFO:absl:Value type <class 'list'> of key _beam_pipeline_args in exec_properties is not supported, going to drop it
INFO:absl:Cleaning up stateless execution info.
INFO:absl:Execution 1 succeeded.
INFO:absl:Cleaning up stateful execution info.
INFO:absl:Publishing output artifacts defaultdict(<class 'list'>, {'examples': [Artifact(artifact: uri: "pipelines/penguin-transform/CsvExampleGen/examples/1"
custom_properties {
  key: "input_fingerprint"
  value {
    string_value: "split:single_split,num_files:1,total_bytes:13161,xor_checksum:1690538432,sum_checksum:1690538432"
  }
}
custom_properties {
  key: "span"
  value {
    int_value: 0
  }
}
, artifact_type: name: "Examples"
properties {
  key: "span"
  value: INT
}
properties {
  key: "split_names"
  value: STRING
}
properties {
  key: "version"
  value: INT
}
base_type: DATASET
)]}) for execution 1
INFO:absl:MetadataStore with DB connection initialized
INFO:absl:Component CsvExampleGen is finished.
INFO:absl:Component schema_importer is running.
INFO:absl:Running launcher for node_info {
  type {
    name: "tfx.dsl.components.common.importer.Importer"
  }
  id: "schema_importer"
}
contexts {
  contexts {
    type {
      name: "pipeline"
    }
    name {
      field_value {
        string_value: "penguin-transform"
      }
    }
  }
  contexts {
    type {
      name: "pipeline_run"
    }
    name {
      field_value {
        string_value: "2023-07-28T10:00:34.516274"
      }
    }
  }
  contexts {
    type {
      name: "node"
    }
    name {
      field_value {
        string_value: "penguin-transform.schema_importer"
      }
    }
  }
}
outputs {
  outputs {
    key: "result"
    value {
      artifact_spec {
        type {
          name: "Schema"
        }
      }
    }
  }
}
parameters {
  parameters {
    key: "artifact_uri"
    value {
      field_value {
        string_value: "schema"
      }
    }
  }
  parameters {
    key: "output_key"
    value {
      field_value {
        string_value: "result"
      }
    }
  }
  parameters {
    key: "reimport"
    value {
      field_value {
        int_value: 0
      }
    }
  }
}
downstream_nodes: "ExampleValidator"
downstream_nodes: "Transform"
execution_options {
  caching_options {
  }
}

INFO:absl:Running as an importer node.
INFO:absl:MetadataStore with DB connection initialized
INFO:absl:Processing source uri: schema, properties: {}, custom_properties: {}
INFO:absl:Component schema_importer is finished.
INFO:absl:Component StatisticsGen is running.
INFO:absl:Running launcher for node_info {
  type {
    name: "tfx.components.statistics_gen.component.StatisticsGen"
    base_type: PROCESS
  }
  id: "StatisticsGen"
}
contexts {
  contexts {
    type {
      name: "pipeline"
    }
    name {
      field_value {
        string_value: "penguin-transform"
      }
    }
  }
  contexts {
    type {
      name: "pipeline_run"
    }
    name {
      field_value {
        string_value: "2023-07-28T10:00:34.516274"
      }
    }
  }
  contexts {
    type {
      name: "node"
    }
    name {
      field_value {
        string_value: "penguin-transform.StatisticsGen"
      }
    }
  }
}
inputs {
  inputs {
    key: "examples"
    value {
      channels {
        producer_node_query {
          id: "CsvExampleGen"
        }
        context_queries {
          type {
            name: "pipeline"
          }
          name {
            field_value {
              string_value: "penguin-transform"
            }
          }
        }
        context_queries {
          type {
            name: "pipeline_run"
          }
          name {
            field_value {
              string_value: "2023-07-28T10:00:34.516274"
            }
          }
        }
        context_queries {
          type {
            name: "node"
          }
          name {
            field_value {
              string_value: "penguin-transform.CsvExampleGen"
            }
          }
        }
        artifact_query {
          type {
            name: "Examples"
            base_type: DATASET
          }
        }
        output_key: "examples"
      }
      min_count: 1
    }
  }
}
outputs {
  outputs {
    key: "statistics"
    value {
      artifact_spec {
        type {
          name: "ExampleStatistics"
          properties {
            key: "span"
            value: INT
          }
          properties {
            key: "split_names"
            value: STRING
          }
          base_type: STATISTICS
        }
      }
    }
  }
}
parameters {
  parameters {
    key: "exclude_splits"
    value {
      field_value {
        string_value: "[]"
      }
    }
  }
}
upstream_nodes: "CsvExampleGen"
downstream_nodes: "ExampleValidator"
execution_options {
  caching_options {
  }
}

INFO:absl:MetadataStore with DB connection initialized
WARNING:absl:ArtifactQuery.property_predicate is not supported.
INFO:absl:[StatisticsGen] Resolved inputs: ({'examples': [Artifact(artifact: id: 1
type_id: 15
uri: "pipelines/penguin-transform/CsvExampleGen/examples/1"
properties {
  key: "split_names"
  value {
    string_value: "[\"train\", \"eval\"]"
  }
}
custom_properties {
  key: "file_format"
  value {
    string_value: "tfrecords_gzip"
  }
}
custom_properties {
  key: "input_fingerprint"
  value {
    string_value: "split:single_split,num_files:1,total_bytes:13161,xor_checksum:1690538432,sum_checksum:1690538432"
  }
}
custom_properties {
  key: "is_external"
  value {
    int_value: 0
  }
}
custom_properties {
  key: "payload_format"
  value {
    string_value: "FORMAT_TF_EXAMPLE"
  }
}
custom_properties {
  key: "span"
  value {
    int_value: 0
  }
}
custom_properties {
  key: "tfx_version"
  value {
    string_value: "1.13.0"
  }
}
state: LIVE
type: "Examples"
create_time_since_epoch: 1690538436516
last_update_time_since_epoch: 1690538436516
, artifact_type: id: 15
name: "Examples"
properties {
  key: "span"
  value: INT
}
properties {
  key: "split_names"
  value: STRING
}
properties {
  key: "version"
  value: INT
}
base_type: DATASET
)]},)
INFO:absl:MetadataStore with DB connection initialized
INFO:absl:Going to run a new execution 3
INFO:absl:Going to run a new execution: ExecutionInfo(execution_id=3, input_dict={'examples': [Artifact(artifact: id: 1
type_id: 15
uri: "pipelines/penguin-transform/CsvExampleGen/examples/1"
properties {
  key: "split_names"
  value {
    string_value: "[\"train\", \"eval\"]"
  }
}
custom_properties {
  key: "file_format"
  value {
    string_value: "tfrecords_gzip"
  }
}
custom_properties {
  key: "input_fingerprint"
  value {
    string_value: "split:single_split,num_files:1,total_bytes:13161,xor_checksum:1690538432,sum_checksum:1690538432"
  }
}
custom_properties {
  key: "is_external"
  value {
    int_value: 0
  }
}
custom_properties {
  key: "payload_format"
  value {
    string_value: "FORMAT_TF_EXAMPLE"
  }
}
custom_properties {
  key: "span"
  value {
    int_value: 0
  }
}
custom_properties {
  key: "tfx_version"
  value {
    string_value: "1.13.0"
  }
}
state: LIVE
type: "Examples"
create_time_since_epoch: 1690538436516
last_update_time_since_epoch: 1690538436516
, artifact_type: id: 15
name: "Examples"
properties {
  key: "span"
  value: INT
}
properties {
  key: "split_names"
  value: STRING
}
properties {
  key: "version"
  value: INT
}
base_type: DATASET
)]}, output_dict=defaultdict(<class 'list'>, {'statistics': [Artifact(artifact: uri: "pipelines/penguin-transform/StatisticsGen/statistics/3"
, artifact_type: name: "ExampleStatistics"
properties {
  key: "span"
  value: INT
}
properties {
  key: "split_names"
  value: STRING
}
base_type: STATISTICS
)]}), exec_properties={'exclude_splits': '[]'}, execution_output_uri='pipelines/penguin-transform/StatisticsGen/.system/executor_execution/3/executor_output.pb', stateful_working_dir='pipelines/penguin-transform/StatisticsGen/.system/stateful_working_dir/2023-07-28T10:00:34.516274', tmp_dir='pipelines/penguin-transform/StatisticsGen/.system/executor_execution/3/.temp/', pipeline_node=node_info {
  type {
    name: "tfx.components.statistics_gen.component.StatisticsGen"
    base_type: PROCESS
  }
  id: "StatisticsGen"
}
contexts {
  contexts {
    type {
      name: "pipeline"
    }
    name {
      field_value {
        string_value: "penguin-transform"
      }
    }
  }
  contexts {
    type {
      name: "pipeline_run"
    }
    name {
      field_value {
        string_value: "2023-07-28T10:00:34.516274"
      }
    }
  }
  contexts {
    type {
      name: "node"
    }
    name {
      field_value {
        string_value: "penguin-transform.StatisticsGen"
      }
    }
  }
}
inputs {
  inputs {
    key: "examples"
    value {
      channels {
        producer_node_query {
          id: "CsvExampleGen"
        }
        context_queries {
          type {
            name: "pipeline"
          }
          name {
            field_value {
              string_value: "penguin-transform"
            }
          }
        }
        context_queries {
          type {
            name: "pipeline_run"
          }
          name {
            field_value {
              string_value: "2023-07-28T10:00:34.516274"
            }
          }
        }
        context_queries {
          type {
            name: "node"
          }
          name {
            field_value {
              string_value: "penguin-transform.CsvExampleGen"
            }
          }
        }
        artifact_query {
          type {
            name: "Examples"
            base_type: DATASET
          }
        }
        output_key: "examples"
      }
      min_count: 1
    }
  }
}
outputs {
  outputs {
    key: "statistics"
    value {
      artifact_spec {
        type {
          name: "ExampleStatistics"
          properties {
            key: "span"
            value: INT
          }
          properties {
            key: "split_names"
            value: STRING
          }
          base_type: STATISTICS
        }
      }
    }
  }
}
parameters {
  parameters {
    key: "exclude_splits"
    value {
      field_value {
        string_value: "[]"
      }
    }
  }
}
upstream_nodes: "CsvExampleGen"
downstream_nodes: "ExampleValidator"
execution_options {
  caching_options {
  }
}
, pipeline_info=id: "penguin-transform"
, pipeline_run_id='2023-07-28T10:00:34.516274')
INFO:absl:Generating statistics for split train.
INFO:absl:Statistics for split train written to pipelines/penguin-transform/StatisticsGen/statistics/3/Split-train.
INFO:absl:Generating statistics for split eval.
INFO:absl:Statistics for split eval written to pipelines/penguin-transform/StatisticsGen/statistics/3/Split-eval.
INFO:absl:Cleaning up stateless execution info.
INFO:absl:Execution 3 succeeded.
INFO:absl:Cleaning up stateful execution info.
INFO:absl:Publishing output artifacts defaultdict(<class 'list'>, {'statistics': [Artifact(artifact: uri: "pipelines/penguin-transform/StatisticsGen/statistics/3"
, artifact_type: name: "ExampleStatistics"
properties {
  key: "span"
  value: INT
}
properties {
  key: "split_names"
  value: STRING
}
base_type: STATISTICS
)]}) for execution 3
INFO:absl:MetadataStore with DB connection initialized
INFO:absl:Component StatisticsGen is finished.
INFO:absl:Component Transform is running.
INFO:absl:Running launcher for node_info {
  type {
    name: "tfx.components.transform.component.Transform"
    base_type: TRANSFORM
  }
  id: "Transform"
}
contexts {
  contexts {
    type {
      name: "pipeline"
    }
    name {
      field_value {
        string_value: "penguin-transform"
      }
    }
  }
  contexts {
    type {
      name: "pipeline_run"
    }
    name {
      field_value {
        string_value: "2023-07-28T10:00:34.516274"
      }
    }
  }
  contexts {
    type {
      name: "node"
    }
    name {
      field_value {
        string_value: "penguin-transform.Transform"
      }
    }
  }
}
inputs {
  inputs {
    key: "examples"
    value {
      channels {
        producer_node_query {
          id: "CsvExampleGen"
        }
        context_queries {
          type {
            name: "pipeline"
          }
          name {
            field_value {
              string_value: "penguin-transform"
            }
          }
        }
        context_queries {
          type {
            name: "pipeline_run"
          }
          name {
            field_value {
              string_value: "2023-07-28T10:00:34.516274"
            }
          }
        }
        context_queries {
          type {
            name: "node"
          }
          name {
            field_value {
              string_value: "penguin-transform.CsvExampleGen"
            }
          }
        }
        artifact_query {
          type {
            name: "Examples"
            base_type: DATASET
          }
        }
        output_key: "examples"
      }
      min_count: 1
    }
  }
  inputs {
    key: "schema"
    value {
      channels {
        producer_node_query {
          id: "schema_importer"
        }
        context_queries {
          type {
            name: "pipeline"
          }
          name {
            field_value {
              string_value: "penguin-transform"
            }
          }
        }
        context_queries {
          type {
            name: "pipeline_run"
          }
          name {
            field_value {
              string_value: "2023-07-28T10:00:34.516274"
            }
          }
        }
        context_queries {
          type {
            name: "node"
          }
          name {
            field_value {
              string_value: "penguin-transform.schema_importer"
            }
          }
        }
        artifact_query {
          type {
            name: "Schema"
          }
        }
        output_key: "result"
      }
      min_count: 1
    }
  }
}
outputs {
  outputs {
    key: "post_transform_anomalies"
    value {
      artifact_spec {
        type {
          name: "ExampleAnomalies"
          properties {
            key: "span"
            value: INT
          }
          properties {
            key: "split_names"
            value: STRING
          }
        }
      }
    }
  }
  outputs {
    key: "post_transform_schema"
    value {
      artifact_spec {
        type {
          name: "Schema"
        }
      }
    }
  }
  outputs {
    key: "post_transform_stats"
    value {
      artifact_spec {
        type {
          name: "ExampleStatistics"
          properties {
            key: "span"
            value: INT
          }
          properties {
            key: "split_names"
            value: STRING
          }
          base_type: STATISTICS
        }
      }
    }
  }
  outputs {
    key: "pre_transform_schema"
    value {
      artifact_spec {
        type {
          name: "Schema"
        }
      }
    }
  }
  outputs {
    key: "pre_transform_stats"
    value {
      artifact_spec {
        type {
          name: "ExampleStatistics"
          properties {
            key: "span"
            value: INT
          }
          properties {
            key: "split_names"
            value: STRING
          }
          base_type: STATISTICS
        }
      }
    }
  }
  outputs {
    key: "transform_graph"
    value {
      artifact_spec {
        type {
          name: "TransformGraph"
        }
      }
    }
  }
  outputs {
    key: "updated_analyzer_cache"
    value {
      artifact_spec {
        type {
          name: "TransformCache"
        }
      }
    }
  }
}
parameters {
  parameters {
    key: "custom_config"
    value {
      field_value {
        string_value: "null"
      }
    }
  }
  parameters {
    key: "disable_statistics"
    value {
      field_value {
        int_value: 0
      }
    }
  }
  parameters {
    key: "force_tf_compat_v1"
    value {
      field_value {
        int_value: 0
      }
    }
  }
  parameters {
    key: "module_path"
    value {
      field_value {
        string_value: "penguin_utils@pipelines/penguin-transform/_wheels/tfx_user_code_Transform-0.0+a5e9139bd7facf5026b5306a6aea534f89db0dea58ebe1bb1fb5ebb9df5fdea9-py3-none-any.whl"
      }
    }
  }
}
upstream_nodes: "CsvExampleGen"
upstream_nodes: "schema_importer"
downstream_nodes: "Trainer"
execution_options {
  caching_options {
  }
}

INFO:absl:MetadataStore with DB connection initialized
WARNING:absl:ArtifactQuery.property_predicate is not supported.
WARNING:absl:ArtifactQuery.property_predicate is not supported.
INFO:absl:[Transform] Resolved inputs: ({'examples': [Artifact(artifact: id: 1
type_id: 15
uri: "pipelines/penguin-transform/CsvExampleGen/examples/1"
properties {
  key: "split_names"
  value {
    string_value: "[\"train\", \"eval\"]"
  }
}
custom_properties {
  key: "file_format"
  value {
    string_value: "tfrecords_gzip"
  }
}
custom_properties {
  key: "input_fingerprint"
  value {
    string_value: "split:single_split,num_files:1,total_bytes:13161,xor_checksum:1690538432,sum_checksum:1690538432"
  }
}
custom_properties {
  key: "is_external"
  value {
    int_value: 0
  }
}
custom_properties {
  key: "payload_format"
  value {
    string_value: "FORMAT_TF_EXAMPLE"
  }
}
custom_properties {
  key: "span"
  value {
    int_value: 0
  }
}
custom_properties {
  key: "tfx_version"
  value {
    string_value: "1.13.0"
  }
}
state: LIVE
type: "Examples"
create_time_since_epoch: 1690538436516
last_update_time_since_epoch: 1690538436516
, artifact_type: id: 15
name: "Examples"
properties {
  key: "span"
  value: INT
}
properties {
  key: "split_names"
  value: STRING
}
properties {
  key: "version"
  value: INT
}
base_type: DATASET
)], 'schema': [Artifact(artifact: id: 2
type_id: 17
uri: "schema"
custom_properties {
  key: "is_external"
  value {
    int_value: 1
  }
}
custom_properties {
  key: "tfx_version"
  value {
    string_value: "1.13.0"
  }
}
state: LIVE
type: "Schema"
create_time_since_epoch: 1690538436638
last_update_time_since_epoch: 1690538436638
, artifact_type: id: 17
name: "Schema"
)]},)
INFO:absl:MetadataStore with DB connection initialized
INFO:absl:Going to run a new execution 4
INFO:absl:Going to run a new execution: ExecutionInfo(execution_id=4, input_dict={'examples': [Artifact(artifact: id: 1
type_id: 15
uri: "pipelines/penguin-transform/CsvExampleGen/examples/1"
properties {
  key: "split_names"
  value {
    string_value: "[\"train\", \"eval\"]"
  }
}
custom_properties {
  key: "file_format"
  value {
    string_value: "tfrecords_gzip"
  }
}
custom_properties {
  key: "input_fingerprint"
  value {
    string_value: "split:single_split,num_files:1,total_bytes:13161,xor_checksum:1690538432,sum_checksum:1690538432"
  }
}
custom_properties {
  key: "is_external"
  value {
    int_value: 0
  }
}
custom_properties {
  key: "payload_format"
  value {
    string_value: "FORMAT_TF_EXAMPLE"
  }
}
custom_properties {
  key: "span"
  value {
    int_value: 0
  }
}
custom_properties {
  key: "tfx_version"
  value {
    string_value: "1.13.0"
  }
}
state: LIVE
type: "Examples"
create_time_since_epoch: 1690538436516
last_update_time_since_epoch: 1690538436516
, artifact_type: id: 15
name: "Examples"
properties {
  key: "span"
  value: INT
}
properties {
  key: "split_names"
  value: STRING
}
properties {
  key: "version"
  value: INT
}
base_type: DATASET
)], 'schema': [Artifact(artifact: id: 2
type_id: 17
uri: "schema"
custom_properties {
  key: "is_external"
  value {
    int_value: 1
  }
}
custom_properties {
  key: "tfx_version"
  value {
    string_value: "1.13.0"
  }
}
state: LIVE
type: "Schema"
create_time_since_epoch: 1690538436638
last_update_time_since_epoch: 1690538436638
, artifact_type: id: 17
name: "Schema"
)]}, output_dict=defaultdict(<class 'list'>, {'pre_transform_schema': [Artifact(artifact: uri: "pipelines/penguin-transform/Transform/pre_transform_schema/4"
, artifact_type: name: "Schema"
)], 'post_transform_stats': [Artifact(artifact: uri: "pipelines/penguin-transform/Transform/post_transform_stats/4"
, artifact_type: name: "ExampleStatistics"
properties {
  key: "span"
  value: INT
}
properties {
  key: "split_names"
  value: STRING
}
base_type: STATISTICS
)], 'post_transform_anomalies': [Artifact(artifact: uri: "pipelines/penguin-transform/Transform/post_transform_anomalies/4"
, artifact_type: name: "ExampleAnomalies"
properties {
  key: "span"
  value: INT
}
properties {
  key: "split_names"
  value: STRING
}
)], 'updated_analyzer_cache': [Artifact(artifact: uri: "pipelines/penguin-transform/Transform/updated_analyzer_cache/4"
, artifact_type: name: "TransformCache"
)], 'pre_transform_stats': [Artifact(artifact: uri: "pipelines/penguin-transform/Transform/pre_transform_stats/4"
, artifact_type: name: "ExampleStatistics"
properties {
  key: "span"
  value: INT
}
properties {
  key: "split_names"
  value: STRING
}
base_type: STATISTICS
)], 'post_transform_schema': [Artifact(artifact: uri: "pipelines/penguin-transform/Transform/post_transform_schema/4"
, artifact_type: name: "Schema"
)], 'transform_graph': [Artifact(artifact: uri: "pipelines/penguin-transform/Transform/transform_graph/4"
, artifact_type: name: "TransformGraph"
)]}), exec_properties={'disable_statistics': 0, 'custom_config': 'null', 'force_tf_compat_v1': 0, 'module_path': 'penguin_utils@pipelines/penguin-transform/_wheels/tfx_user_code_Transform-0.0+a5e9139bd7facf5026b5306a6aea534f89db0dea58ebe1bb1fb5ebb9df5fdea9-py3-none-any.whl'}, execution_output_uri='pipelines/penguin-transform/Transform/.system/executor_execution/4/executor_output.pb', stateful_working_dir='pipelines/penguin-transform/Transform/.system/stateful_working_dir/2023-07-28T10:00:34.516274', tmp_dir='pipelines/penguin-transform/Transform/.system/executor_execution/4/.temp/', pipeline_node=node_info {
  type {
    name: "tfx.components.transform.component.Transform"
    base_type: TRANSFORM
  }
  id: "Transform"
}
contexts {
  contexts {
    type {
      name: "pipeline"
    }
    name {
      field_value {
        string_value: "penguin-transform"
      }
    }
  }
  contexts {
    type {
      name: "pipeline_run"
    }
    name {
      field_value {
        string_value: "2023-07-28T10:00:34.516274"
      }
    }
  }
  contexts {
    type {
      name: "node"
    }
    name {
      field_value {
        string_value: "penguin-transform.Transform"
      }
    }
  }
}
inputs {
  inputs {
    key: "examples"
    value {
      channels {
        producer_node_query {
          id: "CsvExampleGen"
        }
        context_queries {
          type {
            name: "pipeline"
          }
          name {
            field_value {
              string_value: "penguin-transform"
            }
          }
        }
        context_queries {
          type {
            name: "pipeline_run"
          }
          name {
            field_value {
              string_value: "2023-07-28T10:00:34.516274"
            }
          }
        }
        context_queries {
          type {
            name: "node"
          }
          name {
            field_value {
              string_value: "penguin-transform.CsvExampleGen"
            }
          }
        }
        artifact_query {
          type {
            name: "Examples"
            base_type: DATASET
          }
        }
        output_key: "examples"
      }
      min_count: 1
    }
  }
  inputs {
    key: "schema"
    value {
      channels {
        producer_node_query {
          id: "schema_importer"
        }
        context_queries {
          type {
            name: "pipeline"
          }
          name {
            field_value {
              string_value: "penguin-transform"
            }
          }
        }
        context_queries {
          type {
            name: "pipeline_run"
          }
          name {
            field_value {
              string_value: "2023-07-28T10:00:34.516274"
            }
          }
        }
        context_queries {
          type {
            name: "node"
          }
          name {
            field_value {
              string_value: "penguin-transform.schema_importer"
            }
          }
        }
        artifact_query {
          type {
            name: "Schema"
          }
        }
        output_key: "result"
      }
      min_count: 1
    }
  }
}
outputs {
  outputs {
    key: "post_transform_anomalies"
    value {
      artifact_spec {
        type {
          name: "ExampleAnomalies"
          properties {
            key: "span"
            value: INT
          }
          properties {
            key: "split_names"
            value: STRING
          }
        }
      }
    }
  }
  outputs {
    key: "post_transform_schema"
    value {
      artifact_spec {
        type {
          name: "Schema"
        }
      }
    }
  }
  outputs {
    key: "post_transform_stats"
    value {
      artifact_spec {
        type {
          name: "ExampleStatistics"
          properties {
            key: "span"
            value: INT
          }
          properties {
            key: "split_names"
            value: STRING
          }
          base_type: STATISTICS
        }
      }
    }
  }
  outputs {
    key: "pre_transform_schema"
    value {
      artifact_spec {
        type {
          name: "Schema"
        }
      }
    }
  }
  outputs {
    key: "pre_transform_stats"
    value {
      artifact_spec {
        type {
          name: "ExampleStatistics"
          properties {
            key: "span"
            value: INT
          }
          properties {
            key: "split_names"
            value: STRING
          }
          base_type: STATISTICS
        }
      }
    }
  }
  outputs {
    key: "transform_graph"
    value {
      artifact_spec {
        type {
          name: "TransformGraph"
        }
      }
    }
  }
  outputs {
    key: "updated_analyzer_cache"
    value {
      artifact_spec {
        type {
          name: "TransformCache"
        }
      }
    }
  }
}
parameters {
  parameters {
    key: "custom_config"
    value {
      field_value {
        string_value: "null"
      }
    }
  }
  parameters {
    key: "disable_statistics"
    value {
      field_value {
        int_value: 0
      }
    }
  }
  parameters {
    key: "force_tf_compat_v1"
    value {
      field_value {
        int_value: 0
      }
    }
  }
  parameters {
    key: "module_path"
    value {
      field_value {
        string_value: "penguin_utils@pipelines/penguin-transform/_wheels/tfx_user_code_Transform-0.0+a5e9139bd7facf5026b5306a6aea534f89db0dea58ebe1bb1fb5ebb9df5fdea9-py3-none-any.whl"
      }
    }
  }
}
upstream_nodes: "CsvExampleGen"
upstream_nodes: "schema_importer"
downstream_nodes: "Trainer"
execution_options {
  caching_options {
  }
}
, pipeline_info=id: "penguin-transform"
, pipeline_run_id='2023-07-28T10:00:34.516274')
INFO:absl:Analyze the 'train' split and transform all splits when splits_config is not set.
INFO:absl:udf_utils.get_fn {'module_file': None, 'module_path': 'penguin_utils@pipelines/penguin-transform/_wheels/tfx_user_code_Transform-0.0+a5e9139bd7facf5026b5306a6aea534f89db0dea58ebe1bb1fb5ebb9df5fdea9-py3-none-any.whl', 'preprocessing_fn': None} 'preprocessing_fn'
INFO:absl:Installing 'pipelines/penguin-transform/_wheels/tfx_user_code_Transform-0.0+a5e9139bd7facf5026b5306a6aea534f89db0dea58ebe1bb1fb5ebb9df5fdea9-py3-none-any.whl' to a temporary directory.
INFO:absl:Executing: ['/tmpfs/src/tf_docs_env/bin/python', '-m', 'pip', 'install', '--target', '/tmpfs/tmp/tmp2v5b0xtj', 'pipelines/penguin-transform/_wheels/tfx_user_code_Transform-0.0+a5e9139bd7facf5026b5306a6aea534f89db0dea58ebe1bb1fb5ebb9df5fdea9-py3-none-any.whl']
Processing ./pipelines/penguin-transform/_wheels/tfx_user_code_Transform-0.0+a5e9139bd7facf5026b5306a6aea534f89db0dea58ebe1bb1fb5ebb9df5fdea9-py3-none-any.whl
INFO:absl:Successfully installed 'pipelines/penguin-transform/_wheels/tfx_user_code_Transform-0.0+a5e9139bd7facf5026b5306a6aea534f89db0dea58ebe1bb1fb5ebb9df5fdea9-py3-none-any.whl'.
INFO:absl:udf_utils.get_fn {'module_file': None, 'module_path': 'penguin_utils@pipelines/penguin-transform/_wheels/tfx_user_code_Transform-0.0+a5e9139bd7facf5026b5306a6aea534f89db0dea58ebe1bb1fb5ebb9df5fdea9-py3-none-any.whl', 'stats_options_updater_fn': None} 'stats_options_updater_fn'
INFO:absl:Installing 'pipelines/penguin-transform/_wheels/tfx_user_code_Transform-0.0+a5e9139bd7facf5026b5306a6aea534f89db0dea58ebe1bb1fb5ebb9df5fdea9-py3-none-any.whl' to a temporary directory.
INFO:absl:Executing: ['/tmpfs/src/tf_docs_env/bin/python', '-m', 'pip', 'install', '--target', '/tmpfs/tmp/tmp0v9sxxr7', 'pipelines/penguin-transform/_wheels/tfx_user_code_Transform-0.0+a5e9139bd7facf5026b5306a6aea534f89db0dea58ebe1bb1fb5ebb9df5fdea9-py3-none-any.whl']
Installing collected packages: tfx-user-code-Transform
Successfully installed tfx-user-code-Transform-0.0+a5e9139bd7facf5026b5306a6aea534f89db0dea58ebe1bb1fb5ebb9df5fdea9
Processing ./pipelines/penguin-transform/_wheels/tfx_user_code_Transform-0.0+a5e9139bd7facf5026b5306a6aea534f89db0dea58ebe1bb1fb5ebb9df5fdea9-py3-none-any.whl
INFO:absl:Successfully installed 'pipelines/penguin-transform/_wheels/tfx_user_code_Transform-0.0+a5e9139bd7facf5026b5306a6aea534f89db0dea58ebe1bb1fb5ebb9df5fdea9-py3-none-any.whl'.
INFO:absl:Installing 'pipelines/penguin-transform/_wheels/tfx_user_code_Transform-0.0+a5e9139bd7facf5026b5306a6aea534f89db0dea58ebe1bb1fb5ebb9df5fdea9-py3-none-any.whl' to a temporary directory.
INFO:absl:Executing: ['/tmpfs/src/tf_docs_env/bin/python', '-m', 'pip', 'install', '--target', '/tmpfs/tmp/tmp2xfiz2z0', 'pipelines/penguin-transform/_wheels/tfx_user_code_Transform-0.0+a5e9139bd7facf5026b5306a6aea534f89db0dea58ebe1bb1fb5ebb9df5fdea9-py3-none-any.whl']
Installing collected packages: tfx-user-code-Transform
Successfully installed tfx-user-code-Transform-0.0+a5e9139bd7facf5026b5306a6aea534f89db0dea58ebe1bb1fb5ebb9df5fdea9
Processing ./pipelines/penguin-transform/_wheels/tfx_user_code_Transform-0.0+a5e9139bd7facf5026b5306a6aea534f89db0dea58ebe1bb1fb5ebb9df5fdea9-py3-none-any.whl
INFO:absl:Successfully installed 'pipelines/penguin-transform/_wheels/tfx_user_code_Transform-0.0+a5e9139bd7facf5026b5306a6aea534f89db0dea58ebe1bb1fb5ebb9df5fdea9-py3-none-any.whl'.
INFO:absl:Feature body_mass_g has a shape dim {
  size: 1
}
. Setting to DenseTensor.
INFO:absl:Feature culmen_depth_mm has a shape dim {
  size: 1
}
. Setting to DenseTensor.
INFO:absl:Feature culmen_length_mm has a shape dim {
  size: 1
}
. Setting to DenseTensor.
INFO:absl:Feature flipper_length_mm has a shape dim {
  size: 1
}
. Setting to DenseTensor.
INFO:absl:Feature island has a shape dim {
  size: 1
}
. Setting to DenseTensor.
INFO:absl:Feature sex has a shape dim {
  size: 1
}
. Setting to DenseTensor.
INFO:absl:Feature species has a shape dim {
  size: 1
}
. Setting to DenseTensor.
Installing collected packages: tfx-user-code-Transform
Successfully installed tfx-user-code-Transform-0.0+a5e9139bd7facf5026b5306a6aea534f89db0dea58ebe1bb1fb5ebb9df5fdea9
INFO:absl:Feature body_mass_g has a shape dim {
  size: 1
}
. Setting to DenseTensor.
INFO:absl:Feature culmen_depth_mm has a shape dim {
  size: 1
}
. Setting to DenseTensor.
INFO:absl:Feature culmen_length_mm has a shape dim {
  size: 1
}
. Setting to DenseTensor.
INFO:absl:Feature flipper_length_mm has a shape dim {
  size: 1
}
. Setting to DenseTensor.
INFO:absl:Feature island has a shape dim {
  size: 1
}
. Setting to DenseTensor.
INFO:absl:Feature sex has a shape dim {
  size: 1
}
. Setting to DenseTensor.
INFO:absl:Feature species has a shape dim {
  size: 1
}
. Setting to DenseTensor.
INFO:absl:Feature body_mass_g has a shape dim {
  size: 1
}
. Setting to DenseTensor.
INFO:absl:Feature culmen_depth_mm has a shape dim {
  size: 1
}
. Setting to DenseTensor.
INFO:absl:Feature culmen_length_mm has a shape dim {
  size: 1
}
. Setting to DenseTensor.
INFO:absl:Feature flipper_length_mm has a shape dim {
  size: 1
}
. Setting to DenseTensor.
INFO:absl:Feature island has a shape dim {
  size: 1
}
. Setting to DenseTensor.
INFO:absl:Feature sex has a shape dim {
  size: 1
}
. Setting to DenseTensor.
INFO:absl:Feature species has a shape dim {
  size: 1
}
. Setting to DenseTensor.
INFO:absl:Feature body_mass_g has a shape dim {
  size: 1
}
. Setting to DenseTensor.
INFO:absl:Feature culmen_depth_mm has a shape dim {
  size: 1
}
. Setting to DenseTensor.
INFO:absl:Feature culmen_length_mm has a shape dim {
  size: 1
}
. Setting to DenseTensor.
INFO:absl:Feature flipper_length_mm has a shape dim {
  size: 1
}
. Setting to DenseTensor.
INFO:absl:Feature species has a shape dim {
  size: 1
}
. Setting to DenseTensor.
INFO:absl:Feature body_mass_g has a shape dim {
  size: 1
}
. Setting to DenseTensor.
INFO:absl:Feature culmen_depth_mm has a shape dim {
  size: 1
}
. Setting to DenseTensor.
INFO:absl:Feature culmen_length_mm has a shape dim {
  size: 1
}
. Setting to DenseTensor.
INFO:absl:Feature flipper_length_mm has a shape dim {
  size: 1
}
. Setting to DenseTensor.
INFO:absl:Feature island has a shape dim {
  size: 1
}
. Setting to DenseTensor.
INFO:absl:Feature sex has a shape dim {
  size: 1
}
. Setting to DenseTensor.
INFO:absl:Feature species has a shape dim {
  size: 1
}
. Setting to DenseTensor.
INFO:absl:Feature body_mass_g has a shape dim {
  size: 1
}
. Setting to DenseTensor.
INFO:absl:Feature culmen_depth_mm has a shape dim {
  size: 1
}
. Setting to DenseTensor.
INFO:absl:Feature culmen_length_mm has a shape dim {
  size: 1
}
. Setting to DenseTensor.
INFO:absl:Feature flipper_length_mm has a shape dim {
  size: 1
}
. Setting to DenseTensor.
INFO:absl:Feature island has a shape dim {
  size: 1
}
. Setting to DenseTensor.
INFO:absl:Feature sex has a shape dim {
  size: 1
}
. Setting to DenseTensor.
INFO:absl:Feature species has a shape dim {
  size: 1
}
. Setting to DenseTensor.
INFO:absl:Feature body_mass_g has a shape dim {
  size: 1
}
. Setting to DenseTensor.
INFO:absl:Feature culmen_depth_mm has a shape dim {
  size: 1
}
. Setting to DenseTensor.
INFO:absl:Feature culmen_length_mm has a shape dim {
  size: 1
}
. Setting to DenseTensor.
INFO:absl:Feature flipper_length_mm has a shape dim {
  size: 1
}
. Setting to DenseTensor.
INFO:absl:Feature island has a shape dim {
  size: 1
}
. Setting to DenseTensor.
INFO:absl:Feature sex has a shape dim {
  size: 1
}
. Setting to DenseTensor.
INFO:absl:Feature species has a shape dim {
  size: 1
}
. Setting to DenseTensor.
WARNING:absl:Tables initialized inside a tf.function  will be re-initialized on every invocation of the function. This  re-initialization can have significant impact on performance. Consider lifting  them out of the graph context using  `tf.init_scope`.: key_value_init/LookupTableImportV2
WARNING:absl:Tables initialized inside a tf.function  will be re-initialized on every invocation of the function. This  re-initialization can have significant impact on performance. Consider lifting  them out of the graph context using  `tf.init_scope`.: key_value_init/LookupTableImportV2
WARNING:absl:Tables initialized inside a tf.function  will be re-initialized on every invocation of the function. This  re-initialization can have significant impact on performance. Consider lifting  them out of the graph context using  `tf.init_scope`.: key_value_init/LookupTableImportV2
WARNING:absl:Tables initialized inside a tf.function  will be re-initialized on every invocation of the function. This  re-initialization can have significant impact on performance. Consider lifting  them out of the graph context using  `tf.init_scope`.: key_value_init/LookupTableImportV2
INFO:absl:Feature body_mass_g has a shape dim {
  size: 1
}
. Setting to DenseTensor.
INFO:absl:Feature culmen_depth_mm has a shape dim {
  size: 1
}
. Setting to DenseTensor.
INFO:absl:Feature culmen_length_mm has a shape dim {
  size: 1
}
. Setting to DenseTensor.
INFO:absl:Feature flipper_length_mm has a shape dim {
  size: 1
}
. Setting to DenseTensor.
INFO:absl:Feature species has a shape dim {
  size: 1
}
. Setting to DenseTensor.
INFO:absl:Feature body_mass_g has a shape dim {
  size: 1
}
. Setting to DenseTensor.
INFO:absl:Feature culmen_depth_mm has a shape dim {
  size: 1
}
. Setting to DenseTensor.
INFO:absl:Feature culmen_length_mm has a shape dim {
  size: 1
}
. Setting to DenseTensor.
INFO:absl:Feature flipper_length_mm has a shape dim {
  size: 1
}
. Setting to DenseTensor.
INFO:absl:Feature island has a shape dim {
  size: 1
}
. Setting to DenseTensor.
INFO:absl:Feature sex has a shape dim {
  size: 1
}
. Setting to DenseTensor.
INFO:absl:Feature species has a shape dim {
  size: 1
}
. Setting to DenseTensor.
INFO:absl:Feature body_mass_g has a shape dim {
  size: 1
}
. Setting to DenseTensor.
INFO:absl:Feature culmen_depth_mm has a shape dim {
  size: 1
}
. Setting to DenseTensor.
INFO:absl:Feature culmen_length_mm has a shape dim {
  size: 1
}
. Setting to DenseTensor.
INFO:absl:Feature flipper_length_mm has a shape dim {
  size: 1
}
. Setting to DenseTensor.
INFO:absl:Feature species has a shape dim {
  size: 1
}
. Setting to DenseTensor.
INFO:absl:Feature body_mass_g has a shape dim {
  size: 1
}
. Setting to DenseTensor.
INFO:absl:Feature culmen_depth_mm has a shape dim {
  size: 1
}
. Setting to DenseTensor.
INFO:absl:Feature culmen_length_mm has a shape dim {
  size: 1
}
. Setting to DenseTensor.
INFO:absl:Feature flipper_length_mm has a shape dim {
  size: 1
}
. Setting to DenseTensor.
INFO:absl:Feature island has a shape dim {
  size: 1
}
. Setting to DenseTensor.
INFO:absl:Feature sex has a shape dim {
  size: 1
}
. Setting to DenseTensor.
INFO:absl:Feature species has a shape dim {
  size: 1
}
. Setting to DenseTensor.
INFO:tensorflow:Assets written to: pipelines/penguin-transform/Transform/transform_graph/4/.temp_path/tftransform_tmp/d8544d8785a745768246e7c4a78d1ef3/assets
INFO:tensorflow:Assets written to: pipelines/penguin-transform/Transform/transform_graph/4/.temp_path/tftransform_tmp/d8544d8785a745768246e7c4a78d1ef3/assets
WARNING:absl:Tables initialized inside a tf.function  will be re-initialized on every invocation of the function. This  re-initialization can have significant impact on performance. Consider lifting  them out of the graph context using  `tf.init_scope`.: key_value_init/LookupTableImportV2
INFO:tensorflow:struct2tensor is not available.
INFO:tensorflow:struct2tensor is not available.
INFO:tensorflow:tensorflow_decision_forests is not available.
INFO:tensorflow:tensorflow_decision_forests is not available.
INFO:tensorflow:tensorflow_text is not available.
INFO:tensorflow:tensorflow_text is not available.
INFO:tensorflow:Assets written to: pipelines/penguin-transform/Transform/transform_graph/4/.temp_path/tftransform_tmp/045c01afab3547e4a60d242c20a541f5/assets
INFO:tensorflow:Assets written to: pipelines/penguin-transform/Transform/transform_graph/4/.temp_path/tftransform_tmp/045c01afab3547e4a60d242c20a541f5/assets
WARNING:absl:Tables initialized inside a tf.function  will be re-initialized on every invocation of the function. This  re-initialization can have significant impact on performance. Consider lifting  them out of the graph context using  `tf.init_scope`.: key_value_init/LookupTableImportV2
INFO:absl:Feature body_mass_g has a shape dim {
  size: 1
}
. Setting to DenseTensor.
INFO:absl:Feature culmen_depth_mm has a shape dim {
  size: 1
}
. Setting to DenseTensor.
INFO:absl:Feature culmen_length_mm has a shape dim {
  size: 1
}
. Setting to DenseTensor.
INFO:absl:Feature flipper_length_mm has a shape dim {
  size: 1
}
. Setting to DenseTensor.
INFO:absl:Feature species has a shape dim {
  size: 1
}
. Setting to DenseTensor.
INFO:absl:Feature body_mass_g has a shape dim {
  size: 1
}
. Setting to DenseTensor.
INFO:absl:Feature culmen_depth_mm has a shape dim {
  size: 1
}
. Setting to DenseTensor.
INFO:absl:Feature culmen_length_mm has a shape dim {
  size: 1
}
. Setting to DenseTensor.
INFO:absl:Feature flipper_length_mm has a shape dim {
  size: 1
}
. Setting to DenseTensor.
INFO:absl:Feature species has a shape dim {
  size: 1
}
. Setting to DenseTensor.
INFO:tensorflow:struct2tensor is not available.
INFO:tensorflow:struct2tensor is not available.
INFO:tensorflow:tensorflow_decision_forests is not available.
INFO:tensorflow:tensorflow_decision_forests is not available.
INFO:tensorflow:tensorflow_text is not available.
INFO:tensorflow:tensorflow_text is not available.
INFO:tensorflow:struct2tensor is not available.
INFO:tensorflow:struct2tensor is not available.
INFO:tensorflow:tensorflow_decision_forests is not available.
INFO:tensorflow:tensorflow_decision_forests is not available.
INFO:tensorflow:tensorflow_text is not available.
INFO:tensorflow:tensorflow_text is not available.
INFO:absl:Cleaning up stateless execution info.
INFO:absl:Execution 4 succeeded.
INFO:absl:Cleaning up stateful execution info.
INFO:absl:Publishing output artifacts defaultdict(<class 'list'>, {'pre_transform_schema': [Artifact(artifact: uri: "pipelines/penguin-transform/Transform/pre_transform_schema/4"
, artifact_type: name: "Schema"
)], 'post_transform_stats': [Artifact(artifact: uri: "pipelines/penguin-transform/Transform/post_transform_stats/4"
, artifact_type: name: "ExampleStatistics"
properties {
  key: "span"
  value: INT
}
properties {
  key: "split_names"
  value: STRING
}
base_type: STATISTICS
)], 'post_transform_anomalies': [Artifact(artifact: uri: "pipelines/penguin-transform/Transform/post_transform_anomalies/4"
, artifact_type: name: "ExampleAnomalies"
properties {
  key: "span"
  value: INT
}
properties {
  key: "split_names"
  value: STRING
}
)], 'updated_analyzer_cache': [Artifact(artifact: uri: "pipelines/penguin-transform/Transform/updated_analyzer_cache/4"
, artifact_type: name: "TransformCache"
)], 'pre_transform_stats': [Artifact(artifact: uri: "pipelines/penguin-transform/Transform/pre_transform_stats/4"
, artifact_type: name: "ExampleStatistics"
properties {
  key: "span"
  value: INT
}
properties {
  key: "split_names"
  value: STRING
}
base_type: STATISTICS
)], 'post_transform_schema': [Artifact(artifact: uri: "pipelines/penguin-transform/Transform/post_transform_schema/4"
, artifact_type: name: "Schema"
)], 'transform_graph': [Artifact(artifact: uri: "pipelines/penguin-transform/Transform/transform_graph/4"
, artifact_type: name: "TransformGraph"
)]}) for execution 4
INFO:absl:MetadataStore with DB connection initialized
INFO:absl:Component Transform is finished.
INFO:absl:Component ExampleValidator is running.
INFO:absl:Running launcher for node_info {
  type {
    name: "tfx.components.example_validator.component.ExampleValidator"
  }
  id: "ExampleValidator"
}
contexts {
  contexts {
    type {
      name: "pipeline"
    }
    name {
      field_value {
        string_value: "penguin-transform"
      }
    }
  }
  contexts {
    type {
      name: "pipeline_run"
    }
    name {
      field_value {
        string_value: "2023-07-28T10:00:34.516274"
      }
    }
  }
  contexts {
    type {
      name: "node"
    }
    name {
      field_value {
        string_value: "penguin-transform.ExampleValidator"
      }
    }
  }
}
inputs {
  inputs {
    key: "schema"
    value {
      channels {
        producer_node_query {
          id: "schema_importer"
        }
        context_queries {
          type {
            name: "pipeline"
          }
          name {
            field_value {
              string_value: "penguin-transform"
            }
          }
        }
        context_queries {
          type {
            name: "pipeline_run"
          }
          name {
            field_value {
              string_value: "2023-07-28T10:00:34.516274"
            }
          }
        }
        context_queries {
          type {
            name: "node"
          }
          name {
            field_value {
              string_value: "penguin-transform.schema_importer"
            }
          }
        }
        artifact_query {
          type {
            name: "Schema"
          }
        }
        output_key: "result"
      }
      min_count: 1
    }
  }
  inputs {
    key: "statistics"
    value {
      channels {
        producer_node_query {
          id: "StatisticsGen"
        }
        context_queries {
          type {
            name: "pipeline"
          }
          name {
            field_value {
              string_value: "penguin-transform"
            }
          }
        }
        context_queries {
          type {
            name: "pipeline_run"
          }
          name {
            field_value {
              string_value: "2023-07-28T10:00:34.516274"
            }
          }
        }
        context_queries {
          type {
            name: "node"
          }
          name {
            field_value {
              string_value: "penguin-transform.StatisticsGen"
            }
          }
        }
        artifact_query {
          type {
            name: "ExampleStatistics"
            base_type: STATISTICS
          }
        }
        output_key: "statistics"
      }
      min_count: 1
    }
  }
}
outputs {
  outputs {
    key: "anomalies"
    value {
      artifact_spec {
        type {
          name: "ExampleAnomalies"
          properties {
            key: "span"
            value: INT
          }
          properties {
            key: "split_names"
            value: STRING
          }
        }
      }
    }
  }
}
parameters {
  parameters {
    key: "exclude_splits"
    value {
      field_value {
        string_value: "[]"
      }
    }
  }
}
upstream_nodes: "StatisticsGen"
upstream_nodes: "schema_importer"
execution_options {
  caching_options {
  }
}

INFO:absl:MetadataStore with DB connection initialized
WARNING:absl:ArtifactQuery.property_predicate is not supported.
WARNING:absl:ArtifactQuery.property_predicate is not supported.
INFO:absl:[ExampleValidator] Resolved inputs: ({'schema': [Artifact(artifact: id: 2
type_id: 17
uri: "schema"
custom_properties {
  key: "is_external"
  value {
    int_value: 1
  }
}
custom_properties {
  key: "tfx_version"
  value {
    string_value: "1.13.0"
  }
}
state: LIVE
type: "Schema"
create_time_since_epoch: 1690538436638
last_update_time_since_epoch: 1690538436638
, artifact_type: id: 17
name: "Schema"
)], 'statistics': [Artifact(artifact: id: 3
type_id: 19
uri: "pipelines/penguin-transform/StatisticsGen/statistics/3"
properties {
  key: "split_names"
  value {
    string_value: "[\"train\", \"eval\"]"
  }
}
custom_properties {
  key: "is_external"
  value {
    int_value: 0
  }
}
custom_properties {
  key: "stats_dashboard_link"
  value {
    string_value: ""
  }
}
custom_properties {
  key: "tfx_version"
  value {
    string_value: "1.13.0"
  }
}
state: LIVE
type: "ExampleStatistics"
create_time_since_epoch: 1690538440107
last_update_time_since_epoch: 1690538440107
, artifact_type: id: 19
name: "ExampleStatistics"
properties {
  key: "span"
  value: INT
}
properties {
  key: "split_names"
  value: STRING
}
base_type: STATISTICS
)]},)
INFO:absl:MetadataStore with DB connection initialized
INFO:absl:Going to run a new execution 5
INFO:absl:Going to run a new execution: ExecutionInfo(execution_id=5, input_dict={'schema': [Artifact(artifact: id: 2
type_id: 17
uri: "schema"
custom_properties {
  key: "is_external"
  value {
    int_value: 1
  }
}
custom_properties {
  key: "tfx_version"
  value {
    string_value: "1.13.0"
  }
}
state: LIVE
type: "Schema"
create_time_since_epoch: 1690538436638
last_update_time_since_epoch: 1690538436638
, artifact_type: id: 17
name: "Schema"
)], 'statistics': [Artifact(artifact: id: 3
type_id: 19
uri: "pipelines/penguin-transform/StatisticsGen/statistics/3"
properties {
  key: "split_names"
  value {
    string_value: "[\"train\", \"eval\"]"
  }
}
custom_properties {
  key: "is_external"
  value {
    int_value: 0
  }
}
custom_properties {
  key: "stats_dashboard_link"
  value {
    string_value: ""
  }
}
custom_properties {
  key: "tfx_version"
  value {
    string_value: "1.13.0"
  }
}
state: LIVE
type: "ExampleStatistics"
create_time_since_epoch: 1690538440107
last_update_time_since_epoch: 1690538440107
, artifact_type: id: 19
name: "ExampleStatistics"
properties {
  key: "span"
  value: INT
}
properties {
  key: "split_names"
  value: STRING
}
base_type: STATISTICS
)]}, output_dict=defaultdict(<class 'list'>, {'anomalies': [Artifact(artifact: uri: "pipelines/penguin-transform/ExampleValidator/anomalies/5"
, artifact_type: name: "ExampleAnomalies"
properties {
  key: "span"
  value: INT
}
properties {
  key: "split_names"
  value: STRING
}
)]}), exec_properties={'exclude_splits': '[]'}, execution_output_uri='pipelines/penguin-transform/ExampleValidator/.system/executor_execution/5/executor_output.pb', stateful_working_dir='pipelines/penguin-transform/ExampleValidator/.system/stateful_working_dir/2023-07-28T10:00:34.516274', tmp_dir='pipelines/penguin-transform/ExampleValidator/.system/executor_execution/5/.temp/', pipeline_node=node_info {
  type {
    name: "tfx.components.example_validator.component.ExampleValidator"
  }
  id: "ExampleValidator"
}
contexts {
  contexts {
    type {
      name: "pipeline"
    }
    name {
      field_value {
        string_value: "penguin-transform"
      }
    }
  }
  contexts {
    type {
      name: "pipeline_run"
    }
    name {
      field_value {
        string_value: "2023-07-28T10:00:34.516274"
      }
    }
  }
  contexts {
    type {
      name: "node"
    }
    name {
      field_value {
        string_value: "penguin-transform.ExampleValidator"
      }
    }
  }
}
inputs {
  inputs {
    key: "schema"
    value {
      channels {
        producer_node_query {
          id: "schema_importer"
        }
        context_queries {
          type {
            name: "pipeline"
          }
          name {
            field_value {
              string_value: "penguin-transform"
            }
          }
        }
        context_queries {
          type {
            name: "pipeline_run"
          }
          name {
            field_value {
              string_value: "2023-07-28T10:00:34.516274"
            }
          }
        }
        context_queries {
          type {
            name: "node"
          }
          name {
            field_value {
              string_value: "penguin-transform.schema_importer"
            }
          }
        }
        artifact_query {
          type {
            name: "Schema"
          }
        }
        output_key: "result"
      }
      min_count: 1
    }
  }
  inputs {
    key: "statistics"
    value {
      channels {
        producer_node_query {
          id: "StatisticsGen"
        }
        context_queries {
          type {
            name: "pipeline"
          }
          name {
            field_value {
              string_value: "penguin-transform"
            }
          }
        }
        context_queries {
          type {
            name: "pipeline_run"
          }
          name {
            field_value {
              string_value: "2023-07-28T10:00:34.516274"
            }
          }
        }
        context_queries {
          type {
            name: "node"
          }
          name {
            field_value {
              string_value: "penguin-transform.StatisticsGen"
            }
          }
        }
        artifact_query {
          type {
            name: "ExampleStatistics"
            base_type: STATISTICS
          }
        }
        output_key: "statistics"
      }
      min_count: 1
    }
  }
}
outputs {
  outputs {
    key: "anomalies"
    value {
      artifact_spec {
        type {
          name: "ExampleAnomalies"
          properties {
            key: "span"
            value: INT
          }
          properties {
            key: "split_names"
            value: STRING
          }
        }
      }
    }
  }
}
parameters {
  parameters {
    key: "exclude_splits"
    value {
      field_value {
        string_value: "[]"
      }
    }
  }
}
upstream_nodes: "StatisticsGen"
upstream_nodes: "schema_importer"
execution_options {
  caching_options {
  }
}
, pipeline_info=id: "penguin-transform"
, pipeline_run_id='2023-07-28T10:00:34.516274')
INFO:absl:Validating schema against the computed statistics for split train.
INFO:absl:Validation complete for split train. Anomalies written to pipelines/penguin-transform/ExampleValidator/anomalies/5/Split-train.
INFO:absl:Validating schema against the computed statistics for split eval.
INFO:absl:Validation complete for split eval. Anomalies written to pipelines/penguin-transform/ExampleValidator/anomalies/5/Split-eval.
INFO:absl:Cleaning up stateless execution info.
INFO:absl:Execution 5 succeeded.
INFO:absl:Cleaning up stateful execution info.
INFO:absl:Publishing output artifacts defaultdict(<class 'list'>, {'anomalies': [Artifact(artifact: uri: "pipelines/penguin-transform/ExampleValidator/anomalies/5"
, artifact_type: name: "ExampleAnomalies"
properties {
  key: "span"
  value: INT
}
properties {
  key: "split_names"
  value: STRING
}
)]}) for execution 5
INFO:absl:MetadataStore with DB connection initialized
INFO:absl:Component ExampleValidator is finished.
INFO:absl:Component Trainer is running.
INFO:absl:Running launcher for node_info {
  type {
    name: "tfx.components.trainer.component.Trainer"
    base_type: TRAIN
  }
  id: "Trainer"
}
contexts {
  contexts {
    type {
      name: "pipeline"
    }
    name {
      field_value {
        string_value: "penguin-transform"
      }
    }
  }
  contexts {
    type {
      name: "pipeline_run"
    }
    name {
      field_value {
        string_value: "2023-07-28T10:00:34.516274"
      }
    }
  }
  contexts {
    type {
      name: "node"
    }
    name {
      field_value {
        string_value: "penguin-transform.Trainer"
      }
    }
  }
}
inputs {
  inputs {
    key: "examples"
    value {
      channels {
        producer_node_query {
          id: "CsvExampleGen"
        }
        context_queries {
          type {
            name: "pipeline"
          }
          name {
            field_value {
              string_value: "penguin-transform"
            }
          }
        }
        context_queries {
          type {
            name: "pipeline_run"
          }
          name {
            field_value {
              string_value: "2023-07-28T10:00:34.516274"
            }
          }
        }
        context_queries {
          type {
            name: "node"
          }
          name {
            field_value {
              string_value: "penguin-transform.CsvExampleGen"
            }
          }
        }
        artifact_query {
          type {
            name: "Examples"
            base_type: DATASET
          }
        }
        output_key: "examples"
      }
      min_count: 1
    }
  }
  inputs {
    key: "transform_graph"
    value {
      channels {
        producer_node_query {
          id: "Transform"
        }
        context_queries {
          type {
            name: "pipeline"
          }
          name {
            field_value {
              string_value: "penguin-transform"
            }
          }
        }
        context_queries {
          type {
            name: "pipeline_run"
          }
          name {
            field_value {
              string_value: "2023-07-28T10:00:34.516274"
            }
          }
        }
        context_queries {
          type {
            name: "node"
          }
          name {
            field_value {
              string_value: "penguin-transform.Transform"
            }
          }
        }
        artifact_query {
          type {
            name: "TransformGraph"
          }
        }
        output_key: "transform_graph"
      }
    }
  }
}
outputs {
  outputs {
    key: "model"
    value {
      artifact_spec {
        type {
          name: "Model"
          base_type: MODEL
        }
      }
    }
  }
  outputs {
    key: "model_run"
    value {
      artifact_spec {
        type {
          name: "ModelRun"
        }
      }
    }
  }
}
parameters {
  parameters {
    key: "custom_config"
    value {
      field_value {
        string_value: "null"
      }
    }
  }
  parameters {
    key: "eval_args"
    value {
      field_value {
        string_value: "{\n  \"num_steps\": 5\n}"
      }
    }
  }
  parameters {
    key: "module_path"
    value {
      field_value {
        string_value: "penguin_utils@pipelines/penguin-transform/_wheels/tfx_user_code_Trainer-0.0+a5e9139bd7facf5026b5306a6aea534f89db0dea58ebe1bb1fb5ebb9df5fdea9-py3-none-any.whl"
      }
    }
  }
  parameters {
    key: "train_args"
    value {
      field_value {
        string_value: "{\n  \"num_steps\": 100\n}"
      }
    }
  }
}
upstream_nodes: "CsvExampleGen"
upstream_nodes: "Transform"
downstream_nodes: "Pusher"
execution_options {
  caching_options {
  }
}

INFO:absl:MetadataStore with DB connection initialized
WARNING:absl:ArtifactQuery.property_predicate is not supported.
WARNING:absl:ArtifactQuery.property_predicate is not supported.
INFO:absl:[Trainer] Resolved inputs: ({'examples': [Artifact(artifact: id: 1
type_id: 15
uri: "pipelines/penguin-transform/CsvExampleGen/examples/1"
properties {
  key: "split_names"
  value {
    string_value: "[\"train\", \"eval\"]"
  }
}
custom_properties {
  key: "file_format"
  value {
    string_value: "tfrecords_gzip"
  }
}
custom_properties {
  key: "input_fingerprint"
  value {
    string_value: "split:single_split,num_files:1,total_bytes:13161,xor_checksum:1690538432,sum_checksum:1690538432"
  }
}
custom_properties {
  key: "is_external"
  value {
    int_value: 0
  }
}
custom_properties {
  key: "payload_format"
  value {
    string_value: "FORMAT_TF_EXAMPLE"
  }
}
custom_properties {
  key: "span"
  value {
    int_value: 0
  }
}
custom_properties {
  key: "tfx_version"
  value {
    string_value: "1.13.0"
  }
}
state: LIVE
type: "Examples"
create_time_since_epoch: 1690538436516
last_update_time_since_epoch: 1690538436516
, artifact_type: id: 15
name: "Examples"
properties {
  key: "span"
  value: INT
}
properties {
  key: "split_names"
  value: STRING
}
properties {
  key: "version"
  value: INT
}
base_type: DATASET
)], 'transform_graph': [Artifact(artifact: id: 10
type_id: 23
uri: "pipelines/penguin-transform/Transform/transform_graph/4"
custom_properties {
  key: "is_external"
  value {
    int_value: 0
  }
}
custom_properties {
  key: "tfx_version"
  value {
    string_value: "1.13.0"
  }
}
state: LIVE
type: "TransformGraph"
create_time_since_epoch: 1690538464650
last_update_time_since_epoch: 1690538464650
, artifact_type: id: 23
name: "TransformGraph"
)]},)
INFO:absl:MetadataStore with DB connection initialized
INFO:absl:Going to run a new execution 6
INFO:absl:Going to run a new execution: ExecutionInfo(execution_id=6, input_dict={'examples': [Artifact(artifact: id: 1
type_id: 15
uri: "pipelines/penguin-transform/CsvExampleGen/examples/1"
properties {
  key: "split_names"
  value {
    string_value: "[\"train\", \"eval\"]"
  }
}
custom_properties {
  key: "file_format"
  value {
    string_value: "tfrecords_gzip"
  }
}
custom_properties {
  key: "input_fingerprint"
  value {
    string_value: "split:single_split,num_files:1,total_bytes:13161,xor_checksum:1690538432,sum_checksum:1690538432"
  }
}
custom_properties {
  key: "is_external"
  value {
    int_value: 0
  }
}
custom_properties {
  key: "payload_format"
  value {
    string_value: "FORMAT_TF_EXAMPLE"
  }
}
custom_properties {
  key: "span"
  value {
    int_value: 0
  }
}
custom_properties {
  key: "tfx_version"
  value {
    string_value: "1.13.0"
  }
}
state: LIVE
type: "Examples"
create_time_since_epoch: 1690538436516
last_update_time_since_epoch: 1690538436516
, artifact_type: id: 15
name: "Examples"
properties {
  key: "span"
  value: INT
}
properties {
  key: "split_names"
  value: STRING
}
properties {
  key: "version"
  value: INT
}
base_type: DATASET
)], 'transform_graph': [Artifact(artifact: id: 10
type_id: 23
uri: "pipelines/penguin-transform/Transform/transform_graph/4"
custom_properties {
  key: "is_external"
  value {
    int_value: 0
  }
}
custom_properties {
  key: "tfx_version"
  value {
    string_value: "1.13.0"
  }
}
state: LIVE
type: "TransformGraph"
create_time_since_epoch: 1690538464650
last_update_time_since_epoch: 1690538464650
, artifact_type: id: 23
name: "TransformGraph"
)]}, output_dict=defaultdict(<class 'list'>, {'model': [Artifact(artifact: uri: "pipelines/penguin-transform/Trainer/model/6"
, artifact_type: name: "Model"
base_type: MODEL
)], 'model_run': [Artifact(artifact: uri: "pipelines/penguin-transform/Trainer/model_run/6"
, artifact_type: name: "ModelRun"
)]}), exec_properties={'eval_args': '{\n  "num_steps": 5\n}', 'module_path': 'penguin_utils@pipelines/penguin-transform/_wheels/tfx_user_code_Trainer-0.0+a5e9139bd7facf5026b5306a6aea534f89db0dea58ebe1bb1fb5ebb9df5fdea9-py3-none-any.whl', 'train_args': '{\n  "num_steps": 100\n}', 'custom_config': 'null'}, execution_output_uri='pipelines/penguin-transform/Trainer/.system/executor_execution/6/executor_output.pb', stateful_working_dir='pipelines/penguin-transform/Trainer/.system/stateful_working_dir/2023-07-28T10:00:34.516274', tmp_dir='pipelines/penguin-transform/Trainer/.system/executor_execution/6/.temp/', pipeline_node=node_info {
  type {
    name: "tfx.components.trainer.component.Trainer"
    base_type: TRAIN
  }
  id: "Trainer"
}
contexts {
  contexts {
    type {
      name: "pipeline"
    }
    name {
      field_value {
        string_value: "penguin-transform"
      }
    }
  }
  contexts {
    type {
      name: "pipeline_run"
    }
    name {
      field_value {
        string_value: "2023-07-28T10:00:34.516274"
      }
    }
  }
  contexts {
    type {
      name: "node"
    }
    name {
      field_value {
        string_value: "penguin-transform.Trainer"
      }
    }
  }
}
inputs {
  inputs {
    key: "examples"
    value {
      channels {
        producer_node_query {
          id: "CsvExampleGen"
        }
        context_queries {
          type {
            name: "pipeline"
          }
          name {
            field_value {
              string_value: "penguin-transform"
            }
          }
        }
        context_queries {
          type {
            name: "pipeline_run"
          }
          name {
            field_value {
              string_value: "2023-07-28T10:00:34.516274"
            }
          }
        }
        context_queries {
          type {
            name: "node"
          }
          name {
            field_value {
              string_value: "penguin-transform.CsvExampleGen"
            }
          }
        }
        artifact_query {
          type {
            name: "Examples"
            base_type: DATASET
          }
        }
        output_key: "examples"
      }
      min_count: 1
    }
  }
  inputs {
    key: "transform_graph"
    value {
      channels {
        producer_node_query {
          id: "Transform"
        }
        context_queries {
          type {
            name: "pipeline"
          }
          name {
            field_value {
              string_value: "penguin-transform"
            }
          }
        }
        context_queries {
          type {
            name: "pipeline_run"
          }
          name {
            field_value {
              string_value: "2023-07-28T10:00:34.516274"
            }
          }
        }
        context_queries {
          type {
            name: "node"
          }
          name {
            field_value {
              string_value: "penguin-transform.Transform"
            }
          }
        }
        artifact_query {
          type {
            name: "TransformGraph"
          }
        }
        output_key: "transform_graph"
      }
    }
  }
}
outputs {
  outputs {
    key: "model"
    value {
      artifact_spec {
        type {
          name: "Model"
          base_type: MODEL
        }
      }
    }
  }
  outputs {
    key: "model_run"
    value {
      artifact_spec {
        type {
          name: "ModelRun"
        }
      }
    }
  }
}
parameters {
  parameters {
    key: "custom_config"
    value {
      field_value {
        string_value: "null"
      }
    }
  }
  parameters {
    key: "eval_args"
    value {
      field_value {
        string_value: "{\n  \"num_steps\": 5\n}"
      }
    }
  }
  parameters {
    key: "module_path"
    value {
      field_value {
        string_value: "penguin_utils@pipelines/penguin-transform/_wheels/tfx_user_code_Trainer-0.0+a5e9139bd7facf5026b5306a6aea534f89db0dea58ebe1bb1fb5ebb9df5fdea9-py3-none-any.whl"
      }
    }
  }
  parameters {
    key: "train_args"
    value {
      field_value {
        string_value: "{\n  \"num_steps\": 100\n}"
      }
    }
  }
}
upstream_nodes: "CsvExampleGen"
upstream_nodes: "Transform"
downstream_nodes: "Pusher"
execution_options {
  caching_options {
  }
}
, pipeline_info=id: "penguin-transform"
, pipeline_run_id='2023-07-28T10:00:34.516274')
INFO:absl:Train on the 'train' split when train_args.splits is not set.
INFO:absl:Evaluate on the 'eval' split when eval_args.splits is not set.
INFO:absl:udf_utils.get_fn {'eval_args': '{\n  "num_steps": 5\n}', 'module_path': 'penguin_utils@pipelines/penguin-transform/_wheels/tfx_user_code_Trainer-0.0+a5e9139bd7facf5026b5306a6aea534f89db0dea58ebe1bb1fb5ebb9df5fdea9-py3-none-any.whl', 'train_args': '{\n  "num_steps": 100\n}', 'custom_config': 'null'} 'run_fn'
INFO:absl:Installing 'pipelines/penguin-transform/_wheels/tfx_user_code_Trainer-0.0+a5e9139bd7facf5026b5306a6aea534f89db0dea58ebe1bb1fb5ebb9df5fdea9-py3-none-any.whl' to a temporary directory.
INFO:absl:Executing: ['/tmpfs/src/tf_docs_env/bin/python', '-m', 'pip', 'install', '--target', '/tmpfs/tmp/tmp37t0ji0r', 'pipelines/penguin-transform/_wheels/tfx_user_code_Trainer-0.0+a5e9139bd7facf5026b5306a6aea534f89db0dea58ebe1bb1fb5ebb9df5fdea9-py3-none-any.whl']
Processing ./pipelines/penguin-transform/_wheels/tfx_user_code_Trainer-0.0+a5e9139bd7facf5026b5306a6aea534f89db0dea58ebe1bb1fb5ebb9df5fdea9-py3-none-any.whl
INFO:absl:Successfully installed 'pipelines/penguin-transform/_wheels/tfx_user_code_Trainer-0.0+a5e9139bd7facf5026b5306a6aea534f89db0dea58ebe1bb1fb5ebb9df5fdea9-py3-none-any.whl'.
INFO:absl:Training model.
INFO:absl:Feature body_mass_g has a shape dim {
  size: 1
}
. Setting to DenseTensor.
INFO:absl:Feature culmen_depth_mm has a shape dim {
  size: 1
}
. Setting to DenseTensor.
INFO:absl:Feature culmen_length_mm has a shape dim {
  size: 1
}
. Setting to DenseTensor.
INFO:absl:Feature flipper_length_mm has a shape dim {
  size: 1
}
. Setting to DenseTensor.
INFO:absl:Feature island has a shape dim {
  size: 1
}
. Setting to DenseTensor.
INFO:absl:Feature sex has a shape dim {
  size: 1
}
. Setting to DenseTensor.
INFO:absl:Feature species has a shape dim {
  size: 1
}
. Setting to DenseTensor.
Installing collected packages: tfx-user-code-Trainer
Successfully installed tfx-user-code-Trainer-0.0+a5e9139bd7facf5026b5306a6aea534f89db0dea58ebe1bb1fb5ebb9df5fdea9
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tfx_bsl/tfxio/tf_example_record.py:339: parse_example_dataset (from tensorflow.python.data.experimental.ops.parsing_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use `tf.data.Dataset.map(tf.io.parse_example(...))` instead.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tfx_bsl/tfxio/tf_example_record.py:339: parse_example_dataset (from tensorflow.python.data.experimental.ops.parsing_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use `tf.data.Dataset.map(tf.io.parse_example(...))` instead.
INFO:tensorflow:struct2tensor is not available.
INFO:tensorflow:struct2tensor is not available.
INFO:tensorflow:tensorflow_decision_forests is not available.
INFO:tensorflow:tensorflow_decision_forests is not available.
INFO:tensorflow:tensorflow_text is not available.
INFO:tensorflow:tensorflow_text is not available.
INFO:absl:Feature body_mass_g has a shape dim {
  size: 1
}
. Setting to DenseTensor.
INFO:absl:Feature culmen_depth_mm has a shape dim {
  size: 1
}
. Setting to DenseTensor.
INFO:absl:Feature culmen_length_mm has a shape dim {
  size: 1
}
. Setting to DenseTensor.
INFO:absl:Feature flipper_length_mm has a shape dim {
  size: 1
}
. Setting to DenseTensor.
INFO:absl:Feature island has a shape dim {
  size: 1
}
. Setting to DenseTensor.
INFO:absl:Feature sex has a shape dim {
  size: 1
}
. Setting to DenseTensor.
INFO:absl:Feature species has a shape dim {
  size: 1
}
. Setting to DenseTensor.
INFO:absl:Model: "model"
INFO:absl:__________________________________________________________________________________________________
INFO:absl: Layer (type)                   Output Shape         Param #     Connected to                     
INFO:absl:==================================================================================================
INFO:absl: culmen_length_mm (InputLayer)  [(None, 1)]          0           []                               
INFO:absl:                                                                                                  
INFO:absl: culmen_depth_mm (InputLayer)   [(None, 1)]          0           []                               
INFO:absl:                                                                                                  
INFO:absl: flipper_length_mm (InputLayer)  [(None, 1)]         0           []                               
INFO:absl:                                                                                                  
INFO:absl: body_mass_g (InputLayer)       [(None, 1)]          0           []                               
INFO:absl:                                                                                                  
INFO:absl: concatenate (Concatenate)      (None, 4)            0           ['culmen_length_mm[0][0]',       
INFO:absl:                                                                  'culmen_depth_mm[0][0]',        
INFO:absl:                                                                  'flipper_length_mm[0][0]',      
INFO:absl:                                                                  'body_mass_g[0][0]']            
INFO:absl:                                                                                                  
INFO:absl: dense (Dense)                  (None, 8)            40          ['concatenate[0][0]']            
INFO:absl:                                                                                                  
INFO:absl: dense_1 (Dense)                (None, 8)            72          ['dense[0][0]']                  
INFO:absl:                                                                                                  
INFO:absl: dense_2 (Dense)                (None, 3)            27          ['dense_1[0][0]']                
INFO:absl:                                                                                                  
INFO:absl:==================================================================================================
INFO:absl:Total params: 139
INFO:absl:Trainable params: 139
INFO:absl:Non-trainable params: 0
INFO:absl:__________________________________________________________________________________________________
100/100 [==============================] - 3s 6ms/step - loss: 0.3654 - sparse_categorical_accuracy: 0.9520 - val_loss: 0.1495 - val_sparse_categorical_accuracy: 1.0000
INFO:absl:Feature body_mass_g has a shape dim {
  size: 1
}
. Setting to DenseTensor.
INFO:absl:Feature culmen_depth_mm has a shape dim {
  size: 1
}
. Setting to DenseTensor.
INFO:absl:Feature culmen_length_mm has a shape dim {
  size: 1
}
. Setting to DenseTensor.
INFO:absl:Feature flipper_length_mm has a shape dim {
  size: 1
}
. Setting to DenseTensor.
INFO:absl:Feature island has a shape dim {
  size: 1
}
. Setting to DenseTensor.
INFO:absl:Feature sex has a shape dim {
  size: 1
}
. Setting to DenseTensor.
INFO:absl:Feature species has a shape dim {
  size: 1
}
. Setting to DenseTensor.
INFO:tensorflow:Assets written to: pipelines/penguin-transform/Trainer/model/6/Format-Serving/assets
INFO:tensorflow:Assets written to: pipelines/penguin-transform/Trainer/model/6/Format-Serving/assets
INFO:absl:Training complete. Model written to pipelines/penguin-transform/Trainer/model/6/Format-Serving. ModelRun written to pipelines/penguin-transform/Trainer/model_run/6
INFO:absl:Cleaning up stateless execution info.
INFO:absl:Execution 6 succeeded.
INFO:absl:Cleaning up stateful execution info.
INFO:absl:Publishing output artifacts defaultdict(<class 'list'>, {'model': [Artifact(artifact: uri: "pipelines/penguin-transform/Trainer/model/6"
, artifact_type: name: "Model"
base_type: MODEL
)], 'model_run': [Artifact(artifact: uri: "pipelines/penguin-transform/Trainer/model_run/6"
, artifact_type: name: "ModelRun"
)]}) for execution 6
INFO:absl:MetadataStore with DB connection initialized
INFO:absl:Component Trainer is finished.
INFO:absl:Component Pusher is running.
INFO:absl:Running launcher for node_info {
  type {
    name: "tfx.components.pusher.component.Pusher"
    base_type: DEPLOY
  }
  id: "Pusher"
}
contexts {
  contexts {
    type {
      name: "pipeline"
    }
    name {
      field_value {
        string_value: "penguin-transform"
      }
    }
  }
  contexts {
    type {
      name: "pipeline_run"
    }
    name {
      field_value {
        string_value: "2023-07-28T10:00:34.516274"
      }
    }
  }
  contexts {
    type {
      name: "node"
    }
    name {
      field_value {
        string_value: "penguin-transform.Pusher"
      }
    }
  }
}
inputs {
  inputs {
    key: "model"
    value {
      channels {
        producer_node_query {
          id: "Trainer"
        }
        context_queries {
          type {
            name: "pipeline"
          }
          name {
            field_value {
              string_value: "penguin-transform"
            }
          }
        }
        context_queries {
          type {
            name: "pipeline_run"
          }
          name {
            field_value {
              string_value: "2023-07-28T10:00:34.516274"
            }
          }
        }
        context_queries {
          type {
            name: "node"
          }
          name {
            field_value {
              string_value: "penguin-transform.Trainer"
            }
          }
        }
        artifact_query {
          type {
            name: "Model"
            base_type: MODEL
          }
        }
        output_key: "model"
      }
    }
  }
}
outputs {
  outputs {
    key: "pushed_model"
    value {
      artifact_spec {
        type {
          name: "PushedModel"
          base_type: MODEL
        }
      }
    }
  }
}
parameters {
  parameters {
    key: "custom_config"
    value {
      field_value {
        string_value: "null"
      }
    }
  }
  parameters {
    key: "push_destination"
    value {
      field_value {
        string_value: "{\n  \"filesystem\": {\n    \"base_directory\": \"serving_model/penguin-transform\"\n  }\n}"
      }
    }
  }
}
upstream_nodes: "Trainer"
execution_options {
  caching_options {
  }
}

INFO:absl:MetadataStore with DB connection initialized
WARNING:absl:ArtifactQuery.property_predicate is not supported.
INFO:absl:[Pusher] Resolved inputs: ({'model': [Artifact(artifact: id: 12
type_id: 26
uri: "pipelines/penguin-transform/Trainer/model/6"
custom_properties {
  key: "is_external"
  value {
    int_value: 0
  }
}
custom_properties {
  key: "tfx_version"
  value {
    string_value: "1.13.0"
  }
}
state: LIVE
type: "Model"
create_time_since_epoch: 1690538476181
last_update_time_since_epoch: 1690538476181
, artifact_type: id: 26
name: "Model"
base_type: MODEL
)]},)
INFO:absl:MetadataStore with DB connection initialized
INFO:absl:Going to run a new execution 7
INFO:absl:Going to run a new execution: ExecutionInfo(execution_id=7, input_dict={'model': [Artifact(artifact: id: 12
type_id: 26
uri: "pipelines/penguin-transform/Trainer/model/6"
custom_properties {
  key: "is_external"
  value {
    int_value: 0
  }
}
custom_properties {
  key: "tfx_version"
  value {
    string_value: "1.13.0"
  }
}
state: LIVE
type: "Model"
create_time_since_epoch: 1690538476181
last_update_time_since_epoch: 1690538476181
, artifact_type: id: 26
name: "Model"
base_type: MODEL
)]}, output_dict=defaultdict(<class 'list'>, {'pushed_model': [Artifact(artifact: uri: "pipelines/penguin-transform/Pusher/pushed_model/7"
, artifact_type: name: "PushedModel"
base_type: MODEL
)]}), exec_properties={'push_destination': '{\n  "filesystem": {\n    "base_directory": "serving_model/penguin-transform"\n  }\n}', 'custom_config': 'null'}, execution_output_uri='pipelines/penguin-transform/Pusher/.system/executor_execution/7/executor_output.pb', stateful_working_dir='pipelines/penguin-transform/Pusher/.system/stateful_working_dir/2023-07-28T10:00:34.516274', tmp_dir='pipelines/penguin-transform/Pusher/.system/executor_execution/7/.temp/', pipeline_node=node_info {
  type {
    name: "tfx.components.pusher.component.Pusher"
    base_type: DEPLOY
  }
  id: "Pusher"
}
contexts {
  contexts {
    type {
      name: "pipeline"
    }
    name {
      field_value {
        string_value: "penguin-transform"
      }
    }
  }
  contexts {
    type {
      name: "pipeline_run"
    }
    name {
      field_value {
        string_value: "2023-07-28T10:00:34.516274"
      }
    }
  }
  contexts {
    type {
      name: "node"
    }
    name {
      field_value {
        string_value: "penguin-transform.Pusher"
      }
    }
  }
}
inputs {
  inputs {
    key: "model"
    value {
      channels {
        producer_node_query {
          id: "Trainer"
        }
        context_queries {
          type {
            name: "pipeline"
          }
          name {
            field_value {
              string_value: "penguin-transform"
            }
          }
        }
        context_queries {
          type {
            name: "pipeline_run"
          }
          name {
            field_value {
              string_value: "2023-07-28T10:00:34.516274"
            }
          }
        }
        context_queries {
          type {
            name: "node"
          }
          name {
            field_value {
              string_value: "penguin-transform.Trainer"
            }
          }
        }
        artifact_query {
          type {
            name: "Model"
            base_type: MODEL
          }
        }
        output_key: "model"
      }
    }
  }
}
outputs {
  outputs {
    key: "pushed_model"
    value {
      artifact_spec {
        type {
          name: "PushedModel"
          base_type: MODEL
        }
      }
    }
  }
}
parameters {
  parameters {
    key: "custom_config"
    value {
      field_value {
        string_value: "null"
      }
    }
  }
  parameters {
    key: "push_destination"
    value {
      field_value {
        string_value: "{\n  \"filesystem\": {\n    \"base_directory\": \"serving_model/penguin-transform\"\n  }\n}"
      }
    }
  }
}
upstream_nodes: "Trainer"
execution_options {
  caching_options {
  }
}
, pipeline_info=id: "penguin-transform"
, pipeline_run_id='2023-07-28T10:00:34.516274')
WARNING:absl:Pusher is going to push the model without validation. Consider using Evaluator or InfraValidator in your pipeline.
INFO:absl:Model version: 1690538476
INFO:absl:Model written to serving path serving_model/penguin-transform/1690538476.
INFO:absl:Model pushed to pipelines/penguin-transform/Pusher/pushed_model/7.
INFO:absl:Cleaning up stateless execution info.
INFO:absl:Execution 7 succeeded.
INFO:absl:Cleaning up stateful execution info.
INFO:absl:Publishing output artifacts defaultdict(<class 'list'>, {'pushed_model': [Artifact(artifact: uri: "pipelines/penguin-transform/Pusher/pushed_model/7"
, artifact_type: name: "PushedModel"
base_type: MODEL
)]}) for execution 7
INFO:absl:MetadataStore with DB connection initialized
INFO:absl:Component Pusher is finished.

You should see "INFO:absl:Component Pusher is finished." if the pipeline finished successfully.

The pusher component pushes the trained model to the SERVING_MODEL_DIR which is the serving_model/penguin-transform directory if you did not change the variables in the previous steps. You can see the result from the file browser in the left-side panel in Colab, or using the following command:

# List files in created model directory.
find {SERVING_MODEL_DIR}
serving_model/penguin-transform
serving_model/penguin-transform/1690538476
serving_model/penguin-transform/1690538476/keras_metadata.pb
serving_model/penguin-transform/1690538476/assets
serving_model/penguin-transform/1690538476/fingerprint.pb
serving_model/penguin-transform/1690538476/variables
serving_model/penguin-transform/1690538476/variables/variables.data-00000-of-00001
serving_model/penguin-transform/1690538476/variables/variables.index
serving_model/penguin-transform/1690538476/saved_model.pb

You can also check the signature of the generated model using the saved_model_cli tool.

saved_model_cli show --dir {SERVING_MODEL_DIR}/$(ls -1 {SERVING_MODEL_DIR} | sort -nr | head -1) --tag_set serve --signature_def serving_default
The given SavedModel SignatureDef contains the following input(s):
  inputs['examples'] tensor_info:
      dtype: DT_STRING
      shape: (-1)
      name: serving_default_examples:0
The given SavedModel SignatureDef contains the following output(s):
  outputs['output_0'] tensor_info:
      dtype: DT_FLOAT
      shape: (-1, 3)
      name: StatefulPartitionedCall_1:0
Method name is: tensorflow/serving/predict

Because we defined serving_default with our own serve_tf_examples_fn function, the signature shows that it takes a single string. This string is a serialized string of tf.Examples and will be parsed with the tf.io.parse_example() function as we defined earlier (learn more about tf.Examples here).

We can load the exported model and try some inferences with a few examples.

# Find a model with the latest timestamp.
model_dirs = (item for item in os.scandir(SERVING_MODEL_DIR) if item.is_dir())
model_path = max(model_dirs, key=lambda i: int(i.name)).path

loaded_model = tf.keras.models.load_model(model_path)
inference_fn = loaded_model.signatures['serving_default']
WARNING:tensorflow:Inconsistent references when loading the checkpoint into this object graph. For example, in the saved checkpoint object, `model.layer.weight` and `model.layer_copy.weight` reference the same variable, while in the current object these are two different variables. The referenced variables are:(<keras.saving.legacy.saved_model.load.TensorFlowTransform>TransformFeaturesLayer object at 0x7f4a50d10700> and <keras.engine.input_layer.InputLayer object at 0x7f4a50967f10>).
WARNING:tensorflow:Inconsistent references when loading the checkpoint into this object graph. For example, in the saved checkpoint object, `model.layer.weight` and `model.layer_copy.weight` reference the same variable, while in the current object these are two different variables. The referenced variables are:(<keras.saving.legacy.saved_model.load.TensorFlowTransform>TransformFeaturesLayer object at 0x7f4a50d10700> and <keras.engine.input_layer.InputLayer object at 0x7f4a50967f10>).
# Prepare an example and run inference.
features = {
  'culmen_length_mm': tf.train.Feature(float_list=tf.train.FloatList(value=[49.9])),
  'culmen_depth_mm': tf.train.Feature(float_list=tf.train.FloatList(value=[16.1])),
  'flipper_length_mm': tf.train.Feature(int64_list=tf.train.Int64List(value=[213])),
  'body_mass_g': tf.train.Feature(int64_list=tf.train.Int64List(value=[5400])),
}
example_proto = tf.train.Example(features=tf.train.Features(feature=features))
examples = example_proto.SerializeToString()

result = inference_fn(examples=tf.constant([examples]))
print(result['output_0'].numpy())
[[-2.1244783 -2.7389686  4.3641634]]

The third element, which corresponds to 'Gentoo' species, is expected to be the largest among three.

Next steps

If you want to learn more about Transform component, see Transform Component guide. You can find more resources on https://www.tensorflow.org/tfx/tutorials

Please see Understanding TFX Pipelines to learn more about various concepts in TFX.