Transform input data and train a model with a TFX pipeline.
In this notebook-based tutorial, we will create and run a TFX pipeline to ingest raw input data and preprocess it appropriately for ML training. This notebook is based on the TFX pipeline we built in Data validation using TFX Pipeline and TensorFlow Data Validation Tutorial. If you have not read that one yet, you should read it before proceeding with this notebook.
You can increase the predictive quality of your data and/or reduce dimensionality with feature engineering. One of the benefits of using TFX is that you will write your transformation code once, and the resulting transforms will be consistent between training and serving in order to avoid training/serving skew.
We will add a Transform
component to the pipeline. The Transform component is
implemented using the
tf.transform library.
Please see Understanding TFX Pipelines to learn more about various concepts in TFX.
Set Up
We first need to install the TFX Python package and download the dataset which we will use for our model.
Upgrade Pip
To avoid upgrading Pip in a system when running locally, check to make sure that we are running in Colab. Local systems can of course be upgraded separately.
try:
import colab
!pip install --upgrade pip
except:
pass
Install TFX
pip install -U tfx
Did you restart the runtime?
If you are using Google Colab, the first time that you run the cell above, you must restart the runtime by clicking above "RESTART RUNTIME" button or using "Runtime > Restart runtime ..." menu. This is because of the way that Colab loads packages.
Check the TensorFlow and TFX versions.
import tensorflow as tf
print('TensorFlow version: {}'.format(tf.__version__))
from tfx import v1 as tfx
print('TFX version: {}'.format(tfx.__version__))
2024-05-08 09:20:08.101132: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered 2024-05-08 09:20:08.101177: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered 2024-05-08 09:20:08.102632: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered TensorFlow version: 2.15.1 TFX version: 1.15.0
Set up variables
There are some variables used to define a pipeline. You can customize these variables as you want. By default all output from the pipeline will be generated under the current directory.
import os
PIPELINE_NAME = "penguin-transform"
# Output directory to store artifacts generated from the pipeline.
PIPELINE_ROOT = os.path.join('pipelines', PIPELINE_NAME)
# Path to a SQLite DB file to use as an MLMD storage.
METADATA_PATH = os.path.join('metadata', PIPELINE_NAME, 'metadata.db')
# Output directory where created models from the pipeline will be exported.
SERVING_MODEL_DIR = os.path.join('serving_model', PIPELINE_NAME)
from absl import logging
logging.set_verbosity(logging.INFO) # Set default logging level.
Prepare example data
We will download the example dataset for use in our TFX pipeline. The dataset we are using is Palmer Penguins dataset.
However, unlike previous tutorials which used an already preprocessed dataset, we will use the raw Palmer Penguins dataset.
Because the TFX ExampleGen component reads inputs from a directory, we need to create a directory and copy the dataset to it.
import urllib.request
import tempfile
DATA_ROOT = tempfile.mkdtemp(prefix='tfx-data') # Create a temporary directory.
_data_path = 'https://storage.googleapis.com/download.tensorflow.org/data/palmer_penguins/penguins_size.csv'
_data_filepath = os.path.join(DATA_ROOT, "data.csv")
urllib.request.urlretrieve(_data_path, _data_filepath)
('/tmpfs/tmp/tfx-data244l5nap/data.csv', <http.client.HTTPMessage at 0x7f9eee0ce2b0>)
Take a quick look at what the raw data looks like.
head {_data_filepath}
species,island,culmen_length_mm,culmen_depth_mm,flipper_length_mm,body_mass_g,sex Adelie,Torgersen,39.1,18.7,181,3750,MALE Adelie,Torgersen,39.5,17.4,186,3800,FEMALE Adelie,Torgersen,40.3,18,195,3250,FEMALE Adelie,Torgersen,NA,NA,NA,NA,NA Adelie,Torgersen,36.7,19.3,193,3450,FEMALE Adelie,Torgersen,39.3,20.6,190,3650,MALE Adelie,Torgersen,38.9,17.8,181,3625,FEMALE Adelie,Torgersen,39.2,19.6,195,4675,MALE Adelie,Torgersen,34.1,18.1,193,3475,NA
There are some entries with missing values which are represented as NA
.
We will just delete those entries in this tutorial.
sed -i '/\bNA\b/d' {_data_filepath}
head {_data_filepath}
species,island,culmen_length_mm,culmen_depth_mm,flipper_length_mm,body_mass_g,sex Adelie,Torgersen,39.1,18.7,181,3750,MALE Adelie,Torgersen,39.5,17.4,186,3800,FEMALE Adelie,Torgersen,40.3,18,195,3250,FEMALE Adelie,Torgersen,36.7,19.3,193,3450,FEMALE Adelie,Torgersen,39.3,20.6,190,3650,MALE Adelie,Torgersen,38.9,17.8,181,3625,FEMALE Adelie,Torgersen,39.2,19.6,195,4675,MALE Adelie,Torgersen,41.1,17.6,182,3200,FEMALE Adelie,Torgersen,38.6,21.2,191,3800,MALE
You should be able to see seven features which describe penguins. We will use the same set of features as the previous tutorials - 'culmen_length_mm', 'culmen_depth_mm', 'flipper_length_mm', 'body_mass_g' - and will predict the 'species' of a penguin.
The only difference will be that the input data is not preprocessed. Note that we will not use other features like 'island' or 'sex' in this tutorial.
Prepare a schema file
As described in Data validation using TFX Pipeline and TensorFlow Data Validation Tutorial, we need a schema file for the dataset. Because the dataset is different from the previous tutorial we need to generate it again. In this tutorial, we will skip those steps and just use a prepared schema file.
import shutil
SCHEMA_PATH = 'schema'
_schema_uri = 'https://raw.githubusercontent.com/tensorflow/tfx/master/tfx/examples/penguin/schema/raw/schema.pbtxt'
_schema_filename = 'schema.pbtxt'
_schema_filepath = os.path.join(SCHEMA_PATH, _schema_filename)
os.makedirs(SCHEMA_PATH, exist_ok=True)
urllib.request.urlretrieve(_schema_uri, _schema_filepath)
('schema/schema.pbtxt', <http.client.HTTPMessage at 0x7fa026398e80>)
This schema file was created with the same pipeline as in the previous tutorial without any manual changes.
Create a pipeline
TFX pipelines are defined using Python APIs. We will add Transform
component to the pipeline we created in the
Data Validation tutorial.
A Transform component requires input data from an ExampleGen
component and
a schema from a SchemaGen
component, and produces a "transform graph". The
output will be used in a Trainer
component. Transform can optionally
produce "transformed data" in addition, which is the materialized data after
transformation.
However, we will transform data during training in this tutorial without
materialization of the intermediate transformed data.
One thing to note is that we need to define a Python function,
preprocessing_fn
to describe how input data should be transformed. This is
similar to a Trainer component which also requires user code for model
definition.
Write preprocessing and training code
We need to define two Python functions. One for Transform and one for Trainer.
preprocessing_fn
The Transform component will find a function named preprocessing_fn
in the
given module file as we did for Trainer
component. You can also specify a
specific function using the
preprocessing_fn
parameter
of the Transform component.
In this example, we will do two kinds of transformation. For continuous numeric
features like culmen_length_mm
and body_mass_g
, we will normalize these
values using the
tft.scale_to_z_score
function. For the label feature, we need to convert string labels into numeric
index values. We will use
tf.lookup.StaticHashTable
for conversion.
run_fn
The model itself is almost the same as in the previous tutorials, but this time we will transform the input data using the transform graph from the Transform component.
One more important difference compared to the previous tutorial is that now we
export a model for serving which includes not only the computation graph of the
model, but also the transform graph for preprocessing, which is generated in
Transform component. We need to define a separate function which will be used
for serving incoming requests. You can see that the same function
_apply_preprocessing
was used for both of the training data and the
serving request.
_module_file = 'penguin_utils.py'
%%writefile {_module_file}
from typing import List, Text
from absl import logging
import tensorflow as tf
from tensorflow import keras
from tensorflow_metadata.proto.v0 import schema_pb2
import tensorflow_transform as tft
from tensorflow_transform.tf_metadata import schema_utils
from tfx import v1 as tfx
from tfx_bsl.public import tfxio
# Specify features that we will use.
_FEATURE_KEYS = [
'culmen_length_mm', 'culmen_depth_mm', 'flipper_length_mm', 'body_mass_g'
]
_LABEL_KEY = 'species'
_TRAIN_BATCH_SIZE = 20
_EVAL_BATCH_SIZE = 10
# NEW: TFX Transform will call this function.
def preprocessing_fn(inputs):
"""tf.transform's callback function for preprocessing inputs.
Args:
inputs: map from feature keys to raw not-yet-transformed features.
Returns:
Map from string feature key to transformed feature.
"""
outputs = {}
# Uses features defined in _FEATURE_KEYS only.
for key in _FEATURE_KEYS:
# tft.scale_to_z_score computes the mean and variance of the given feature
# and scales the output based on the result.
outputs[key] = tft.scale_to_z_score(inputs[key])
# For the label column we provide the mapping from string to index.
# We could instead use `tft.compute_and_apply_vocabulary()` in order to
# compute the vocabulary dynamically and perform a lookup.
# Since in this example there are only 3 possible values, we use a hard-coded
# table for simplicity.
table_keys = ['Adelie', 'Chinstrap', 'Gentoo']
initializer = tf.lookup.KeyValueTensorInitializer(
keys=table_keys,
values=tf.cast(tf.range(len(table_keys)), tf.int64),
key_dtype=tf.string,
value_dtype=tf.int64)
table = tf.lookup.StaticHashTable(initializer, default_value=-1)
outputs[_LABEL_KEY] = table.lookup(inputs[_LABEL_KEY])
return outputs
# NEW: This function will apply the same transform operation to training data
# and serving requests.
def _apply_preprocessing(raw_features, tft_layer):
transformed_features = tft_layer(raw_features)
if _LABEL_KEY in raw_features:
transformed_label = transformed_features.pop(_LABEL_KEY)
return transformed_features, transformed_label
else:
return transformed_features, None
# NEW: This function will create a handler function which gets a serialized
# tf.example, preprocess and run an inference with it.
def _get_serve_tf_examples_fn(model, tf_transform_output):
# We must save the tft_layer to the model to ensure its assets are kept and
# tracked.
model.tft_layer = tf_transform_output.transform_features_layer()
@tf.function(input_signature=[
tf.TensorSpec(shape=[None], dtype=tf.string, name='examples')
])
def serve_tf_examples_fn(serialized_tf_examples):
# Expected input is a string which is serialized tf.Example format.
feature_spec = tf_transform_output.raw_feature_spec()
# Because input schema includes unnecessary fields like 'species' and
# 'island', we filter feature_spec to include required keys only.
required_feature_spec = {
k: v for k, v in feature_spec.items() if k in _FEATURE_KEYS
}
parsed_features = tf.io.parse_example(serialized_tf_examples,
required_feature_spec)
# Preprocess parsed input with transform operation defined in
# preprocessing_fn().
transformed_features, _ = _apply_preprocessing(parsed_features,
model.tft_layer)
# Run inference with ML model.
return model(transformed_features)
return serve_tf_examples_fn
def _input_fn(file_pattern: List[Text],
data_accessor: tfx.components.DataAccessor,
tf_transform_output: tft.TFTransformOutput,
batch_size: int = 200) -> tf.data.Dataset:
"""Generates features and label for tuning/training.
Args:
file_pattern: List of paths or patterns of input tfrecord files.
data_accessor: DataAccessor for converting input to RecordBatch.
tf_transform_output: A TFTransformOutput.
batch_size: representing the number of consecutive elements of returned
dataset to combine in a single batch
Returns:
A dataset that contains (features, indices) tuple where features is a
dictionary of Tensors, and indices is a single Tensor of label indices.
"""
dataset = data_accessor.tf_dataset_factory(
file_pattern,
tfxio.TensorFlowDatasetOptions(batch_size=batch_size),
schema=tf_transform_output.raw_metadata.schema)
transform_layer = tf_transform_output.transform_features_layer()
def apply_transform(raw_features):
return _apply_preprocessing(raw_features, transform_layer)
return dataset.map(apply_transform).repeat()
def _build_keras_model() -> tf.keras.Model:
"""Creates a DNN Keras model for classifying penguin data.
Returns:
A Keras Model.
"""
# The model below is built with Functional API, please refer to
# https://www.tensorflow.org/guide/keras/overview for all API options.
inputs = [
keras.layers.Input(shape=(1,), name=key)
for key in _FEATURE_KEYS
]
d = keras.layers.concatenate(inputs)
for _ in range(2):
d = keras.layers.Dense(8, activation='relu')(d)
outputs = keras.layers.Dense(3)(d)
model = keras.Model(inputs=inputs, outputs=outputs)
model.compile(
optimizer=keras.optimizers.Adam(1e-2),
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=[keras.metrics.SparseCategoricalAccuracy()])
model.summary(print_fn=logging.info)
return model
# TFX Trainer will call this function.
def run_fn(fn_args: tfx.components.FnArgs):
"""Train the model based on given args.
Args:
fn_args: Holds args used to train the model as name/value pairs.
"""
tf_transform_output = tft.TFTransformOutput(fn_args.transform_output)
train_dataset = _input_fn(
fn_args.train_files,
fn_args.data_accessor,
tf_transform_output,
batch_size=_TRAIN_BATCH_SIZE)
eval_dataset = _input_fn(
fn_args.eval_files,
fn_args.data_accessor,
tf_transform_output,
batch_size=_EVAL_BATCH_SIZE)
model = _build_keras_model()
model.fit(
train_dataset,
steps_per_epoch=fn_args.train_steps,
validation_data=eval_dataset,
validation_steps=fn_args.eval_steps)
# NEW: Save a computation graph including transform layer.
signatures = {
'serving_default': _get_serve_tf_examples_fn(model, tf_transform_output),
}
model.save(fn_args.serving_model_dir, save_format='tf', signatures=signatures)
Writing penguin_utils.py
Now you have completed all of the preparation steps to build a TFX pipeline.
Write a pipeline definition
We define a function to create a TFX pipeline. A Pipeline
object
represents a TFX pipeline, which can be run using one of the pipeline
orchestration systems that TFX supports.
def _create_pipeline(pipeline_name: str, pipeline_root: str, data_root: str,
schema_path: str, module_file: str, serving_model_dir: str,
metadata_path: str) -> tfx.dsl.Pipeline:
"""Implements the penguin pipeline with TFX."""
# Brings data into the pipeline or otherwise joins/converts training data.
example_gen = tfx.components.CsvExampleGen(input_base=data_root)
# Computes statistics over data for visualization and example validation.
statistics_gen = tfx.components.StatisticsGen(
examples=example_gen.outputs['examples'])
# Import the schema.
schema_importer = tfx.dsl.Importer(
source_uri=schema_path,
artifact_type=tfx.types.standard_artifacts.Schema).with_id(
'schema_importer')
# Performs anomaly detection based on statistics and data schema.
example_validator = tfx.components.ExampleValidator(
statistics=statistics_gen.outputs['statistics'],
schema=schema_importer.outputs['result'])
# NEW: Transforms input data using preprocessing_fn in the 'module_file'.
transform = tfx.components.Transform(
examples=example_gen.outputs['examples'],
schema=schema_importer.outputs['result'],
materialize=False,
module_file=module_file)
# Uses user-provided Python function that trains a model.
trainer = tfx.components.Trainer(
module_file=module_file,
examples=example_gen.outputs['examples'],
# NEW: Pass transform_graph to the trainer.
transform_graph=transform.outputs['transform_graph'],
train_args=tfx.proto.TrainArgs(num_steps=100),
eval_args=tfx.proto.EvalArgs(num_steps=5))
# Pushes the model to a filesystem destination.
pusher = tfx.components.Pusher(
model=trainer.outputs['model'],
push_destination=tfx.proto.PushDestination(
filesystem=tfx.proto.PushDestination.Filesystem(
base_directory=serving_model_dir)))
components = [
example_gen,
statistics_gen,
schema_importer,
example_validator,
transform, # NEW: Transform component was added to the pipeline.
trainer,
pusher,
]
return tfx.dsl.Pipeline(
pipeline_name=pipeline_name,
pipeline_root=pipeline_root,
metadata_connection_config=tfx.orchestration.metadata
.sqlite_metadata_connection_config(metadata_path),
components=components)
Run the pipeline
We will use LocalDagRunner
as in the previous tutorial.
tfx.orchestration.LocalDagRunner().run(
_create_pipeline(
pipeline_name=PIPELINE_NAME,
pipeline_root=PIPELINE_ROOT,
data_root=DATA_ROOT,
schema_path=SCHEMA_PATH,
module_file=_module_file,
serving_model_dir=SERVING_MODEL_DIR,
metadata_path=METADATA_PATH))
INFO:absl:Excluding no splits because exclude_splits is not set. INFO:absl:Excluding no splits because exclude_splits is not set. INFO:absl:Generating ephemeral wheel package for '/tmpfs/src/temp/docs/tutorials/tfx/penguin_utils.py' (including modules: ['penguin_utils']). INFO:absl:User module package has hash fingerprint version a5e9139bd7facf5026b5306a6aea534f89db0dea58ebe1bb1fb5ebb9df5fdea9. INFO:absl:Executing: ['/tmpfs/src/tf_docs_env/bin/python', '/tmpfs/tmp/tmpkolg3eiy/_tfx_generated_setup.py', 'bdist_wheel', '--bdist-dir', '/tmpfs/tmp/tmp1yspllww', '--dist-dir', '/tmpfs/tmp/tmp0f9tyt4n'] /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/setuptools/_distutils/cmd.py:66: SetuptoolsDeprecationWarning: setup.py install is deprecated. !! ******************************************************************************** Please avoid running ``setup.py`` directly. Instead, use pypa/build, pypa/installer or other standards-based tools. See https://blog.ganssle.io/articles/2021/10/setup-py-deprecated.html for details. ******************************************************************************** !! self.initialize_options() INFO:absl:Successfully built user code wheel distribution at 'pipelines/penguin-transform/_wheels/tfx_user_code_Transform-0.0+a5e9139bd7facf5026b5306a6aea534f89db0dea58ebe1bb1fb5ebb9df5fdea9-py3-none-any.whl'; target user module is 'penguin_utils'. INFO:absl:Full user module path is 'penguin_utils@pipelines/penguin-transform/_wheels/tfx_user_code_Transform-0.0+a5e9139bd7facf5026b5306a6aea534f89db0dea58ebe1bb1fb5ebb9df5fdea9-py3-none-any.whl' INFO:absl:Generating ephemeral wheel package for '/tmpfs/src/temp/docs/tutorials/tfx/penguin_utils.py' (including modules: ['penguin_utils']). INFO:absl:User module package has hash fingerprint version a5e9139bd7facf5026b5306a6aea534f89db0dea58ebe1bb1fb5ebb9df5fdea9. INFO:absl:Executing: ['/tmpfs/src/tf_docs_env/bin/python', '/tmpfs/tmp/tmphxgs65wf/_tfx_generated_setup.py', 'bdist_wheel', '--bdist-dir', '/tmpfs/tmp/tmphmpuvvpe', '--dist-dir', '/tmpfs/tmp/tmpmyjp2nkq'] running bdist_wheel running build running build_py creating build creating build/lib copying penguin_utils.py -> build/lib installing to /tmpfs/tmp/tmp1yspllww running install running install_lib copying build/lib/penguin_utils.py -> /tmpfs/tmp/tmp1yspllww running install_egg_info running egg_info creating tfx_user_code_Transform.egg-info writing tfx_user_code_Transform.egg-info/PKG-INFO writing dependency_links to tfx_user_code_Transform.egg-info/dependency_links.txt writing top-level names to tfx_user_code_Transform.egg-info/top_level.txt writing manifest file 'tfx_user_code_Transform.egg-info/SOURCES.txt' reading manifest file 'tfx_user_code_Transform.egg-info/SOURCES.txt' writing manifest file 'tfx_user_code_Transform.egg-info/SOURCES.txt' Copying tfx_user_code_Transform.egg-info to /tmpfs/tmp/tmp1yspllww/tfx_user_code_Transform-0.0+a5e9139bd7facf5026b5306a6aea534f89db0dea58ebe1bb1fb5ebb9df5fdea9-py3.9.egg-info running install_scripts creating /tmpfs/tmp/tmp1yspllww/tfx_user_code_Transform-0.0+a5e9139bd7facf5026b5306a6aea534f89db0dea58ebe1bb1fb5ebb9df5fdea9.dist-info/WHEEL creating '/tmpfs/tmp/tmp0f9tyt4n/tfx_user_code_Transform-0.0+a5e9139bd7facf5026b5306a6aea534f89db0dea58ebe1bb1fb5ebb9df5fdea9-py3-none-any.whl' and adding '/tmpfs/tmp/tmp1yspllww' to it adding 'penguin_utils.py' adding 'tfx_user_code_Transform-0.0+a5e9139bd7facf5026b5306a6aea534f89db0dea58ebe1bb1fb5ebb9df5fdea9.dist-info/METADATA' adding 'tfx_user_code_Transform-0.0+a5e9139bd7facf5026b5306a6aea534f89db0dea58ebe1bb1fb5ebb9df5fdea9.dist-info/WHEEL' adding 'tfx_user_code_Transform-0.0+a5e9139bd7facf5026b5306a6aea534f89db0dea58ebe1bb1fb5ebb9df5fdea9.dist-info/top_level.txt' adding 'tfx_user_code_Transform-0.0+a5e9139bd7facf5026b5306a6aea534f89db0dea58ebe1bb1fb5ebb9df5fdea9.dist-info/RECORD' removing /tmpfs/tmp/tmp1yspllww /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/setuptools/_distutils/cmd.py:66: SetuptoolsDeprecationWarning: setup.py install is deprecated. !! ******************************************************************************** Please avoid running ``setup.py`` directly. Instead, use pypa/build, pypa/installer or other standards-based tools. See https://blog.ganssle.io/articles/2021/10/setup-py-deprecated.html for details. ******************************************************************************** !! self.initialize_options() INFO:absl:Successfully built user code wheel distribution at 'pipelines/penguin-transform/_wheels/tfx_user_code_Trainer-0.0+a5e9139bd7facf5026b5306a6aea534f89db0dea58ebe1bb1fb5ebb9df5fdea9-py3-none-any.whl'; target user module is 'penguin_utils'. INFO:absl:Full user module path is 'penguin_utils@pipelines/penguin-transform/_wheels/tfx_user_code_Trainer-0.0+a5e9139bd7facf5026b5306a6aea534f89db0dea58ebe1bb1fb5ebb9df5fdea9-py3-none-any.whl' INFO:absl:Using deployment config: executor_specs { key: "CsvExampleGen" value { beam_executable_spec { python_executor_spec { class_path: "tfx.components.example_gen.csv_example_gen.executor.Executor" } } } } executor_specs { key: "ExampleValidator" value { python_class_executable_spec { class_path: "tfx.components.example_validator.executor.Executor" } } } executor_specs { key: "Pusher" value { python_class_executable_spec { class_path: "tfx.components.pusher.executor.Executor" } } } executor_specs { key: "StatisticsGen" value { beam_executable_spec { python_executor_spec { class_path: "tfx.components.statistics_gen.executor.Executor" } } } } executor_specs { key: "Trainer" value { python_class_executable_spec { class_path: "tfx.components.trainer.executor.GenericExecutor" } } } executor_specs { key: "Transform" value { beam_executable_spec { python_executor_spec { class_path: "tfx.components.transform.executor.Executor" } } } } custom_driver_specs { key: "CsvExampleGen" value { python_class_executable_spec { class_path: "tfx.components.example_gen.driver.FileBasedDriver" } } } metadata_connection_config { database_connection_config { sqlite { filename_uri: "metadata/penguin-transform/metadata.db" connection_mode: READWRITE_OPENCREATE } } } INFO:absl:Using connection config: sqlite { filename_uri: "metadata/penguin-transform/metadata.db" connection_mode: READWRITE_OPENCREATE } INFO:absl:Component CsvExampleGen is running. INFO:absl:Running launcher for node_info { type { name: "tfx.components.example_gen.csv_example_gen.component.CsvExampleGen" } id: "CsvExampleGen" } contexts { contexts { type { name: "pipeline" } name { field_value { string_value: "penguin-transform" } } } contexts { type { name: "pipeline_run" } name { field_value { string_value: "2024-05-08T09:20:15.209892" } } } contexts { type { name: "node" } name { field_value { string_value: "penguin-transform.CsvExampleGen" } } } } outputs { outputs { key: "examples" value { artifact_spec { type { name: "Examples" properties { key: "span" value: INT } properties { key: "split_names" value: STRING } properties { key: "version" value: INT } base_type: DATASET } } } } } parameters { parameters { key: "input_base" value { field_value { string_value: "/tmpfs/tmp/tfx-data244l5nap" } } } parameters { key: "input_config" value { field_value { string_value: "{\n \"splits\": [\n {\n \"name\": \"single_split\",\n \"pattern\": \"*\"\n }\n ]\n}" } } } parameters { key: "output_config" value { field_value { string_value: "{\n \"split_config\": {\n \"splits\": [\n {\n \"hash_buckets\": 2,\n \"name\": \"train\"\n },\n {\n \"hash_buckets\": 1,\n \"name\": \"eval\"\n }\n ]\n }\n}" } } } parameters { key: "output_data_format" value { field_value { int_value: 6 } } } parameters { key: "output_file_format" value { field_value { int_value: 5 } } } } downstream_nodes: "StatisticsGen" downstream_nodes: "Trainer" downstream_nodes: "Transform" execution_options { caching_options { } } INFO:absl:MetadataStore with DB connection initialized running bdist_wheel running build running build_py creating build creating build/lib copying penguin_utils.py -> build/lib installing to /tmpfs/tmp/tmphmpuvvpe running install running install_lib copying build/lib/penguin_utils.py -> /tmpfs/tmp/tmphmpuvvpe running install_egg_info running egg_info creating tfx_user_code_Trainer.egg-info writing tfx_user_code_Trainer.egg-info/PKG-INFO writing dependency_links to tfx_user_code_Trainer.egg-info/dependency_links.txt writing top-level names to tfx_user_code_Trainer.egg-info/top_level.txt writing manifest file 'tfx_user_code_Trainer.egg-info/SOURCES.txt' reading manifest file 'tfx_user_code_Trainer.egg-info/SOURCES.txt' writing manifest file 'tfx_user_code_Trainer.egg-info/SOURCES.txt' Copying tfx_user_code_Trainer.egg-info to /tmpfs/tmp/tmphmpuvvpe/tfx_user_code_Trainer-0.0+a5e9139bd7facf5026b5306a6aea534f89db0dea58ebe1bb1fb5ebb9df5fdea9-py3.9.egg-info running install_scripts creating /tmpfs/tmp/tmphmpuvvpe/tfx_user_code_Trainer-0.0+a5e9139bd7facf5026b5306a6aea534f89db0dea58ebe1bb1fb5ebb9df5fdea9.dist-info/WHEEL creating '/tmpfs/tmp/tmpmyjp2nkq/tfx_user_code_Trainer-0.0+a5e9139bd7facf5026b5306a6aea534f89db0dea58ebe1bb1fb5ebb9df5fdea9-py3-none-any.whl' and adding '/tmpfs/tmp/tmphmpuvvpe' to it adding 'penguin_utils.py' adding 'tfx_user_code_Trainer-0.0+a5e9139bd7facf5026b5306a6aea534f89db0dea58ebe1bb1fb5ebb9df5fdea9.dist-info/METADATA' adding 'tfx_user_code_Trainer-0.0+a5e9139bd7facf5026b5306a6aea534f89db0dea58ebe1bb1fb5ebb9df5fdea9.dist-info/WHEEL' adding 'tfx_user_code_Trainer-0.0+a5e9139bd7facf5026b5306a6aea534f89db0dea58ebe1bb1fb5ebb9df5fdea9.dist-info/top_level.txt' adding 'tfx_user_code_Trainer-0.0+a5e9139bd7facf5026b5306a6aea534f89db0dea58ebe1bb1fb5ebb9df5fdea9.dist-info/RECORD' removing /tmpfs/tmp/tmphmpuvvpe INFO:absl:[CsvExampleGen] Resolved inputs: ({},) INFO:absl:select span and version = (0, None) INFO:absl:latest span and version = (0, None) INFO:absl:MetadataStore with DB connection initialized INFO:absl:Going to run a new execution 1 INFO:absl:Going to run a new execution: ExecutionInfo(execution_id=1, input_dict={}, output_dict=defaultdict(<class 'list'>, {'examples': [Artifact(artifact: uri: "pipelines/penguin-transform/CsvExampleGen/examples/1" custom_properties { key: "input_fingerprint" value { string_value: "split:single_split,num_files:1,total_bytes:13161,xor_checksum:1715160013,sum_checksum:1715160013" } } custom_properties { key: "span" value { int_value: 0 } } , artifact_type: name: "Examples" properties { key: "span" value: INT } properties { key: "split_names" value: STRING } properties { key: "version" value: INT } base_type: DATASET )]}), exec_properties={'input_base': '/tmpfs/tmp/tfx-data244l5nap', 'output_data_format': 6, 'input_config': '{\n "splits": [\n {\n "name": "single_split",\n "pattern": "*"\n }\n ]\n}', 'output_config': '{\n "split_config": {\n "splits": [\n {\n "hash_buckets": 2,\n "name": "train"\n },\n {\n "hash_buckets": 1,\n "name": "eval"\n }\n ]\n }\n}', 'output_file_format': 5, 'span': 0, 'version': None, 'input_fingerprint': 'split:single_split,num_files:1,total_bytes:13161,xor_checksum:1715160013,sum_checksum:1715160013'}, execution_output_uri='pipelines/penguin-transform/CsvExampleGen/.system/executor_execution/1/executor_output.pb', stateful_working_dir='pipelines/penguin-transform/CsvExampleGen/.system/stateful_working_dir/2a68db53-a342-4b70-b5ef-4bea08945c28', tmp_dir='pipelines/penguin-transform/CsvExampleGen/.system/executor_execution/1/.temp/', pipeline_node=node_info { type { name: "tfx.components.example_gen.csv_example_gen.component.CsvExampleGen" } id: "CsvExampleGen" } contexts { contexts { type { name: "pipeline" } name { field_value { string_value: "penguin-transform" } } } contexts { type { name: "pipeline_run" } name { field_value { string_value: "2024-05-08T09:20:15.209892" } } } contexts { type { name: "node" } name { field_value { string_value: "penguin-transform.CsvExampleGen" } } } } outputs { outputs { key: "examples" value { artifact_spec { type { name: "Examples" properties { key: "span" value: INT } properties { key: "split_names" value: STRING } properties { key: "version" value: INT } base_type: DATASET } } } } } parameters { parameters { key: "input_base" value { field_value { string_value: "/tmpfs/tmp/tfx-data244l5nap" } } } parameters { key: "input_config" value { field_value { string_value: "{\n \"splits\": [\n {\n \"name\": \"single_split\",\n \"pattern\": \"*\"\n }\n ]\n}" } } } parameters { key: "output_config" value { field_value { string_value: "{\n \"split_config\": {\n \"splits\": [\n {\n \"hash_buckets\": 2,\n \"name\": \"train\"\n },\n {\n \"hash_buckets\": 1,\n \"name\": \"eval\"\n }\n ]\n }\n}" } } } parameters { key: "output_data_format" value { field_value { int_value: 6 } } } parameters { key: "output_file_format" value { field_value { int_value: 5 } } } } downstream_nodes: "StatisticsGen" downstream_nodes: "Trainer" downstream_nodes: "Transform" execution_options { caching_options { } } , pipeline_info=id: "penguin-transform" , pipeline_run_id='2024-05-08T09:20:15.209892', top_level_pipeline_run_id=None, frontend_url=None) INFO:absl:Generating examples. WARNING:apache_beam.runners.interactive.interactive_environment:Dependencies required for Interactive Beam PCollection visualization are not available, please use: `pip install apache-beam[interactive]` to install necessary dependencies to enable all data visualization features. INFO:absl:Processing input csv data /tmpfs/tmp/tfx-data244l5nap/* to TFExample. WARNING:apache_beam.io.tfrecordio:Couldn't find python-snappy so the implementation of _TFRecordUtil._masked_crc32c is not as fast as it could be. INFO:absl:Examples generated. INFO:absl:Value type <class 'NoneType'> of key version in exec_properties is not supported, going to drop it INFO:absl:Value type <class 'list'> of key _beam_pipeline_args in exec_properties is not supported, going to drop it INFO:absl:Cleaning up stateless execution info. INFO:absl:Execution 1 succeeded. INFO:absl:Cleaning up stateful execution info. INFO:absl:Deleted stateful_working_dir pipelines/penguin-transform/CsvExampleGen/.system/stateful_working_dir/2a68db53-a342-4b70-b5ef-4bea08945c28 INFO:absl:Publishing output artifacts defaultdict(<class 'list'>, {'examples': [Artifact(artifact: uri: "pipelines/penguin-transform/CsvExampleGen/examples/1" custom_properties { key: "input_fingerprint" value { string_value: "split:single_split,num_files:1,total_bytes:13161,xor_checksum:1715160013,sum_checksum:1715160013" } } custom_properties { key: "span" value { int_value: 0 } } , artifact_type: name: "Examples" properties { key: "span" value: INT } properties { key: "split_names" value: STRING } properties { key: "version" value: INT } base_type: DATASET )]}) for execution 1 INFO:absl:MetadataStore with DB connection initialized INFO:absl:Component CsvExampleGen is finished. INFO:absl:Component schema_importer is running. INFO:absl:Running launcher for node_info { type { name: "tfx.dsl.components.common.importer.Importer" } id: "schema_importer" } contexts { contexts { type { name: "pipeline" } name { field_value { string_value: "penguin-transform" } } } contexts { type { name: "pipeline_run" } name { field_value { string_value: "2024-05-08T09:20:15.209892" } } } contexts { type { name: "node" } name { field_value { string_value: "penguin-transform.schema_importer" } } } } outputs { outputs { key: "result" value { artifact_spec { type { name: "Schema" } } } } } parameters { parameters { key: "artifact_uri" value { field_value { string_value: "schema" } } } parameters { key: "output_key" value { field_value { string_value: "result" } } } parameters { key: "reimport" value { field_value { int_value: 0 } } } } downstream_nodes: "ExampleValidator" downstream_nodes: "Transform" execution_options { caching_options { } } INFO:absl:Running as an importer node. INFO:absl:MetadataStore with DB connection initialized INFO:absl:Processing source uri: schema, properties: {}, custom_properties: {} INFO:absl:Component schema_importer is finished. INFO:absl:Component StatisticsGen is running. INFO:absl:Running launcher for node_info { type { name: "tfx.components.statistics_gen.component.StatisticsGen" base_type: PROCESS } id: "StatisticsGen" } contexts { contexts { type { name: "pipeline" } name { field_value { string_value: "penguin-transform" } } } contexts { type { name: "pipeline_run" } name { field_value { string_value: "2024-05-08T09:20:15.209892" } } } contexts { type { name: "node" } name { field_value { string_value: "penguin-transform.StatisticsGen" } } } } inputs { inputs { key: "examples" value { channels { producer_node_query { id: "CsvExampleGen" } context_queries { type { name: "pipeline" } name { field_value { string_value: "penguin-transform" } } } context_queries { type { name: "pipeline_run" } name { field_value { string_value: "2024-05-08T09:20:15.209892" } } } context_queries { type { name: "node" } name { field_value { string_value: "penguin-transform.CsvExampleGen" } } } artifact_query { type { name: "Examples" base_type: DATASET } } output_key: "examples" } min_count: 1 } } } outputs { outputs { key: "statistics" value { artifact_spec { type { name: "ExampleStatistics" properties { key: "span" value: INT } properties { key: "split_names" value: STRING } base_type: STATISTICS } } } } } parameters { parameters { key: "exclude_splits" value { field_value { string_value: "[]" } } } } upstream_nodes: "CsvExampleGen" downstream_nodes: "ExampleValidator" execution_options { caching_options { } } INFO:absl:MetadataStore with DB connection initialized WARNING:absl:ArtifactQuery.property_predicate is not supported. INFO:absl:[StatisticsGen] Resolved inputs: ({'examples': [Artifact(artifact: id: 1 type_id: 15 uri: "pipelines/penguin-transform/CsvExampleGen/examples/1" properties { key: "split_names" value { string_value: "[\"train\", \"eval\"]" } } custom_properties { key: "file_format" value { string_value: "tfrecords_gzip" } } custom_properties { key: "input_fingerprint" value { string_value: "split:single_split,num_files:1,total_bytes:13161,xor_checksum:1715160013,sum_checksum:1715160013" } } custom_properties { key: "is_external" value { int_value: 0 } } custom_properties { key: "payload_format" value { string_value: "FORMAT_TF_EXAMPLE" } } custom_properties { key: "span" value { int_value: 0 } } custom_properties { key: "tfx_version" value { string_value: "1.15.0" } } state: LIVE type: "Examples" create_time_since_epoch: 1715160016351 last_update_time_since_epoch: 1715160016351 , artifact_type: id: 15 name: "Examples" properties { key: "span" value: INT } properties { key: "split_names" value: STRING } properties { key: "version" value: INT } base_type: DATASET )]},) INFO:absl:MetadataStore with DB connection initialized INFO:absl:Going to run a new execution 3 INFO:absl:Going to run a new execution: ExecutionInfo(execution_id=3, input_dict={'examples': [Artifact(artifact: id: 1 type_id: 15 uri: "pipelines/penguin-transform/CsvExampleGen/examples/1" properties { key: "split_names" value { string_value: "[\"train\", \"eval\"]" } } custom_properties { key: "file_format" value { string_value: "tfrecords_gzip" } } custom_properties { key: "input_fingerprint" value { string_value: "split:single_split,num_files:1,total_bytes:13161,xor_checksum:1715160013,sum_checksum:1715160013" } } custom_properties { key: "is_external" value { int_value: 0 } } custom_properties { key: "payload_format" value { string_value: "FORMAT_TF_EXAMPLE" } } custom_properties { key: "span" value { int_value: 0 } } custom_properties { key: "tfx_version" value { string_value: "1.15.0" } } state: LIVE type: "Examples" create_time_since_epoch: 1715160016351 last_update_time_since_epoch: 1715160016351 , artifact_type: id: 15 name: "Examples" properties { key: "span" value: INT } properties { key: "split_names" value: STRING } properties { key: "version" value: INT } base_type: DATASET )]}, output_dict=defaultdict(<class 'list'>, {'statistics': [Artifact(artifact: uri: "pipelines/penguin-transform/StatisticsGen/statistics/3" , artifact_type: name: "ExampleStatistics" properties { key: "span" value: INT } properties { key: "split_names" value: STRING } base_type: STATISTICS )]}), exec_properties={'exclude_splits': '[]'}, execution_output_uri='pipelines/penguin-transform/StatisticsGen/.system/executor_execution/3/executor_output.pb', stateful_working_dir='pipelines/penguin-transform/StatisticsGen/.system/stateful_working_dir/87ad10ec-9fbd-4d39-bccf-4085439aebed', tmp_dir='pipelines/penguin-transform/StatisticsGen/.system/executor_execution/3/.temp/', pipeline_node=node_info { type { name: "tfx.components.statistics_gen.component.StatisticsGen" base_type: PROCESS } id: "StatisticsGen" } contexts { contexts { type { name: "pipeline" } name { field_value { string_value: "penguin-transform" } } } contexts { type { name: "pipeline_run" } name { field_value { string_value: "2024-05-08T09:20:15.209892" } } } contexts { type { name: "node" } name { field_value { string_value: "penguin-transform.StatisticsGen" } } } } inputs { inputs { key: "examples" value { channels { producer_node_query { id: "CsvExampleGen" } context_queries { type { name: "pipeline" } name { field_value { string_value: "penguin-transform" } } } context_queries { type { name: "pipeline_run" } name { field_value { string_value: "2024-05-08T09:20:15.209892" } } } context_queries { type { name: "node" } name { field_value { string_value: "penguin-transform.CsvExampleGen" } } } artifact_query { type { name: "Examples" base_type: DATASET } } output_key: "examples" } min_count: 1 } } } outputs { outputs { key: "statistics" value { artifact_spec { type { name: "ExampleStatistics" properties { key: "span" value: INT } properties { key: "split_names" value: STRING } base_type: STATISTICS } } } } } parameters { parameters { key: "exclude_splits" value { field_value { string_value: "[]" } } } } upstream_nodes: "CsvExampleGen" downstream_nodes: "ExampleValidator" execution_options { caching_options { } } , pipeline_info=id: "penguin-transform" , pipeline_run_id='2024-05-08T09:20:15.209892', top_level_pipeline_run_id=None, frontend_url=None) INFO:absl:Generating statistics for split train. INFO:absl:Statistics for split train written to pipelines/penguin-transform/StatisticsGen/statistics/3/Split-train. INFO:absl:Generating statistics for split eval. INFO:absl:Statistics for split eval written to pipelines/penguin-transform/StatisticsGen/statistics/3/Split-eval. INFO:absl:Cleaning up stateless execution info. INFO:absl:Execution 3 succeeded. INFO:absl:Cleaning up stateful execution info. INFO:absl:Deleted stateful_working_dir pipelines/penguin-transform/StatisticsGen/.system/stateful_working_dir/87ad10ec-9fbd-4d39-bccf-4085439aebed INFO:absl:Publishing output artifacts defaultdict(<class 'list'>, {'statistics': [Artifact(artifact: uri: "pipelines/penguin-transform/StatisticsGen/statistics/3" , artifact_type: name: "ExampleStatistics" properties { key: "span" value: INT } properties { key: "split_names" value: STRING } base_type: STATISTICS )]}) for execution 3 INFO:absl:MetadataStore with DB connection initialized INFO:absl:Component StatisticsGen is finished. INFO:absl:Component Transform is running. INFO:absl:Running launcher for node_info { type { name: "tfx.components.transform.component.Transform" base_type: TRANSFORM } id: "Transform" } contexts { contexts { type { name: "pipeline" } name { field_value { string_value: "penguin-transform" } } } contexts { type { name: "pipeline_run" } name { field_value { string_value: "2024-05-08T09:20:15.209892" } } } contexts { type { name: "node" } name { field_value { string_value: "penguin-transform.Transform" } } } } inputs { inputs { key: "examples" value { channels { producer_node_query { id: "CsvExampleGen" } context_queries { type { name: "pipeline" } name { field_value { string_value: "penguin-transform" } } } context_queries { type { name: "pipeline_run" } name { field_value { string_value: "2024-05-08T09:20:15.209892" } } } context_queries { type { name: "node" } name { field_value { string_value: "penguin-transform.CsvExampleGen" } } } artifact_query { type { name: "Examples" base_type: DATASET } } output_key: "examples" } min_count: 1 } } inputs { key: "schema" value { channels { producer_node_query { id: "schema_importer" } context_queries { type { name: "pipeline" } name { field_value { string_value: "penguin-transform" } } } context_queries { type { name: "pipeline_run" } name { field_value { string_value: "2024-05-08T09:20:15.209892" } } } context_queries { type { name: "node" } name { field_value { string_value: "penguin-transform.schema_importer" } } } artifact_query { type { name: "Schema" } } output_key: "result" } min_count: 1 } } } outputs { outputs { key: "post_transform_anomalies" value { artifact_spec { type { name: "ExampleAnomalies" properties { key: "span" value: INT } properties { key: "split_names" value: STRING } } } } } outputs { key: "post_transform_schema" value { artifact_spec { type { name: "Schema" } } } } outputs { key: "post_transform_stats" value { artifact_spec { type { name: "ExampleStatistics" properties { key: "span" value: INT } properties { key: "split_names" value: STRING } base_type: STATISTICS } } } } outputs { key: "pre_transform_schema" value { artifact_spec { type { name: "Schema" } } } } outputs { key: "pre_transform_stats" value { artifact_spec { type { name: "ExampleStatistics" properties { key: "span" value: INT } properties { key: "split_names" value: STRING } base_type: STATISTICS } } } } outputs { key: "transform_graph" value { artifact_spec { type { name: "TransformGraph" } } } } outputs { key: "updated_analyzer_cache" value { artifact_spec { type { name: "TransformCache" } } } } } parameters { parameters { key: "custom_config" value { field_value { string_value: "null" } } } parameters { key: "disable_statistics" value { field_value { int_value: 0 } } } parameters { key: "force_tf_compat_v1" value { field_value { int_value: 0 } } } parameters { key: "module_path" value { field_value { string_value: "penguin_utils@pipelines/penguin-transform/_wheels/tfx_user_code_Transform-0.0+a5e9139bd7facf5026b5306a6aea534f89db0dea58ebe1bb1fb5ebb9df5fdea9-py3-none-any.whl" } } } } upstream_nodes: "CsvExampleGen" upstream_nodes: "schema_importer" downstream_nodes: "Trainer" execution_options { caching_options { } } INFO:absl:MetadataStore with DB connection initialized WARNING:absl:ArtifactQuery.property_predicate is not supported. WARNING:absl:ArtifactQuery.property_predicate is not supported. INFO:absl:[Transform] Resolved inputs: ({'schema': [Artifact(artifact: id: 2 type_id: 17 uri: "schema" custom_properties { key: "is_external" value { int_value: 1 } } custom_properties { key: "tfx_version" value { string_value: "1.15.0" } } state: LIVE type: "Schema" create_time_since_epoch: 1715160016371 last_update_time_since_epoch: 1715160016371 , artifact_type: id: 17 name: "Schema" )], 'examples': [Artifact(artifact: id: 1 type_id: 15 uri: "pipelines/penguin-transform/CsvExampleGen/examples/1" properties { key: "split_names" value { string_value: "[\"train\", \"eval\"]" } } custom_properties { key: "file_format" value { string_value: "tfrecords_gzip" } } custom_properties { key: "input_fingerprint" value { string_value: "split:single_split,num_files:1,total_bytes:13161,xor_checksum:1715160013,sum_checksum:1715160013" } } custom_properties { key: "is_external" value { int_value: 0 } } custom_properties { key: "payload_format" value { string_value: "FORMAT_TF_EXAMPLE" } } custom_properties { key: "span" value { int_value: 0 } } custom_properties { key: "tfx_version" value { string_value: "1.15.0" } } state: LIVE type: "Examples" create_time_since_epoch: 1715160016351 last_update_time_since_epoch: 1715160016351 , artifact_type: id: 15 name: "Examples" properties { key: "span" value: INT } properties { key: "split_names" value: STRING } properties { key: "version" value: INT } base_type: DATASET )]},) INFO:absl:MetadataStore with DB connection initialized INFO:absl:Going to run a new execution 4 INFO:absl:Going to run a new execution: ExecutionInfo(execution_id=4, input_dict={'schema': [Artifact(artifact: id: 2 type_id: 17 uri: "schema" custom_properties { key: "is_external" value { int_value: 1 } } custom_properties { key: "tfx_version" value { string_value: "1.15.0" } } state: LIVE type: "Schema" create_time_since_epoch: 1715160016371 last_update_time_since_epoch: 1715160016371 , artifact_type: id: 17 name: "Schema" )], 'examples': [Artifact(artifact: id: 1 type_id: 15 uri: "pipelines/penguin-transform/CsvExampleGen/examples/1" properties { key: "split_names" value { string_value: "[\"train\", \"eval\"]" } } custom_properties { key: "file_format" value { string_value: "tfrecords_gzip" } } custom_properties { key: "input_fingerprint" value { string_value: "split:single_split,num_files:1,total_bytes:13161,xor_checksum:1715160013,sum_checksum:1715160013" } } custom_properties { key: "is_external" value { int_value: 0 } } custom_properties { key: "payload_format" value { string_value: "FORMAT_TF_EXAMPLE" } } custom_properties { key: "span" value { int_value: 0 } } custom_properties { key: "tfx_version" value { string_value: "1.15.0" } } state: LIVE type: "Examples" create_time_since_epoch: 1715160016351 last_update_time_since_epoch: 1715160016351 , artifact_type: id: 15 name: "Examples" properties { key: "span" value: INT } properties { key: "split_names" value: STRING } properties { key: "version" value: INT } base_type: DATASET )]}, output_dict=defaultdict(<class 'list'>, {'post_transform_schema': [Artifact(artifact: uri: "pipelines/penguin-transform/Transform/post_transform_schema/4" , artifact_type: name: "Schema" )], 'updated_analyzer_cache': [Artifact(artifact: uri: "pipelines/penguin-transform/Transform/updated_analyzer_cache/4" , artifact_type: name: "TransformCache" )], 'pre_transform_stats': [Artifact(artifact: uri: "pipelines/penguin-transform/Transform/pre_transform_stats/4" , artifact_type: name: "ExampleStatistics" properties { key: "span" value: INT } properties { key: "split_names" value: STRING } base_type: STATISTICS )], 'transform_graph': [Artifact(artifact: uri: "pipelines/penguin-transform/Transform/transform_graph/4" , artifact_type: name: "TransformGraph" )], 'post_transform_stats': [Artifact(artifact: uri: "pipelines/penguin-transform/Transform/post_transform_stats/4" , artifact_type: name: "ExampleStatistics" properties { key: "span" value: INT } properties { key: "split_names" value: STRING } base_type: STATISTICS )], 'post_transform_anomalies': [Artifact(artifact: uri: "pipelines/penguin-transform/Transform/post_transform_anomalies/4" , artifact_type: name: "ExampleAnomalies" properties { key: "span" value: INT } properties { key: "split_names" value: STRING } )], 'pre_transform_schema': [Artifact(artifact: uri: "pipelines/penguin-transform/Transform/pre_transform_schema/4" , artifact_type: name: "Schema" )]}), exec_properties={'force_tf_compat_v1': 0, 'module_path': 'penguin_utils@pipelines/penguin-transform/_wheels/tfx_user_code_Transform-0.0+a5e9139bd7facf5026b5306a6aea534f89db0dea58ebe1bb1fb5ebb9df5fdea9-py3-none-any.whl', 'disable_statistics': 0, 'custom_config': 'null'}, execution_output_uri='pipelines/penguin-transform/Transform/.system/executor_execution/4/executor_output.pb', stateful_working_dir='pipelines/penguin-transform/Transform/.system/stateful_working_dir/353ff412-c22b-442d-86fa-d000785c15c6', tmp_dir='pipelines/penguin-transform/Transform/.system/executor_execution/4/.temp/', pipeline_node=node_info { type { name: "tfx.components.transform.component.Transform" base_type: TRANSFORM } id: "Transform" } contexts { contexts { type { name: "pipeline" } name { field_value { string_value: "penguin-transform" } } } contexts { type { name: "pipeline_run" } name { field_value { string_value: "2024-05-08T09:20:15.209892" } } } contexts { type { name: "node" } name { field_value { string_value: "penguin-transform.Transform" } } } } inputs { inputs { key: "examples" value { channels { producer_node_query { id: "CsvExampleGen" } context_queries { type { name: "pipeline" } name { field_value { string_value: "penguin-transform" } } } context_queries { type { name: "pipeline_run" } name { field_value { string_value: "2024-05-08T09:20:15.209892" } } } context_queries { type { name: "node" } name { field_value { string_value: "penguin-transform.CsvExampleGen" } } } artifact_query { type { name: "Examples" base_type: DATASET } } output_key: "examples" } min_count: 1 } } inputs { key: "schema" value { channels { producer_node_query { id: "schema_importer" } context_queries { type { name: "pipeline" } name { field_value { string_value: "penguin-transform" } } } context_queries { type { name: "pipeline_run" } name { field_value { string_value: "2024-05-08T09:20:15.209892" } } } context_queries { type { name: "node" } name { field_value { string_value: "penguin-transform.schema_importer" } } } artifact_query { type { name: "Schema" } } output_key: "result" } min_count: 1 } } } outputs { outputs { key: "post_transform_anomalies" value { artifact_spec { type { name: "ExampleAnomalies" properties { key: "span" value: INT } properties { key: "split_names" value: STRING } } } } } outputs { key: "post_transform_schema" value { artifact_spec { type { name: "Schema" } } } } outputs { key: "post_transform_stats" value { artifact_spec { type { name: "ExampleStatistics" properties { key: "span" value: INT } properties { key: "split_names" value: STRING } base_type: STATISTICS } } } } outputs { key: "pre_transform_schema" value { artifact_spec { type { name: "Schema" } } } } outputs { key: "pre_transform_stats" value { artifact_spec { type { name: "ExampleStatistics" properties { key: "span" value: INT } properties { key: "split_names" value: STRING } base_type: STATISTICS } } } } outputs { key: "transform_graph" value { artifact_spec { type { name: "TransformGraph" } } } } outputs { key: "updated_analyzer_cache" value { artifact_spec { type { name: "TransformCache" } } } } } parameters { parameters { key: "custom_config" value { field_value { string_value: "null" } } } parameters { key: "disable_statistics" value { field_value { int_value: 0 } } } parameters { key: "force_tf_compat_v1" value { field_value { int_value: 0 } } } parameters { key: "module_path" value { field_value { string_value: "penguin_utils@pipelines/penguin-transform/_wheels/tfx_user_code_Transform-0.0+a5e9139bd7facf5026b5306a6aea534f89db0dea58ebe1bb1fb5ebb9df5fdea9-py3-none-any.whl" } } } } upstream_nodes: "CsvExampleGen" upstream_nodes: "schema_importer" downstream_nodes: "Trainer" execution_options { caching_options { } } , pipeline_info=id: "penguin-transform" , pipeline_run_id='2024-05-08T09:20:15.209892', top_level_pipeline_run_id=None, frontend_url=None) INFO:absl:Analyze the 'train' split and transform all splits when splits_config is not set. INFO:absl:udf_utils.get_fn {'module_file': None, 'module_path': 'penguin_utils@pipelines/penguin-transform/_wheels/tfx_user_code_Transform-0.0+a5e9139bd7facf5026b5306a6aea534f89db0dea58ebe1bb1fb5ebb9df5fdea9-py3-none-any.whl', 'preprocessing_fn': None} 'preprocessing_fn' INFO:absl:Installing 'pipelines/penguin-transform/_wheels/tfx_user_code_Transform-0.0+a5e9139bd7facf5026b5306a6aea534f89db0dea58ebe1bb1fb5ebb9df5fdea9-py3-none-any.whl' to a temporary directory. INFO:absl:Executing: ['/tmpfs/src/tf_docs_env/bin/python', '-m', 'pip', 'install', '--target', '/tmpfs/tmp/tmp0mjyiork', 'pipelines/penguin-transform/_wheels/tfx_user_code_Transform-0.0+a5e9139bd7facf5026b5306a6aea534f89db0dea58ebe1bb1fb5ebb9df5fdea9-py3-none-any.whl'] Processing ./pipelines/penguin-transform/_wheels/tfx_user_code_Transform-0.0+a5e9139bd7facf5026b5306a6aea534f89db0dea58ebe1bb1fb5ebb9df5fdea9-py3-none-any.whl INFO:absl:Successfully installed 'pipelines/penguin-transform/_wheels/tfx_user_code_Transform-0.0+a5e9139bd7facf5026b5306a6aea534f89db0dea58ebe1bb1fb5ebb9df5fdea9-py3-none-any.whl'. INFO:absl:udf_utils.get_fn {'module_file': None, 'module_path': 'penguin_utils@pipelines/penguin-transform/_wheels/tfx_user_code_Transform-0.0+a5e9139bd7facf5026b5306a6aea534f89db0dea58ebe1bb1fb5ebb9df5fdea9-py3-none-any.whl', 'stats_options_updater_fn': None} 'stats_options_updater_fn' INFO:absl:Installing 'pipelines/penguin-transform/_wheels/tfx_user_code_Transform-0.0+a5e9139bd7facf5026b5306a6aea534f89db0dea58ebe1bb1fb5ebb9df5fdea9-py3-none-any.whl' to a temporary directory. INFO:absl:Executing: ['/tmpfs/src/tf_docs_env/bin/python', '-m', 'pip', 'install', '--target', '/tmpfs/tmp/tmp_hkp2fs0', 'pipelines/penguin-transform/_wheels/tfx_user_code_Transform-0.0+a5e9139bd7facf5026b5306a6aea534f89db0dea58ebe1bb1fb5ebb9df5fdea9-py3-none-any.whl'] Installing collected packages: tfx-user-code-Transform Successfully installed tfx-user-code-Transform-0.0+a5e9139bd7facf5026b5306a6aea534f89db0dea58ebe1bb1fb5ebb9df5fdea9 Processing ./pipelines/penguin-transform/_wheels/tfx_user_code_Transform-0.0+a5e9139bd7facf5026b5306a6aea534f89db0dea58ebe1bb1fb5ebb9df5fdea9-py3-none-any.whl INFO:absl:Successfully installed 'pipelines/penguin-transform/_wheels/tfx_user_code_Transform-0.0+a5e9139bd7facf5026b5306a6aea534f89db0dea58ebe1bb1fb5ebb9df5fdea9-py3-none-any.whl'. INFO:absl:Installing 'pipelines/penguin-transform/_wheels/tfx_user_code_Transform-0.0+a5e9139bd7facf5026b5306a6aea534f89db0dea58ebe1bb1fb5ebb9df5fdea9-py3-none-any.whl' to a temporary directory. INFO:absl:Executing: ['/tmpfs/src/tf_docs_env/bin/python', '-m', 'pip', 'install', '--target', '/tmpfs/tmp/tmprvrky6ts', 'pipelines/penguin-transform/_wheels/tfx_user_code_Transform-0.0+a5e9139bd7facf5026b5306a6aea534f89db0dea58ebe1bb1fb5ebb9df5fdea9-py3-none-any.whl'] Installing collected packages: tfx-user-code-Transform Successfully installed tfx-user-code-Transform-0.0+a5e9139bd7facf5026b5306a6aea534f89db0dea58ebe1bb1fb5ebb9df5fdea9 Processing ./pipelines/penguin-transform/_wheels/tfx_user_code_Transform-0.0+a5e9139bd7facf5026b5306a6aea534f89db0dea58ebe1bb1fb5ebb9df5fdea9-py3-none-any.whl INFO:absl:Successfully installed 'pipelines/penguin-transform/_wheels/tfx_user_code_Transform-0.0+a5e9139bd7facf5026b5306a6aea534f89db0dea58ebe1bb1fb5ebb9df5fdea9-py3-none-any.whl'. INFO:absl:Feature body_mass_g has a shape dim { size: 1 } . Setting to DenseTensor. INFO:absl:Feature culmen_depth_mm has a shape dim { size: 1 } . Setting to DenseTensor. INFO:absl:Feature culmen_length_mm has a shape dim { size: 1 } . Setting to DenseTensor. INFO:absl:Feature flipper_length_mm has a shape dim { size: 1 } . Setting to DenseTensor. INFO:absl:Feature island has a shape dim { size: 1 } . Setting to DenseTensor. INFO:absl:Feature sex has a shape dim { size: 1 } . Setting to DenseTensor. INFO:absl:Feature species has a shape dim { size: 1 } . Setting to DenseTensor. Installing collected packages: tfx-user-code-Transform Successfully installed tfx-user-code-Transform-0.0+a5e9139bd7facf5026b5306a6aea534f89db0dea58ebe1bb1fb5ebb9df5fdea9 INFO:absl:Feature body_mass_g has a shape dim { size: 1 } . Setting to DenseTensor. INFO:absl:Feature culmen_depth_mm has a shape dim { size: 1 } . Setting to DenseTensor. INFO:absl:Feature culmen_length_mm has a shape dim { size: 1 } . Setting to DenseTensor. INFO:absl:Feature flipper_length_mm has a shape dim { size: 1 } . Setting to DenseTensor. INFO:absl:Feature island has a shape dim { size: 1 } . Setting to DenseTensor. INFO:absl:Feature sex has a shape dim { size: 1 } . Setting to DenseTensor. INFO:absl:Feature species has a shape dim { size: 1 } . Setting to DenseTensor. INFO:absl:Feature body_mass_g has a shape dim { size: 1 } . Setting to DenseTensor. INFO:absl:Feature culmen_depth_mm has a shape dim { size: 1 } . Setting to DenseTensor. INFO:absl:Feature culmen_length_mm has a shape dim { size: 1 } . Setting to DenseTensor. INFO:absl:Feature flipper_length_mm has a shape dim { size: 1 } . Setting to DenseTensor. INFO:absl:Feature island has a shape dim { size: 1 } . Setting to DenseTensor. INFO:absl:Feature sex has a shape dim { size: 1 } . Setting to DenseTensor. INFO:absl:Feature species has a shape dim { size: 1 } . Setting to DenseTensor. INFO:absl:Feature body_mass_g has a shape dim { size: 1 } . Setting to DenseTensor. INFO:absl:Feature culmen_depth_mm has a shape dim { size: 1 } . Setting to DenseTensor. INFO:absl:Feature culmen_length_mm has a shape dim { size: 1 } . Setting to DenseTensor. INFO:absl:Feature flipper_length_mm has a shape dim { size: 1 } . Setting to DenseTensor. INFO:absl:Feature species has a shape dim { size: 1 } . Setting to DenseTensor. INFO:absl:Feature body_mass_g has a shape dim { size: 1 } . Setting to DenseTensor. INFO:absl:Feature culmen_depth_mm has a shape dim { size: 1 } . Setting to DenseTensor. INFO:absl:Feature culmen_length_mm has a shape dim { size: 1 } . Setting to DenseTensor. INFO:absl:Feature flipper_length_mm has a shape dim { size: 1 } . Setting to DenseTensor. INFO:absl:Feature island has a shape dim { size: 1 } . Setting to DenseTensor. INFO:absl:Feature sex has a shape dim { size: 1 } . Setting to DenseTensor. INFO:absl:Feature species has a shape dim { size: 1 } . Setting to DenseTensor. INFO:absl:Feature body_mass_g has a shape dim { size: 1 } . Setting to DenseTensor. INFO:absl:Feature culmen_depth_mm has a shape dim { size: 1 } . Setting to DenseTensor. INFO:absl:Feature culmen_length_mm has a shape dim { size: 1 } . Setting to DenseTensor. INFO:absl:Feature flipper_length_mm has a shape dim { size: 1 } . Setting to DenseTensor. INFO:absl:Feature island has a shape dim { size: 1 } . Setting to DenseTensor. INFO:absl:Feature sex has a shape dim { size: 1 } . Setting to DenseTensor. INFO:absl:Feature species has a shape dim { size: 1 } . Setting to DenseTensor. INFO:absl:Feature body_mass_g has a shape dim { size: 1 } . Setting to DenseTensor. INFO:absl:Feature culmen_depth_mm has a shape dim { size: 1 } . Setting to DenseTensor. INFO:absl:Feature culmen_length_mm has a shape dim { size: 1 } . Setting to DenseTensor. INFO:absl:Feature flipper_length_mm has a shape dim { size: 1 } . Setting to DenseTensor. INFO:absl:Feature island has a shape dim { size: 1 } . Setting to DenseTensor. INFO:absl:Feature sex has a shape dim { size: 1 } . Setting to DenseTensor. INFO:absl:Feature species has a shape dim { size: 1 } . Setting to DenseTensor. WARNING:absl:Tables initialized inside a tf.function will be re-initialized on every invocation of the function. This re-initialization can have significant impact on performance. Consider lifting them out of the graph context using `tf.init_scope`.: key_value_init/LookupTableImportV2 WARNING:absl:Tables initialized inside a tf.function will be re-initialized on every invocation of the function. This re-initialization can have significant impact on performance. Consider lifting them out of the graph context using `tf.init_scope`.: key_value_init/LookupTableImportV2 WARNING:absl:Tables initialized inside a tf.function will be re-initialized on every invocation of the function. This re-initialization can have significant impact on performance. Consider lifting them out of the graph context using `tf.init_scope`.: key_value_init/LookupTableImportV2 WARNING:absl:Tables initialized inside a tf.function will be re-initialized on every invocation of the function. This re-initialization can have significant impact on performance. Consider lifting them out of the graph context using `tf.init_scope`.: key_value_init/LookupTableImportV2 INFO:absl:Feature body_mass_g has a shape dim { size: 1 } . Setting to DenseTensor. INFO:absl:Feature culmen_depth_mm has a shape dim { size: 1 } . Setting to DenseTensor. INFO:absl:Feature culmen_length_mm has a shape dim { size: 1 } . Setting to DenseTensor. INFO:absl:Feature flipper_length_mm has a shape dim { size: 1 } . Setting to DenseTensor. INFO:absl:Feature species has a shape dim { size: 1 } . Setting to DenseTensor. INFO:absl:Feature body_mass_g has a shape dim { size: 1 } . Setting to DenseTensor. INFO:absl:Feature culmen_depth_mm has a shape dim { size: 1 } . Setting to DenseTensor. INFO:absl:Feature culmen_length_mm has a shape dim { size: 1 } . Setting to DenseTensor. INFO:absl:Feature flipper_length_mm has a shape dim { size: 1 } . Setting to DenseTensor. INFO:absl:Feature island has a shape dim { size: 1 } . Setting to DenseTensor. INFO:absl:Feature sex has a shape dim { size: 1 } . Setting to DenseTensor. INFO:absl:Feature species has a shape dim { size: 1 } . Setting to DenseTensor. INFO:absl:Feature body_mass_g has a shape dim { size: 1 } . Setting to DenseTensor. INFO:absl:Feature culmen_depth_mm has a shape dim { size: 1 } . Setting to DenseTensor. INFO:absl:Feature culmen_length_mm has a shape dim { size: 1 } . Setting to DenseTensor. INFO:absl:Feature flipper_length_mm has a shape dim { size: 1 } . Setting to DenseTensor. INFO:absl:Feature species has a shape dim { size: 1 } . Setting to DenseTensor. INFO:absl:Feature body_mass_g has a shape dim { size: 1 } . Setting to DenseTensor. INFO:absl:Feature culmen_depth_mm has a shape dim { size: 1 } . Setting to DenseTensor. INFO:absl:Feature culmen_length_mm has a shape dim { size: 1 } . Setting to DenseTensor. INFO:absl:Feature flipper_length_mm has a shape dim { size: 1 } . Setting to DenseTensor. INFO:absl:Feature island has a shape dim { size: 1 } . Setting to DenseTensor. INFO:absl:Feature sex has a shape dim { size: 1 } . Setting to DenseTensor. INFO:absl:Feature species has a shape dim { size: 1 } . Setting to DenseTensor. INFO:tensorflow:Assets written to: pipelines/penguin-transform/Transform/transform_graph/4/.temp_path/tftransform_tmp/c08ebbeca8814f288e2437f43b1224e5/assets INFO:tensorflow:Assets written to: pipelines/penguin-transform/Transform/transform_graph/4/.temp_path/tftransform_tmp/c08ebbeca8814f288e2437f43b1224e5/assets INFO:absl:Writing fingerprint to pipelines/penguin-transform/Transform/transform_graph/4/.temp_path/tftransform_tmp/c08ebbeca8814f288e2437f43b1224e5/fingerprint.pb WARNING:absl:Tables initialized inside a tf.function will be re-initialized on every invocation of the function. This re-initialization can have significant impact on performance. Consider lifting them out of the graph context using `tf.init_scope`.: key_value_init/LookupTableImportV2 INFO:tensorflow:struct2tensor is not available. INFO:tensorflow:struct2tensor is not available. INFO:tensorflow:tensorflow_decision_forests is not available. INFO:tensorflow:tensorflow_decision_forests is not available. INFO:tensorflow:tensorflow_text is not available. INFO:tensorflow:tensorflow_text is not available. INFO:tensorflow:Assets written to: pipelines/penguin-transform/Transform/transform_graph/4/.temp_path/tftransform_tmp/922454c8ed954a439ec7388ab1479458/assets INFO:tensorflow:Assets written to: pipelines/penguin-transform/Transform/transform_graph/4/.temp_path/tftransform_tmp/922454c8ed954a439ec7388ab1479458/assets INFO:absl:Writing fingerprint to pipelines/penguin-transform/Transform/transform_graph/4/.temp_path/tftransform_tmp/922454c8ed954a439ec7388ab1479458/fingerprint.pb WARNING:absl:Tables initialized inside a tf.function will be re-initialized on every invocation of the function. This re-initialization can have significant impact on performance. Consider lifting them out of the graph context using `tf.init_scope`.: key_value_init/LookupTableImportV2 INFO:absl:Feature body_mass_g has a shape dim { size: 1 } . Setting to DenseTensor. INFO:absl:Feature culmen_depth_mm has a shape dim { size: 1 } . Setting to DenseTensor. INFO:absl:Feature culmen_length_mm has a shape dim { size: 1 } . Setting to DenseTensor. INFO:absl:Feature flipper_length_mm has a shape dim { size: 1 } . Setting to DenseTensor. INFO:absl:Feature species has a shape dim { size: 1 } . Setting to DenseTensor. INFO:absl:Feature body_mass_g has a shape dim { size: 1 } . Setting to DenseTensor. INFO:absl:Feature culmen_depth_mm has a shape dim { size: 1 } . Setting to DenseTensor. INFO:absl:Feature culmen_length_mm has a shape dim { size: 1 } . Setting to DenseTensor. INFO:absl:Feature flipper_length_mm has a shape dim { size: 1 } . Setting to DenseTensor. INFO:absl:Feature species has a shape dim { size: 1 } . Setting to DenseTensor. INFO:tensorflow:struct2tensor is not available. INFO:tensorflow:struct2tensor is not available. INFO:tensorflow:tensorflow_decision_forests is not available. INFO:tensorflow:tensorflow_decision_forests is not available. INFO:tensorflow:tensorflow_text is not available. INFO:tensorflow:tensorflow_text is not available. INFO:tensorflow:struct2tensor is not available. INFO:tensorflow:struct2tensor is not available. INFO:tensorflow:tensorflow_decision_forests is not available. INFO:tensorflow:tensorflow_decision_forests is not available. INFO:tensorflow:tensorflow_text is not available. INFO:tensorflow:tensorflow_text is not available. INFO:absl:Cleaning up stateless execution info. INFO:absl:Execution 4 succeeded. INFO:absl:Cleaning up stateful execution info. INFO:absl:Deleted stateful_working_dir pipelines/penguin-transform/Transform/.system/stateful_working_dir/353ff412-c22b-442d-86fa-d000785c15c6 INFO:absl:Publishing output artifacts defaultdict(<class 'list'>, {'post_transform_schema': [Artifact(artifact: uri: "pipelines/penguin-transform/Transform/post_transform_schema/4" , artifact_type: name: "Schema" )], 'updated_analyzer_cache': [Artifact(artifact: uri: "pipelines/penguin-transform/Transform/updated_analyzer_cache/4" , artifact_type: name: "TransformCache" )], 'pre_transform_stats': [Artifact(artifact: uri: "pipelines/penguin-transform/Transform/pre_transform_stats/4" , artifact_type: name: "ExampleStatistics" properties { key: "span" value: INT } properties { key: "split_names" value: STRING } base_type: STATISTICS )], 'transform_graph': [Artifact(artifact: uri: "pipelines/penguin-transform/Transform/transform_graph/4" , artifact_type: name: "TransformGraph" )], 'post_transform_stats': [Artifact(artifact: uri: "pipelines/penguin-transform/Transform/post_transform_stats/4" , artifact_type: name: "ExampleStatistics" properties { key: "span" value: INT } properties { key: "split_names" value: STRING } base_type: STATISTICS )], 'post_transform_anomalies': [Artifact(artifact: uri: "pipelines/penguin-transform/Transform/post_transform_anomalies/4" , artifact_type: name: "ExampleAnomalies" properties { key: "span" value: INT } properties { key: "split_names" value: STRING } )], 'pre_transform_schema': [Artifact(artifact: uri: "pipelines/penguin-transform/Transform/pre_transform_schema/4" , artifact_type: name: "Schema" )]}) for execution 4 INFO:absl:MetadataStore with DB connection initialized INFO:absl:Component Transform is finished. INFO:absl:Component ExampleValidator is running. INFO:absl:Running launcher for node_info { type { name: "tfx.components.example_validator.component.ExampleValidator" } id: "ExampleValidator" } contexts { contexts { type { name: "pipeline" } name { field_value { string_value: "penguin-transform" } } } contexts { type { name: "pipeline_run" } name { field_value { string_value: "2024-05-08T09:20:15.209892" } } } contexts { type { name: "node" } name { field_value { string_value: "penguin-transform.ExampleValidator" } } } } inputs { inputs { key: "schema" value { channels { producer_node_query { id: "schema_importer" } context_queries { type { name: "pipeline" } name { field_value { string_value: "penguin-transform" } } } context_queries { type { name: "pipeline_run" } name { field_value { string_value: "2024-05-08T09:20:15.209892" } } } context_queries { type { name: "node" } name { field_value { string_value: "penguin-transform.schema_importer" } } } artifact_query { type { name: "Schema" } } output_key: "result" } min_count: 1 } } inputs { key: "statistics" value { channels { producer_node_query { id: "StatisticsGen" } context_queries { type { name: "pipeline" } name { field_value { string_value: "penguin-transform" } } } context_queries { type { name: "pipeline_run" } name { field_value { string_value: "2024-05-08T09:20:15.209892" } } } context_queries { type { name: "node" } name { field_value { string_value: "penguin-transform.StatisticsGen" } } } artifact_query { type { name: "ExampleStatistics" base_type: STATISTICS } } output_key: "statistics" } min_count: 1 } } } outputs { outputs { key: "anomalies" value { artifact_spec { type { name: "ExampleAnomalies" properties { key: "span" value: INT } properties { key: "split_names" value: STRING } } } } } } parameters { parameters { key: "exclude_splits" value { field_value { string_value: "[]" } } } } upstream_nodes: "StatisticsGen" upstream_nodes: "schema_importer" execution_options { caching_options { } } INFO:absl:MetadataStore with DB connection initialized WARNING:absl:ArtifactQuery.property_predicate is not supported. WARNING:absl:ArtifactQuery.property_predicate is not supported. INFO:absl:[ExampleValidator] Resolved inputs: ({'statistics': [Artifact(artifact: id: 3 type_id: 19 uri: "pipelines/penguin-transform/StatisticsGen/statistics/3" properties { key: "span" value { int_value: 0 } } properties { key: "split_names" value { string_value: "[\"train\", \"eval\"]" } } custom_properties { key: "is_external" value { int_value: 0 } } custom_properties { key: "stats_dashboard_link" value { string_value: "" } } custom_properties { key: "tfx_version" value { string_value: "1.15.0" } } state: LIVE type: "ExampleStatistics" create_time_since_epoch: 1715160019591 last_update_time_since_epoch: 1715160019591 , artifact_type: id: 19 name: "ExampleStatistics" properties { key: "span" value: INT } properties { key: "split_names" value: STRING } base_type: STATISTICS )], 'schema': [Artifact(artifact: id: 2 type_id: 17 uri: "schema" custom_properties { key: "is_external" value { int_value: 1 } } custom_properties { key: "tfx_version" value { string_value: "1.15.0" } } state: LIVE type: "Schema" create_time_since_epoch: 1715160016371 last_update_time_since_epoch: 1715160016371 , artifact_type: id: 17 name: "Schema" )]},) INFO:absl:MetadataStore with DB connection initialized INFO:absl:Going to run a new execution 5 INFO:absl:Going to run a new execution: ExecutionInfo(execution_id=5, input_dict={'statistics': [Artifact(artifact: id: 3 type_id: 19 uri: "pipelines/penguin-transform/StatisticsGen/statistics/3" properties { key: "span" value { int_value: 0 } } properties { key: "split_names" value { string_value: "[\"train\", \"eval\"]" } } custom_properties { key: "is_external" value { int_value: 0 } } custom_properties { key: "stats_dashboard_link" value { string_value: "" } } custom_properties { key: "tfx_version" value { string_value: "1.15.0" } } state: LIVE type: "ExampleStatistics" create_time_since_epoch: 1715160019591 last_update_time_since_epoch: 1715160019591 , artifact_type: id: 19 name: "ExampleStatistics" properties { key: "span" value: INT } properties { key: "split_names" value: STRING } base_type: STATISTICS )], 'schema': [Artifact(artifact: id: 2 type_id: 17 uri: "schema" custom_properties { key: "is_external" value { int_value: 1 } } custom_properties { key: "tfx_version" value { string_value: "1.15.0" } } state: LIVE type: "Schema" create_time_since_epoch: 1715160016371 last_update_time_since_epoch: 1715160016371 , artifact_type: id: 17 name: "Schema" )]}, output_dict=defaultdict(<class 'list'>, {'anomalies': [Artifact(artifact: uri: "pipelines/penguin-transform/ExampleValidator/anomalies/5" , artifact_type: name: "ExampleAnomalies" properties { key: "span" value: INT } properties { key: "split_names" value: STRING } )]}), exec_properties={'exclude_splits': '[]'}, execution_output_uri='pipelines/penguin-transform/ExampleValidator/.system/executor_execution/5/executor_output.pb', stateful_working_dir='pipelines/penguin-transform/ExampleValidator/.system/stateful_working_dir/5c7a9496-8531-48e8-9884-9db5c9dd4bb2', tmp_dir='pipelines/penguin-transform/ExampleValidator/.system/executor_execution/5/.temp/', pipeline_node=node_info { type { name: "tfx.components.example_validator.component.ExampleValidator" } id: "ExampleValidator" } contexts { contexts { type { name: "pipeline" } name { field_value { string_value: "penguin-transform" } } } contexts { type { name: "pipeline_run" } name { field_value { string_value: "2024-05-08T09:20:15.209892" } } } contexts { type { name: "node" } name { field_value { string_value: "penguin-transform.ExampleValidator" } } } } inputs { inputs { key: "schema" value { channels { producer_node_query { id: "schema_importer" } context_queries { type { name: "pipeline" } name { field_value { string_value: "penguin-transform" } } } context_queries { type { name: "pipeline_run" } name { field_value { string_value: "2024-05-08T09:20:15.209892" } } } context_queries { type { name: "node" } name { field_value { string_value: "penguin-transform.schema_importer" } } } artifact_query { type { name: "Schema" } } output_key: "result" } min_count: 1 } } inputs { key: "statistics" value { channels { producer_node_query { id: "StatisticsGen" } context_queries { type { name: "pipeline" } name { field_value { string_value: "penguin-transform" } } } context_queries { type { name: "pipeline_run" } name { field_value { string_value: "2024-05-08T09:20:15.209892" } } } context_queries { type { name: "node" } name { field_value { string_value: "penguin-transform.StatisticsGen" } } } artifact_query { type { name: "ExampleStatistics" base_type: STATISTICS } } output_key: "statistics" } min_count: 1 } } } outputs { outputs { key: "anomalies" value { artifact_spec { type { name: "ExampleAnomalies" properties { key: "span" value: INT } properties { key: "split_names" value: STRING } } } } } } parameters { parameters { key: "exclude_splits" value { field_value { string_value: "[]" } } } } upstream_nodes: "StatisticsGen" upstream_nodes: "schema_importer" execution_options { caching_options { } } , pipeline_info=id: "penguin-transform" , pipeline_run_id='2024-05-08T09:20:15.209892', top_level_pipeline_run_id=None, frontend_url=None) INFO:absl:Validating schema against the computed statistics for split train. INFO:absl:Anomalies alerts created for split train. INFO:absl:Validation complete for split train. Anomalies written to pipelines/penguin-transform/ExampleValidator/anomalies/5/Split-train. INFO:absl:Validating schema against the computed statistics for split eval. INFO:absl:Anomalies alerts created for split eval. INFO:absl:Validation complete for split eval. Anomalies written to pipelines/penguin-transform/ExampleValidator/anomalies/5/Split-eval. INFO:absl:Cleaning up stateless execution info. INFO:absl:Execution 5 succeeded. INFO:absl:Cleaning up stateful execution info. INFO:absl:Deleted stateful_working_dir pipelines/penguin-transform/ExampleValidator/.system/stateful_working_dir/5c7a9496-8531-48e8-9884-9db5c9dd4bb2 INFO:absl:Publishing output artifacts defaultdict(<class 'list'>, {'anomalies': [Artifact(artifact: uri: "pipelines/penguin-transform/ExampleValidator/anomalies/5" , artifact_type: name: "ExampleAnomalies" properties { key: "span" value: INT } properties { key: "split_names" value: STRING } )]}) for execution 5 INFO:absl:MetadataStore with DB connection initialized INFO:absl:Component ExampleValidator is finished. INFO:absl:Component Trainer is running. INFO:absl:Running launcher for node_info { type { name: "tfx.components.trainer.component.Trainer" base_type: TRAIN } id: "Trainer" } contexts { contexts { type { name: "pipeline" } name { field_value { string_value: "penguin-transform" } } } contexts { type { name: "pipeline_run" } name { field_value { string_value: "2024-05-08T09:20:15.209892" } } } contexts { type { name: "node" } name { field_value { string_value: "penguin-transform.Trainer" } } } } inputs { inputs { key: "examples" value { channels { producer_node_query { id: "CsvExampleGen" } context_queries { type { name: "pipeline" } name { field_value { string_value: "penguin-transform" } } } context_queries { type { name: "pipeline_run" } name { field_value { string_value: "2024-05-08T09:20:15.209892" } } } context_queries { type { name: "node" } name { field_value { string_value: "penguin-transform.CsvExampleGen" } } } artifact_query { type { name: "Examples" base_type: DATASET } } output_key: "examples" } min_count: 1 } } inputs { key: "transform_graph" value { channels { producer_node_query { id: "Transform" } context_queries { type { name: "pipeline" } name { field_value { string_value: "penguin-transform" } } } context_queries { type { name: "pipeline_run" } name { field_value { string_value: "2024-05-08T09:20:15.209892" } } } context_queries { type { name: "node" } name { field_value { string_value: "penguin-transform.Transform" } } } artifact_query { type { name: "TransformGraph" } } output_key: "transform_graph" } } } } outputs { outputs { key: "model" value { artifact_spec { type { name: "Model" base_type: MODEL } } } } outputs { key: "model_run" value { artifact_spec { type { name: "ModelRun" } } } } } parameters { parameters { key: "custom_config" value { field_value { string_value: "null" } } } parameters { key: "eval_args" value { field_value { string_value: "{\n \"num_steps\": 5\n}" } } } parameters { key: "module_path" value { field_value { string_value: "penguin_utils@pipelines/penguin-transform/_wheels/tfx_user_code_Trainer-0.0+a5e9139bd7facf5026b5306a6aea534f89db0dea58ebe1bb1fb5ebb9df5fdea9-py3-none-any.whl" } } } parameters { key: "train_args" value { field_value { string_value: "{\n \"num_steps\": 100\n}" } } } } upstream_nodes: "CsvExampleGen" upstream_nodes: "Transform" downstream_nodes: "Pusher" execution_options { caching_options { } } INFO:absl:MetadataStore with DB connection initialized WARNING:absl:ArtifactQuery.property_predicate is not supported. WARNING:absl:ArtifactQuery.property_predicate is not supported. INFO:absl:[Trainer] Resolved inputs: ({'examples': [Artifact(artifact: id: 1 type_id: 15 uri: "pipelines/penguin-transform/CsvExampleGen/examples/1" properties { key: "split_names" value { string_value: "[\"train\", \"eval\"]" } } custom_properties { key: "file_format" value { string_value: "tfrecords_gzip" } } custom_properties { key: "input_fingerprint" value { string_value: "split:single_split,num_files:1,total_bytes:13161,xor_checksum:1715160013,sum_checksum:1715160013" } } custom_properties { key: "is_external" value { int_value: 0 } } custom_properties { key: "payload_format" value { string_value: "FORMAT_TF_EXAMPLE" } } custom_properties { key: "span" value { int_value: 0 } } custom_properties { key: "tfx_version" value { string_value: "1.15.0" } } state: LIVE type: "Examples" create_time_since_epoch: 1715160016351 last_update_time_since_epoch: 1715160016351 , artifact_type: id: 15 name: "Examples" properties { key: "span" value: INT } properties { key: "split_names" value: STRING } properties { key: "version" value: INT } base_type: DATASET )], 'transform_graph': [Artifact(artifact: id: 7 type_id: 22 uri: "pipelines/penguin-transform/Transform/transform_graph/4" custom_properties { key: "is_external" value { int_value: 0 } } custom_properties { key: "tfx_version" value { string_value: "1.15.0" } } state: LIVE type: "TransformGraph" create_time_since_epoch: 1715160039921 last_update_time_since_epoch: 1715160039921 , artifact_type: id: 22 name: "TransformGraph" )]},) INFO:absl:MetadataStore with DB connection initialized INFO:absl:Going to run a new execution 6 INFO:absl:Going to run a new execution: ExecutionInfo(execution_id=6, input_dict={'examples': [Artifact(artifact: id: 1 type_id: 15 uri: "pipelines/penguin-transform/CsvExampleGen/examples/1" properties { key: "split_names" value { string_value: "[\"train\", \"eval\"]" } } custom_properties { key: "file_format" value { string_value: "tfrecords_gzip" } } custom_properties { key: "input_fingerprint" value { string_value: "split:single_split,num_files:1,total_bytes:13161,xor_checksum:1715160013,sum_checksum:1715160013" } } custom_properties { key: "is_external" value { int_value: 0 } } custom_properties { key: "payload_format" value { string_value: "FORMAT_TF_EXAMPLE" } } custom_properties { key: "span" value { int_value: 0 } } custom_properties { key: "tfx_version" value { string_value: "1.15.0" } } state: LIVE type: "Examples" create_time_since_epoch: 1715160016351 last_update_time_since_epoch: 1715160016351 , artifact_type: id: 15 name: "Examples" properties { key: "span" value: INT } properties { key: "split_names" value: STRING } properties { key: "version" value: INT } base_type: DATASET )], 'transform_graph': [Artifact(artifact: id: 7 type_id: 22 uri: "pipelines/penguin-transform/Transform/transform_graph/4" custom_properties { key: "is_external" value { int_value: 0 } } custom_properties { key: "tfx_version" value { string_value: "1.15.0" } } state: LIVE type: "TransformGraph" create_time_since_epoch: 1715160039921 last_update_time_since_epoch: 1715160039921 , artifact_type: id: 22 name: "TransformGraph" )]}, output_dict=defaultdict(<class 'list'>, {'model': [Artifact(artifact: uri: "pipelines/penguin-transform/Trainer/model/6" , artifact_type: name: "Model" base_type: MODEL )], 'model_run': [Artifact(artifact: uri: "pipelines/penguin-transform/Trainer/model_run/6" , artifact_type: name: "ModelRun" )]}), exec_properties={'train_args': '{\n "num_steps": 100\n}', 'custom_config': 'null', 'module_path': 'penguin_utils@pipelines/penguin-transform/_wheels/tfx_user_code_Trainer-0.0+a5e9139bd7facf5026b5306a6aea534f89db0dea58ebe1bb1fb5ebb9df5fdea9-py3-none-any.whl', 'eval_args': '{\n "num_steps": 5\n}'}, execution_output_uri='pipelines/penguin-transform/Trainer/.system/executor_execution/6/executor_output.pb', stateful_working_dir='pipelines/penguin-transform/Trainer/.system/stateful_working_dir/04611601-4dbf-412d-8c6c-4e4864bb335a', tmp_dir='pipelines/penguin-transform/Trainer/.system/executor_execution/6/.temp/', pipeline_node=node_info { type { name: "tfx.components.trainer.component.Trainer" base_type: TRAIN } id: "Trainer" } contexts { contexts { type { name: "pipeline" } name { field_value { string_value: "penguin-transform" } } } contexts { type { name: "pipeline_run" } name { field_value { string_value: "2024-05-08T09:20:15.209892" } } } contexts { type { name: "node" } name { field_value { string_value: "penguin-transform.Trainer" } } } } inputs { inputs { key: "examples" value { channels { producer_node_query { id: "CsvExampleGen" } context_queries { type { name: "pipeline" } name { field_value { string_value: "penguin-transform" } } } context_queries { type { name: "pipeline_run" } name { field_value { string_value: "2024-05-08T09:20:15.209892" } } } context_queries { type { name: "node" } name { field_value { string_value: "penguin-transform.CsvExampleGen" } } } artifact_query { type { name: "Examples" base_type: DATASET } } output_key: "examples" } min_count: 1 } } inputs { key: "transform_graph" value { channels { producer_node_query { id: "Transform" } context_queries { type { name: "pipeline" } name { field_value { string_value: "penguin-transform" } } } context_queries { type { name: "pipeline_run" } name { field_value { string_value: "2024-05-08T09:20:15.209892" } } } context_queries { type { name: "node" } name { field_value { string_value: "penguin-transform.Transform" } } } artifact_query { type { name: "TransformGraph" } } output_key: "transform_graph" } } } } outputs { outputs { key: "model" value { artifact_spec { type { name: "Model" base_type: MODEL } } } } outputs { key: "model_run" value { artifact_spec { type { name: "ModelRun" } } } } } parameters { parameters { key: "custom_config" value { field_value { string_value: "null" } } } parameters { key: "eval_args" value { field_value { string_value: "{\n \"num_steps\": 5\n}" } } } parameters { key: "module_path" value { field_value { string_value: "penguin_utils@pipelines/penguin-transform/_wheels/tfx_user_code_Trainer-0.0+a5e9139bd7facf5026b5306a6aea534f89db0dea58ebe1bb1fb5ebb9df5fdea9-py3-none-any.whl" } } } parameters { key: "train_args" value { field_value { string_value: "{\n \"num_steps\": 100\n}" } } } } upstream_nodes: "CsvExampleGen" upstream_nodes: "Transform" downstream_nodes: "Pusher" execution_options { caching_options { } } , pipeline_info=id: "penguin-transform" , pipeline_run_id='2024-05-08T09:20:15.209892', top_level_pipeline_run_id=None, frontend_url=None) INFO:absl:Train on the 'train' split when train_args.splits is not set. INFO:absl:Evaluate on the 'eval' split when eval_args.splits is not set. INFO:absl:udf_utils.get_fn {'train_args': '{\n "num_steps": 100\n}', 'custom_config': 'null', 'module_path': 'penguin_utils@pipelines/penguin-transform/_wheels/tfx_user_code_Trainer-0.0+a5e9139bd7facf5026b5306a6aea534f89db0dea58ebe1bb1fb5ebb9df5fdea9-py3-none-any.whl', 'eval_args': '{\n "num_steps": 5\n}'} 'run_fn' INFO:absl:Installing 'pipelines/penguin-transform/_wheels/tfx_user_code_Trainer-0.0+a5e9139bd7facf5026b5306a6aea534f89db0dea58ebe1bb1fb5ebb9df5fdea9-py3-none-any.whl' to a temporary directory. INFO:absl:Executing: ['/tmpfs/src/tf_docs_env/bin/python', '-m', 'pip', 'install', '--target', '/tmpfs/tmp/tmpwu08fhdc', 'pipelines/penguin-transform/_wheels/tfx_user_code_Trainer-0.0+a5e9139bd7facf5026b5306a6aea534f89db0dea58ebe1bb1fb5ebb9df5fdea9-py3-none-any.whl'] Processing ./pipelines/penguin-transform/_wheels/tfx_user_code_Trainer-0.0+a5e9139bd7facf5026b5306a6aea534f89db0dea58ebe1bb1fb5ebb9df5fdea9-py3-none-any.whl INFO:absl:Successfully installed 'pipelines/penguin-transform/_wheels/tfx_user_code_Trainer-0.0+a5e9139bd7facf5026b5306a6aea534f89db0dea58ebe1bb1fb5ebb9df5fdea9-py3-none-any.whl'. INFO:absl:Training model. INFO:absl:Feature body_mass_g has a shape dim { size: 1 } . Setting to DenseTensor. INFO:absl:Feature culmen_depth_mm has a shape dim { size: 1 } . Setting to DenseTensor. INFO:absl:Feature culmen_length_mm has a shape dim { size: 1 } . Setting to DenseTensor. INFO:absl:Feature flipper_length_mm has a shape dim { size: 1 } . Setting to DenseTensor. INFO:absl:Feature island has a shape dim { size: 1 } . Setting to DenseTensor. INFO:absl:Feature sex has a shape dim { size: 1 } . Setting to DenseTensor. INFO:absl:Feature species has a shape dim { size: 1 } . Setting to DenseTensor. Installing collected packages: tfx-user-code-Trainer Successfully installed tfx-user-code-Trainer-0.0+a5e9139bd7facf5026b5306a6aea534f89db0dea58ebe1bb1fb5ebb9df5fdea9 WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tfx_bsl/tfxio/tf_example_record.py:343: parse_example_dataset (from tensorflow.python.data.experimental.ops.parsing_ops) is deprecated and will be removed in a future version. Instructions for updating: Use `tf.data.Dataset.map(tf.io.parse_example(...))` instead. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tfx_bsl/tfxio/tf_example_record.py:343: parse_example_dataset (from tensorflow.python.data.experimental.ops.parsing_ops) is deprecated and will be removed in a future version. Instructions for updating: Use `tf.data.Dataset.map(tf.io.parse_example(...))` instead. INFO:tensorflow:struct2tensor is not available. INFO:tensorflow:struct2tensor is not available. INFO:tensorflow:tensorflow_decision_forests is not available. INFO:tensorflow:tensorflow_decision_forests is not available. INFO:tensorflow:tensorflow_text is not available. INFO:tensorflow:tensorflow_text is not available. INFO:absl:Feature body_mass_g has a shape dim { size: 1 } . Setting to DenseTensor. INFO:absl:Feature culmen_depth_mm has a shape dim { size: 1 } . Setting to DenseTensor. INFO:absl:Feature culmen_length_mm has a shape dim { size: 1 } . Setting to DenseTensor. INFO:absl:Feature flipper_length_mm has a shape dim { size: 1 } . Setting to DenseTensor. INFO:absl:Feature island has a shape dim { size: 1 } . Setting to DenseTensor. INFO:absl:Feature sex has a shape dim { size: 1 } . Setting to DenseTensor. INFO:absl:Feature species has a shape dim { size: 1 } . Setting to DenseTensor. INFO:absl:Model: "model" INFO:absl:__________________________________________________________________________________________________ INFO:absl: Layer (type) Output Shape Param # Connected to INFO:absl:================================================================================================== INFO:absl: culmen_length_mm (InputLay [(None, 1)] 0 [] INFO:absl: er) INFO:absl: INFO:absl: culmen_depth_mm (InputLaye [(None, 1)] 0 [] INFO:absl: r) INFO:absl: INFO:absl: flipper_length_mm (InputLa [(None, 1)] 0 [] INFO:absl: yer) INFO:absl: INFO:absl: body_mass_g (InputLayer) [(None, 1)] 0 [] INFO:absl: INFO:absl: concatenate (Concatenate) (None, 4) 0 ['culmen_length_mm[0][0]', INFO:absl: 'culmen_depth_mm[0][0]', INFO:absl: 'flipper_length_mm[0][0]', INFO:absl: 'body_mass_g[0][0]'] INFO:absl: INFO:absl: dense (Dense) (None, 8) 40 ['concatenate[0][0]'] INFO:absl: INFO:absl: dense_1 (Dense) (None, 8) 72 ['dense[0][0]'] INFO:absl: INFO:absl: dense_2 (Dense) (None, 3) 27 ['dense_1[0][0]'] INFO:absl: INFO:absl:================================================================================================== INFO:absl:Total params: 139 (556.00 Byte) INFO:absl:Trainable params: 139 (556.00 Byte) INFO:absl:Non-trainable params: 0 (0.00 Byte) INFO:absl:__________________________________________________________________________________________________ WARNING: All log messages before absl::InitializeLog() is called are written to STDERR I0000 00:00:1715160045.949607 17085 device_compiler.h:186] Compiled cluster using XLA! This line is logged at most once for the lifetime of the process. 100/100 [==============================] - 2s 6ms/step - loss: 0.2421 - sparse_categorical_accuracy: 0.9250 - val_loss: 0.0074 - val_sparse_categorical_accuracy: 1.0000 INFO:absl:Feature body_mass_g has a shape dim { size: 1 } . Setting to DenseTensor. INFO:absl:Feature culmen_depth_mm has a shape dim { size: 1 } . Setting to DenseTensor. INFO:absl:Feature culmen_length_mm has a shape dim { size: 1 } . Setting to DenseTensor. INFO:absl:Feature flipper_length_mm has a shape dim { size: 1 } . Setting to DenseTensor. INFO:absl:Feature island has a shape dim { size: 1 } . Setting to DenseTensor. INFO:absl:Feature sex has a shape dim { size: 1 } . Setting to DenseTensor. INFO:absl:Feature species has a shape dim { size: 1 } . Setting to DenseTensor. INFO:absl:Function `serve_tf_examples_fn` contains input name(s) 5332, resource with unsupported characters which will be renamed to transform_features_layer_5332, model_dense_2_biasadd_readvariableop_resource in the SavedModel. INFO:tensorflow:Assets written to: pipelines/penguin-transform/Trainer/model/6/Format-Serving/assets INFO:tensorflow:Assets written to: pipelines/penguin-transform/Trainer/model/6/Format-Serving/assets INFO:absl:Writing fingerprint to pipelines/penguin-transform/Trainer/model/6/Format-Serving/fingerprint.pb INFO:absl:Training complete. Model written to pipelines/penguin-transform/Trainer/model/6/Format-Serving. ModelRun written to pipelines/penguin-transform/Trainer/model_run/6 INFO:absl:Cleaning up stateless execution info. INFO:absl:Execution 6 succeeded. INFO:absl:Cleaning up stateful execution info. INFO:absl:Deleted stateful_working_dir pipelines/penguin-transform/Trainer/.system/stateful_working_dir/04611601-4dbf-412d-8c6c-4e4864bb335a INFO:absl:Publishing output artifacts defaultdict(<class 'list'>, {'model': [Artifact(artifact: uri: "pipelines/penguin-transform/Trainer/model/6" , artifact_type: name: "Model" base_type: MODEL )], 'model_run': [Artifact(artifact: uri: "pipelines/penguin-transform/Trainer/model_run/6" , artifact_type: name: "ModelRun" )]}) for execution 6 INFO:absl:MetadataStore with DB connection initialized INFO:absl:Component Trainer is finished. INFO:absl:Component Pusher is running. INFO:absl:Running launcher for node_info { type { name: "tfx.components.pusher.component.Pusher" base_type: DEPLOY } id: "Pusher" } contexts { contexts { type { name: "pipeline" } name { field_value { string_value: "penguin-transform" } } } contexts { type { name: "pipeline_run" } name { field_value { string_value: "2024-05-08T09:20:15.209892" } } } contexts { type { name: "node" } name { field_value { string_value: "penguin-transform.Pusher" } } } } inputs { inputs { key: "model" value { channels { producer_node_query { id: "Trainer" } context_queries { type { name: "pipeline" } name { field_value { string_value: "penguin-transform" } } } context_queries { type { name: "pipeline_run" } name { field_value { string_value: "2024-05-08T09:20:15.209892" } } } context_queries { type { name: "node" } name { field_value { string_value: "penguin-transform.Trainer" } } } artifact_query { type { name: "Model" base_type: MODEL } } output_key: "model" } } } } outputs { outputs { key: "pushed_model" value { artifact_spec { type { name: "PushedModel" base_type: MODEL } } } } } parameters { parameters { key: "custom_config" value { field_value { string_value: "null" } } } parameters { key: "push_destination" value { field_value { string_value: "{\n \"filesystem\": {\n \"base_directory\": \"serving_model/penguin-transform\"\n }\n}" } } } } upstream_nodes: "Trainer" execution_options { caching_options { } } INFO:absl:MetadataStore with DB connection initialized WARNING:absl:ArtifactQuery.property_predicate is not supported. INFO:absl:[Pusher] Resolved inputs: ({'model': [Artifact(artifact: id: 12 type_id: 26 uri: "pipelines/penguin-transform/Trainer/model/6" custom_properties { key: "is_external" value { int_value: 0 } } custom_properties { key: "tfx_version" value { string_value: "1.15.0" } } state: LIVE type: "Model" create_time_since_epoch: 1715160050020 last_update_time_since_epoch: 1715160050020 , artifact_type: id: 26 name: "Model" base_type: MODEL )]},) INFO:absl:MetadataStore with DB connection initialized INFO:absl:Going to run a new execution 7 INFO:absl:Going to run a new execution: ExecutionInfo(execution_id=7, input_dict={'model': [Artifact(artifact: id: 12 type_id: 26 uri: "pipelines/penguin-transform/Trainer/model/6" custom_properties { key: "is_external" value { int_value: 0 } } custom_properties { key: "tfx_version" value { string_value: "1.15.0" } } state: LIVE type: "Model" create_time_since_epoch: 1715160050020 last_update_time_since_epoch: 1715160050020 , artifact_type: id: 26 name: "Model" base_type: MODEL )]}, output_dict=defaultdict(<class 'list'>, {'pushed_model': [Artifact(artifact: uri: "pipelines/penguin-transform/Pusher/pushed_model/7" , artifact_type: name: "PushedModel" base_type: MODEL )]}), exec_properties={'push_destination': '{\n "filesystem": {\n "base_directory": "serving_model/penguin-transform"\n }\n}', 'custom_config': 'null'}, execution_output_uri='pipelines/penguin-transform/Pusher/.system/executor_execution/7/executor_output.pb', stateful_working_dir='pipelines/penguin-transform/Pusher/.system/stateful_working_dir/b2701c8d-b1b7-4c37-a72d-ac97a07c8706', tmp_dir='pipelines/penguin-transform/Pusher/.system/executor_execution/7/.temp/', pipeline_node=node_info { type { name: "tfx.components.pusher.component.Pusher" base_type: DEPLOY } id: "Pusher" } contexts { contexts { type { name: "pipeline" } name { field_value { string_value: "penguin-transform" } } } contexts { type { name: "pipeline_run" } name { field_value { string_value: "2024-05-08T09:20:15.209892" } } } contexts { type { name: "node" } name { field_value { string_value: "penguin-transform.Pusher" } } } } inputs { inputs { key: "model" value { channels { producer_node_query { id: "Trainer" } context_queries { type { name: "pipeline" } name { field_value { string_value: "penguin-transform" } } } context_queries { type { name: "pipeline_run" } name { field_value { string_value: "2024-05-08T09:20:15.209892" } } } context_queries { type { name: "node" } name { field_value { string_value: "penguin-transform.Trainer" } } } artifact_query { type { name: "Model" base_type: MODEL } } output_key: "model" } } } } outputs { outputs { key: "pushed_model" value { artifact_spec { type { name: "PushedModel" base_type: MODEL } } } } } parameters { parameters { key: "custom_config" value { field_value { string_value: "null" } } } parameters { key: "push_destination" value { field_value { string_value: "{\n \"filesystem\": {\n \"base_directory\": \"serving_model/penguin-transform\"\n }\n}" } } } } upstream_nodes: "Trainer" execution_options { caching_options { } } , pipeline_info=id: "penguin-transform" , pipeline_run_id='2024-05-08T09:20:15.209892', top_level_pipeline_run_id=None, frontend_url=None) WARNING:absl:Pusher is going to push the model without validation. Consider using Evaluator or InfraValidator in your pipeline. INFO:absl:Model version: 1715160050 INFO:absl:Model written to serving path serving_model/penguin-transform/1715160050. INFO:absl:Model pushed to pipelines/penguin-transform/Pusher/pushed_model/7. INFO:absl:Cleaning up stateless execution info. INFO:absl:Execution 7 succeeded. INFO:absl:Cleaning up stateful execution info. INFO:absl:Deleted stateful_working_dir pipelines/penguin-transform/Pusher/.system/stateful_working_dir/b2701c8d-b1b7-4c37-a72d-ac97a07c8706 INFO:absl:Publishing output artifacts defaultdict(<class 'list'>, {'pushed_model': [Artifact(artifact: uri: "pipelines/penguin-transform/Pusher/pushed_model/7" , artifact_type: name: "PushedModel" base_type: MODEL )]}) for execution 7 INFO:absl:MetadataStore with DB connection initialized INFO:absl:Component Pusher is finished.
You should see "INFO:absl:Component Pusher is finished." if the pipeline finished successfully.
The pusher component pushes the trained model to the SERVING_MODEL_DIR
which
is the serving_model/penguin-transform
directory if you did not change
the variables in the previous steps. You can see the result from the file
browser in the left-side panel in Colab, or using the following command:
# List files in created model directory.
find {SERVING_MODEL_DIR}
serving_model/penguin-transform serving_model/penguin-transform/1715160050 serving_model/penguin-transform/1715160050/variables serving_model/penguin-transform/1715160050/variables/variables.index serving_model/penguin-transform/1715160050/variables/variables.data-00000-of-00001 serving_model/penguin-transform/1715160050/assets serving_model/penguin-transform/1715160050/keras_metadata.pb serving_model/penguin-transform/1715160050/fingerprint.pb serving_model/penguin-transform/1715160050/saved_model.pb
You can also check the signature of the generated model using the
saved_model_cli
tool.
saved_model_cli show --dir {SERVING_MODEL_DIR}/$(ls -1 {SERVING_MODEL_DIR} | sort -nr | head -1) --tag_set serve --signature_def serving_default
2024-05-08 09:20:50.856062: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered 2024-05-08 09:20:50.856134: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered 2024-05-08 09:20:50.857634: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered The given SavedModel SignatureDef contains the following input(s): inputs['examples'] tensor_info: dtype: DT_STRING shape: (-1) name: serving_default_examples:0 The given SavedModel SignatureDef contains the following output(s): outputs['output_0'] tensor_info: dtype: DT_FLOAT shape: (-1, 3) name: StatefulPartitionedCall_1:0 Method name is: tensorflow/serving/predict
Because we defined serving_default
with our own serve_tf_examples_fn
function, the signature shows that it takes a single string.
This string is a serialized string of tf.Examples and will be parsed with the
tf.io.parse_example()
function as we defined earlier (learn more about tf.Examples here).
We can load the exported model and try some inferences with a few examples.
# Find a model with the latest timestamp.
model_dirs = (item for item in os.scandir(SERVING_MODEL_DIR) if item.is_dir())
model_path = max(model_dirs, key=lambda i: int(i.name)).path
loaded_model = tf.keras.models.load_model(model_path)
inference_fn = loaded_model.signatures['serving_default']
# Prepare an example and run inference.
features = {
'culmen_length_mm': tf.train.Feature(float_list=tf.train.FloatList(value=[49.9])),
'culmen_depth_mm': tf.train.Feature(float_list=tf.train.FloatList(value=[16.1])),
'flipper_length_mm': tf.train.Feature(int64_list=tf.train.Int64List(value=[213])),
'body_mass_g': tf.train.Feature(int64_list=tf.train.Int64List(value=[5400])),
}
example_proto = tf.train.Example(features=tf.train.Features(feature=features))
examples = example_proto.SerializeToString()
result = inference_fn(examples=tf.constant([examples]))
print(result['output_0'].numpy())
[[-2.76344 -0.5130405 7.046433 ]]
The third element, which corresponds to 'Gentoo' species, is expected to be the largest among three.
Next steps
If you want to learn more about Transform component, see Transform Component guide. You can find more resources on https://www.tensorflow.org/tfx/tutorials
Please see Understanding TFX Pipelines to learn more about various concepts in TFX.