A Component-by-Component Introduction to TensorFlow Extended (TFX)
This Colab-based tutorial will interactively walk through each built-in component of TensorFlow Extended (TFX).
It covers every step in an end-to-end machine learning pipeline, from data ingestion to pushing a model to serving.
When you're done, the contents of this notebook can be automatically exported as TFX pipeline source code, which you can orchestrate with Apache Airflow and Apache Beam.
Background
This notebook demonstrates how to use TFX in a Jupyter/Colab environment. Here, we walk through the Chicago Taxi example in an interactive notebook.
Working in an interactive notebook is a useful way to become familiar with the structure of a TFX pipeline. It's also useful when doing development of your own pipelines as a lightweight development environment, but you should be aware that there are differences in the way interactive notebooks are orchestrated, and how they access metadata artifacts.
Orchestration
In a production deployment of TFX, you will use an orchestrator such as Apache Airflow, Kubeflow Pipelines, or Apache Beam to orchestrate a pre-defined pipeline graph of TFX components. In an interactive notebook, the notebook itself is the orchestrator, running each TFX component as you execute the notebook cells.
Metadata
In a production deployment of TFX, you will access metadata through the ML Metadata (MLMD) API. MLMD stores metadata properties in a database such as MySQL or SQLite, and stores the metadata payloads in a persistent store such as on your filesystem. In an interactive notebook, both properties and payloads are stored in an ephemeral SQLite database in the /tmp
directory on the Jupyter notebook or Colab server.
Setup
First, we install and import the necessary packages, set up paths, and download data.
Upgrade Pip
To avoid upgrading Pip in a system when running locally, check to make sure that we're running in Colab. Local systems can of course be upgraded separately.
try:
import colab
!pip install --upgrade pip
except:
pass
Install TFX
pip install -q -U --use-feature=2020-resolver tfx
Did you restart the runtime?
If you are using Google Colab, the first time that you run the cell above, you must restart the runtime (Runtime > Restart runtime ...). This is because of the way that Colab loads packages.
Import packages
We import necessary packages, including standard TFX component classes.
import os
import pprint
import tempfile
import urllib
import absl
import tensorflow as tf
import tensorflow_model_analysis as tfma
tf.get_logger().propagate = False
pp = pprint.PrettyPrinter()
import tfx
from tfx.components import CsvExampleGen
from tfx.components import Evaluator
from tfx.components import ExampleValidator
from tfx.components import Pusher
from tfx.components import ResolverNode
from tfx.components import SchemaGen
from tfx.components import StatisticsGen
from tfx.components import Trainer
from tfx.components import Transform
from tfx.dsl.experimental import latest_blessed_model_resolver
from tfx.orchestration import metadata
from tfx.orchestration import pipeline
from tfx.orchestration.experimental.interactive.interactive_context import InteractiveContext
from tfx.proto import pusher_pb2
from tfx.proto import trainer_pb2
from tfx.proto.evaluator_pb2 import SingleSlicingSpec
from tfx.utils.dsl_utils import external_input
from tfx.types import Channel
from tfx.types.standard_artifacts import Model
from tfx.types.standard_artifacts import ModelBlessing
%load_ext tfx.orchestration.experimental.interactive.notebook_extensions.skip
WARNING:absl:RuntimeParameter is only supported on Cloud-based DAG runner currently.
Let's check the library versions.
print('TensorFlow version: {}'.format(tf.__version__))
print('TFX version: {}'.format(tfx.__version__))
TensorFlow version: 2.3.1 TFX version: 0.25.0
Set up pipeline paths
# This is the root directory for your TFX pip package installation.
_tfx_root = tfx.__path__[0]
# This is the directory containing the TFX Chicago Taxi Pipeline example.
_taxi_root = os.path.join(_tfx_root, 'examples/chicago_taxi_pipeline')
# This is the path where your model will be pushed for serving.
_serving_model_dir = os.path.join(
tempfile.mkdtemp(), 'serving_model/taxi_simple')
# Set up logging.
absl.logging.set_verbosity(absl.logging.INFO)
Download example data
We download the example dataset for use in our TFX pipeline.
The dataset we're using is the Taxi Trips dataset released by the City of Chicago. The columns in this dataset are:
pickup_community_area | fare | trip_start_month |
trip_start_hour | trip_start_day | trip_start_timestamp |
pickup_latitude | pickup_longitude | dropoff_latitude |
dropoff_longitude | trip_miles | pickup_census_tract |
dropoff_census_tract | payment_type | company |
trip_seconds | dropoff_community_area | tips |
With this dataset, we will build a model that predicts the tips
of a trip.
_data_root = tempfile.mkdtemp(prefix='tfx-data')
DATA_PATH = 'https://raw.githubusercontent.com/tensorflow/tfx/master/tfx/examples/chicago_taxi_pipeline/data/simple/data.csv'
_data_filepath = os.path.join(_data_root, "data.csv")
urllib.request.urlretrieve(DATA_PATH, _data_filepath)
('/tmp/tfx-dataq0c865u1/data.csv', <http.client.HTTPMessage at 0x7f475c9de780>)
Take a quick look at the CSV file.
head {_data_filepath}
pickup_community_area,fare,trip_start_month,trip_start_hour,trip_start_day,trip_start_timestamp,pickup_latitude,pickup_longitude,dropoff_latitude,dropoff_longitude,trip_miles,pickup_census_tract,dropoff_census_tract,payment_type,company,trip_seconds,dropoff_community_area,tips ,12.45,5,19,6,1400269500,,,,,0.0,,,Credit Card,Chicago Elite Cab Corp. (Chicago Carriag,0,,0.0 ,0,3,19,5,1362683700,,,,,0,,,Unknown,Chicago Elite Cab Corp.,300,,0 60,27.05,10,2,3,1380593700,41.836150155,-87.648787952,,,12.6,,,Cash,Taxi Affiliation Services,1380,,0.0 10,5.85,10,1,2,1382319000,41.985015101,-87.804532006,,,0.0,,,Cash,Taxi Affiliation Services,180,,0.0 14,16.65,5,7,5,1369897200,41.968069,-87.721559063,,,0.0,,,Cash,Dispatch Taxi Affiliation,1080,,0.0 13,16.45,11,12,3,1446554700,41.983636307,-87.723583185,,,6.9,,,Cash,,780,,0.0 16,32.05,12,1,1,1417916700,41.953582125,-87.72345239,,,15.4,,,Cash,,1200,,0.0 30,38.45,10,10,5,1444301100,41.839086906,-87.714003807,,,14.6,,,Cash,,2580,,0.0 11,14.65,1,1,3,1358213400,41.978829526,-87.771166703,,,5.81,,,Cash,,1080,,0.0
Disclaimer: This site provides applications using data that has been modified for use from its original source, www.cityofchicago.org, the official website of the City of Chicago. The City of Chicago makes no claims as to the content, accuracy, timeliness, or completeness of any of the data provided at this site. The data provided at this site is subject to change at any time. It is understood that the data provided at this site is being used at one’s own risk.
Create the InteractiveContext
Last, we create an InteractiveContext, which will allow us to run TFX components interactively in this notebook.
# Here, we create an InteractiveContext using default parameters. This will
# use a temporary directory with an ephemeral ML Metadata database instance.
# To use your own pipeline root or database, the optional properties
# `pipeline_root` and `metadata_connection_config` may be passed to
# InteractiveContext. Calls to InteractiveContext are no-ops outside of the
# notebook.
context = InteractiveContext()
WARNING:absl:InteractiveContext pipeline_root argument not provided: using temporary directory /tmp/tfx-interactive-2020-11-26T10_13_22.805724-fe5em05p as root for pipeline outputs. WARNING:absl:InteractiveContext metadata_connection_config not provided: using SQLite ML Metadata database at /tmp/tfx-interactive-2020-11-26T10_13_22.805724-fe5em05p/metadata.sqlite.
Run TFX components interactively
In the cells that follow, we create TFX components one-by-one, run each of them, and visualize their output artifacts.
ExampleGen
The ExampleGen
component is usually at the start of a TFX pipeline. It will:
- Split data into training and evaluation sets (by default, 2/3 training + 1/3 eval)
- Convert data into the
tf.Example
format - Copy data into the
_tfx_root
directory for other components to access
ExampleGen
takes as input the path to your data source. In our case, this is the _data_root
path that contains the downloaded CSV.
example_gen = CsvExampleGen(input=external_input(_data_root))
context.run(example_gen)
WARNING:tensorflow:From <ipython-input-1-2e0190c2dd16>:1: external_input (from tfx.utils.dsl_utils) is deprecated and will be removed in a future version. Instructions for updating: external_input is deprecated, directly pass the uri to ExampleGen. Warning:absl:The "input" argument to the CsvExampleGen component has been deprecated by "input_base". Please update your usage as support for this argument will be removed soon. INFO:absl:Running driver for CsvExampleGen INFO:absl:MetadataStore with DB connection initialized INFO:absl:select span and version = (0, None) INFO:absl:latest span and version = (0, None) INFO:absl:Running executor for CsvExampleGen INFO:absl:Generating examples. WARNING:apache_beam.runners.interactive.interactive_environment:Dependencies required for Interactive Beam PCollection visualization are not available, please use: `pip install apache-beam[interactive]` to install necessary dependencies to enable all data visualization features. INFO:absl:Processing input csv data /tmp/tfx-dataq0c865u1/* to TFExample. WARNING:apache_beam.io.tfrecordio:Couldn't find python-snappy so the implementation of _TFRecordUtil._masked_crc32c is not as fast as it could be. INFO:absl:Examples generated. INFO:absl:Running publisher for CsvExampleGen INFO:absl:MetadataStore with DB connection initialized
Let's examine the output artifacts of ExampleGen
. This component produces two artifacts, training examples and evaluation examples:
artifact = example_gen.outputs['examples'].get()[0]
print(artifact.split_names, artifact.uri)
["train", "eval"] /tmp/tfx-interactive-2020-11-26T10_13_22.805724-fe5em05p/CsvExampleGen/examples/1
We can also take a look at the first three training examples:
# Get the URI of the output artifact representing the training examples, which is a directory
train_uri = os.path.join(example_gen.outputs['examples'].get()[0].uri, 'train')
# Get the list of files in this directory (all compressed TFRecord files)
tfrecord_filenames = [os.path.join(train_uri, name)
for name in os.listdir(train_uri)]
# Create a `TFRecordDataset` to read these files
dataset = tf.data.TFRecordDataset(tfrecord_filenames, compression_type="GZIP")
# Iterate over the first 3 records and decode them.
for tfrecord in dataset.take(3):
serialized_example = tfrecord.numpy()
example = tf.train.Example()
example.ParseFromString(serialized_example)
pp.pprint(example)
features { feature { key: "company" value { bytes_list { value: "Chicago Elite Cab Corp. (Chicago Carriag" } } } feature { key: "dropoff_census_tract" value { int64_list { } } } feature { key: "dropoff_community_area" value { int64_list { } } } feature { key: "dropoff_latitude" value { float_list { } } } feature { key: "dropoff_longitude" value { float_list { } } } feature { key: "fare" value { float_list { value: 12.449999809265137 } } } feature { key: "payment_type" value { bytes_list { value: "Credit Card" } } } feature { key: "pickup_census_tract" value { int64_list { } } } feature { key: "pickup_community_area" value { int64_list { } } } feature { key: "pickup_latitude" value { float_list { } } } feature { key: "pickup_longitude" value { float_list { } } } feature { key: "tips" value { float_list { value: 0.0 } } } feature { key: "trip_miles" value { float_list { value: 0.0 } } } feature { key: "trip_seconds" value { int64_list { value: 0 } } } feature { key: "trip_start_day" value { int64_list { value: 6 } } } feature { key: "trip_start_hour" value { int64_list { value: 19 } } } feature { key: "trip_start_month" value { int64_list { value: 5 } } } feature { key: "trip_start_timestamp" value { int64_list { value: 1400269500 } } } } features { feature { key: "company" value { bytes_list { value: "Taxi Affiliation Services" } } } feature { key: "dropoff_census_tract" value { int64_list { } } } feature { key: "dropoff_community_area" value { int64_list { } } } feature { key: "dropoff_latitude" value { float_list { } } } feature { key: "dropoff_longitude" value { float_list { } } } feature { key: "fare" value { float_list { value: 27.049999237060547 } } } feature { key: "payment_type" value { bytes_list { value: "Cash" } } } feature { key: "pickup_census_tract" value { int64_list { } } } feature { key: "pickup_community_area" value { int64_list { value: 60 } } } feature { key: "pickup_latitude" value { float_list { value: 41.836151123046875 } } } feature { key: "pickup_longitude" value { float_list { value: -87.64878845214844 } } } feature { key: "tips" value { float_list { value: 0.0 } } } feature { key: "trip_miles" value { float_list { value: 12.600000381469727 } } } feature { key: "trip_seconds" value { int64_list { value: 1380 } } } feature { key: "trip_start_day" value { int64_list { value: 3 } } } feature { key: "trip_start_hour" value { int64_list { value: 2 } } } feature { key: "trip_start_month" value { int64_list { value: 10 } } } feature { key: "trip_start_timestamp" value { int64_list { value: 1380593700 } } } } features { feature { key: "company" value { bytes_list { } } } feature { key: "dropoff_census_tract" value { int64_list { } } } feature { key: "dropoff_community_area" value { int64_list { } } } feature { key: "dropoff_latitude" value { float_list { } } } feature { key: "dropoff_longitude" value { float_list { } } } feature { key: "fare" value { float_list { value: 16.450000762939453 } } } feature { key: "payment_type" value { bytes_list { value: "Cash" } } } feature { key: "pickup_census_tract" value { int64_list { } } } feature { key: "pickup_community_area" value { int64_list { value: 13 } } } feature { key: "pickup_latitude" value { float_list { value: 41.98363494873047 } } } feature { key: "pickup_longitude" value { float_list { value: -87.72357940673828 } } } feature { key: "tips" value { float_list { value: 0.0 } } } feature { key: "trip_miles" value { float_list { value: 6.900000095367432 } } } feature { key: "trip_seconds" value { int64_list { value: 780 } } } feature { key: "trip_start_day" value { int64_list { value: 3 } } } feature { key: "trip_start_hour" value { int64_list { value: 12 } } } feature { key: "trip_start_month" value { int64_list { value: 11 } } } feature { key: "trip_start_timestamp" value { int64_list { value: 1446554700 } } } }
Now that ExampleGen
has finished ingesting the data, the next step is data analysis.
StatisticsGen
The StatisticsGen
component computes statistics over your dataset for data analysis, as well as for use in downstream components. It uses the TensorFlow Data Validation library.
StatisticsGen
takes as input the dataset we just ingested using ExampleGen
.
statistics_gen = StatisticsGen(
examples=example_gen.outputs['examples'])
context.run(statistics_gen)
INFO:absl:Excluding no splits because exclude_splits is not set. INFO:absl:Running driver for StatisticsGen INFO:absl:MetadataStore with DB connection initialized INFO:absl:Running executor for StatisticsGen INFO:absl:Generating statistics for split train. INFO:absl:Statistics for split train written to /tmp/tfx-interactive-2020-11-26T10_13_22.805724-fe5em05p/StatisticsGen/statistics/2/train. INFO:absl:Generating statistics for split eval. INFO:absl:Statistics for split eval written to /tmp/tfx-interactive-2020-11-26T10_13_22.805724-fe5em05p/StatisticsGen/statistics/2/eval. INFO:absl:Running publisher for StatisticsGen INFO:absl:MetadataStore with DB connection initialized
After StatisticsGen
finishes running, we can visualize the outputted statistics. Try playing with the different plots!
context.show(statistics_gen.outputs['statistics'])
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_data_validation/utils/stats_util.py:247: tf_record_iterator (from tensorflow.python.lib.io.tf_record) is deprecated and will be removed in a future version. Instructions for updating: Use eager execution and: `tf.data.TFRecordDataset(path)`
SchemaGen
The SchemaGen
component generates a schema based on your data statistics. (A schema defines the expected bounds, types, and properties of the features in your dataset.) It also uses the TensorFlow Data Validation library.
SchemaGen
will take as input the statistics that we generated with StatisticsGen
, looking at the training split by default.
schema_gen = SchemaGen(
statistics=statistics_gen.outputs['statistics'],
infer_feature_shape=False)
context.run(schema_gen)
INFO:absl:Excluding no splits because exclude_splits is not set. INFO:absl:Running driver for SchemaGen INFO:absl:MetadataStore with DB connection initialized INFO:absl:Running executor for SchemaGen INFO:absl:Processing schema from statistics for split train. INFO:absl:Processing schema from statistics for split eval. INFO:absl:Schema written to /tmp/tfx-interactive-2020-11-26T10_13_22.805724-fe5em05p/SchemaGen/schema/3/schema.pbtxt. INFO:absl:Running publisher for SchemaGen INFO:absl:MetadataStore with DB connection initialized
After SchemaGen
finishes running, we can visualize the generated schema as a table.
context.show(schema_gen.outputs['schema'])
/tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_data_validation/utils/display_util.py:151: FutureWarning: Passing a negative integer is deprecated in version 1.0 and will not be supported in future version. Instead, use None to not limit the column width. pd.set_option('max_colwidth', -1)
Each feature in your dataset shows up as a row in the schema table, alongside its properties. The schema also captures all the values that a categorical feature takes on, denoted as its domain.
To learn more about schemas, see the SchemaGen documentation.
ExampleValidator
The ExampleValidator
component detects anomalies in your data, based on the expectations defined by the schema. It also uses the TensorFlow Data Validation library.
ExampleValidator
will take as input the statistics from StatisticsGen
, and the schema from SchemaGen
.
example_validator = ExampleValidator(
statistics=statistics_gen.outputs['statistics'],
schema=schema_gen.outputs['schema'])
context.run(example_validator)
INFO:absl:Excluding no splits because exclude_splits is not set. INFO:absl:Running driver for ExampleValidator INFO:absl:MetadataStore with DB connection initialized INFO:absl:Running executor for ExampleValidator INFO:absl:Validating schema against the computed statistics for split train. INFO:absl:Validation complete for split train. Anomalies written to /tmp/tfx-interactive-2020-11-26T10_13_22.805724-fe5em05p/ExampleValidator/anomalies/4/train. INFO:absl:Validating schema against the computed statistics for split eval. INFO:absl:Validation complete for split eval. Anomalies written to /tmp/tfx-interactive-2020-11-26T10_13_22.805724-fe5em05p/ExampleValidator/anomalies/4/eval. INFO:absl:Running publisher for ExampleValidator INFO:absl:MetadataStore with DB connection initialized
After ExampleValidator
finishes running, we can visualize the anomalies as a table.
context.show(example_validator.outputs['anomalies'])
In the anomalies table, we can see that there are no anomalies. This is what we'd expect, since this the first dataset that we've analyzed and the schema is tailored to it. You should review this schema -- anything unexpected means an anomaly in the data. Once reviewed, the schema can be used to guard future data, and anomalies produced here can be used to debug model performance, understand how your data evolves over time, and identify data errors.
Transform
The Transform
component performs feature engineering for both training and serving. It uses the TensorFlow Transform library.
Transform
will take as input the data from ExampleGen
, the schema from SchemaGen
, as well as a module that contains user-defined Transform code.
Let's see an example of user-defined Transform code below (for an introduction to the TensorFlow Transform APIs, see the tutorial). First, we define a few constants for feature engineering:
_taxi_constants_module_file = 'taxi_constants.py'
%%writefile {_taxi_constants_module_file}
# Categorical features are assumed to each have a maximum value in the dataset.
MAX_CATEGORICAL_FEATURE_VALUES = [24, 31, 12]
CATEGORICAL_FEATURE_KEYS = [
'trip_start_hour', 'trip_start_day', 'trip_start_month',
'pickup_census_tract', 'dropoff_census_tract', 'pickup_community_area',
'dropoff_community_area'
]
DENSE_FLOAT_FEATURE_KEYS = ['trip_miles', 'fare', 'trip_seconds']
# Number of buckets used by tf.transform for encoding each feature.
FEATURE_BUCKET_COUNT = 10
BUCKET_FEATURE_KEYS = [
'pickup_latitude', 'pickup_longitude', 'dropoff_latitude',
'dropoff_longitude'
]
# Number of vocabulary terms used for encoding VOCAB_FEATURES by tf.transform
VOCAB_SIZE = 1000
# Count of out-of-vocab buckets in which unrecognized VOCAB_FEATURES are hashed.
OOV_SIZE = 10
VOCAB_FEATURE_KEYS = [
'payment_type',
'company',
]
# Keys
LABEL_KEY = 'tips'
FARE_KEY = 'fare'
def transformed_name(key):
return key + '_xf'
Writing taxi_constants.py
Next, we write a preprocessing_fn
that takes in raw data as input, and returns transformed features that our model can train on:
_taxi_transform_module_file = 'taxi_transform.py'
%%writefile {_taxi_transform_module_file}
import tensorflow as tf
import tensorflow_transform as tft
import taxi_constants
_DENSE_FLOAT_FEATURE_KEYS = taxi_constants.DENSE_FLOAT_FEATURE_KEYS
_VOCAB_FEATURE_KEYS = taxi_constants.VOCAB_FEATURE_KEYS
_VOCAB_SIZE = taxi_constants.VOCAB_SIZE
_OOV_SIZE = taxi_constants.OOV_SIZE
_FEATURE_BUCKET_COUNT = taxi_constants.FEATURE_BUCKET_COUNT
_BUCKET_FEATURE_KEYS = taxi_constants.BUCKET_FEATURE_KEYS
_CATEGORICAL_FEATURE_KEYS = taxi_constants.CATEGORICAL_FEATURE_KEYS
_FARE_KEY = taxi_constants.FARE_KEY
_LABEL_KEY = taxi_constants.LABEL_KEY
_transformed_name = taxi_constants.transformed_name
def preprocessing_fn(inputs):
"""tf.transform's callback function for preprocessing inputs.
Args:
inputs: map from feature keys to raw not-yet-transformed features.
Returns:
Map from string feature key to transformed feature operations.
"""
outputs = {}
for key in _DENSE_FLOAT_FEATURE_KEYS:
# Preserve this feature as a dense float, setting nan's to the mean.
outputs[_transformed_name(key)] = tft.scale_to_z_score(
_fill_in_missing(inputs[key]))
for key in _VOCAB_FEATURE_KEYS:
# Build a vocabulary for this feature.
outputs[_transformed_name(key)] = tft.compute_and_apply_vocabulary(
_fill_in_missing(inputs[key]),
top_k=_VOCAB_SIZE,
num_oov_buckets=_OOV_SIZE)
for key in _BUCKET_FEATURE_KEYS:
outputs[_transformed_name(key)] = tft.bucketize(
_fill_in_missing(inputs[key]), _FEATURE_BUCKET_COUNT)
for key in _CATEGORICAL_FEATURE_KEYS:
outputs[_transformed_name(key)] = _fill_in_missing(inputs[key])
# Was this passenger a big tipper?
taxi_fare = _fill_in_missing(inputs[_FARE_KEY])
tips = _fill_in_missing(inputs[_LABEL_KEY])
outputs[_transformed_name(_LABEL_KEY)] = tf.where(
tf.math.is_nan(taxi_fare),
tf.cast(tf.zeros_like(taxi_fare), tf.int64),
# Test if the tip was > 20% of the fare.
tf.cast(
tf.greater(tips, tf.multiply(taxi_fare, tf.constant(0.2))), tf.int64))
return outputs
def _fill_in_missing(x):
"""Replace missing values in a SparseTensor.
Fills in missing values of `x` with '' or 0, and converts to a dense tensor.
Args:
x: A `SparseTensor` of rank 2. Its dense shape should have size at most 1
in the second dimension.
Returns:
A rank 1 tensor where missing values of `x` have been filled in.
"""
default_value = '' if x.dtype == tf.string else 0
return tf.squeeze(
tf.sparse.to_dense(
tf.SparseTensor(x.indices, x.values, [x.dense_shape[0], 1]),
default_value),
axis=1)
Writing taxi_transform.py
Now, we pass in this feature engineering code to the Transform
component and run it to transform your data.
transform = Transform(
examples=example_gen.outputs['examples'],
schema=schema_gen.outputs['schema'],
module_file=os.path.abspath(_taxi_transform_module_file))
context.run(transform)
INFO:absl:Running driver for Transform INFO:absl:MetadataStore with DB connection initialized INFO:absl:Running executor for Transform INFO:absl:Analyze the 'train' split and transform all splits when splits_config is not set. Warning:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tfx/components/transform/executor.py:528: Schema (from tensorflow_transform.tf_metadata.dataset_schema) is deprecated and will be removed in a future version. Instructions for updating: Schema is a deprecated, use schema_utils.schema_from_feature_spec to create a `Schema` INFO:absl:Feature company has no shape. Setting to VarLenSparseTensor. INFO:absl:Feature payment_type has no shape. Setting to VarLenSparseTensor. INFO:absl:Feature dropoff_census_tract has no shape. Setting to VarLenSparseTensor. INFO:absl:Feature dropoff_community_area has no shape. Setting to VarLenSparseTensor. INFO:absl:Feature dropoff_latitude has no shape. Setting to VarLenSparseTensor. INFO:absl:Feature dropoff_longitude has no shape. Setting to VarLenSparseTensor. INFO:absl:Feature fare has no shape. Setting to VarLenSparseTensor. INFO:absl:Feature pickup_census_tract has no shape. Setting to VarLenSparseTensor. INFO:absl:Feature pickup_community_area has no shape. Setting to VarLenSparseTensor. INFO:absl:Feature pickup_latitude has no shape. Setting to VarLenSparseTensor. INFO:absl:Feature pickup_longitude has no shape. Setting to VarLenSparseTensor. INFO:absl:Feature tips has no shape. Setting to VarLenSparseTensor. INFO:absl:Feature trip_miles has no shape. Setting to VarLenSparseTensor. INFO:absl:Feature trip_seconds has no shape. Setting to VarLenSparseTensor. INFO:absl:Feature trip_start_day has no shape. Setting to VarLenSparseTensor. INFO:absl:Feature trip_start_hour has no shape. Setting to VarLenSparseTensor. INFO:absl:Feature trip_start_month has no shape. Setting to VarLenSparseTensor. INFO:absl:Feature trip_start_timestamp has no shape. Setting to VarLenSparseTensor. Warning:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_transform/tf_utils.py:250: Tensor.experimental_ref (from tensorflow.python.framework.ops) is deprecated and will be removed in a future version. Instructions for updating: Use ref() instead. INFO:absl:Feature company has no shape. Setting to VarLenSparseTensor. INFO:absl:Feature payment_type has no shape. Setting to VarLenSparseTensor. INFO:absl:Feature dropoff_census_tract has no shape. Setting to VarLenSparseTensor. INFO:absl:Feature dropoff_community_area has no shape. Setting to VarLenSparseTensor. INFO:absl:Feature dropoff_latitude has no shape. Setting to VarLenSparseTensor. INFO:absl:Feature dropoff_longitude has no shape. Setting to VarLenSparseTensor. INFO:absl:Feature fare has no shape. Setting to VarLenSparseTensor. INFO:absl:Feature pickup_census_tract has no shape. Setting to VarLenSparseTensor. INFO:absl:Feature pickup_community_area has no shape. Setting to VarLenSparseTensor. INFO:absl:Feature pickup_latitude has no shape. Setting to VarLenSparseTensor. INFO:absl:Feature pickup_longitude has no shape. Setting to VarLenSparseTensor. INFO:absl:Feature tips has no shape. Setting to VarLenSparseTensor. INFO:absl:Feature trip_miles has no shape. Setting to VarLenSparseTensor. INFO:absl:Feature trip_seconds has no shape. Setting to VarLenSparseTensor. INFO:absl:Feature trip_start_day has no shape. Setting to VarLenSparseTensor. INFO:absl:Feature trip_start_hour has no shape. Setting to VarLenSparseTensor. INFO:absl:Feature trip_start_month has no shape. Setting to VarLenSparseTensor. INFO:absl:Feature trip_start_timestamp has no shape. Setting to VarLenSparseTensor. INFO:absl:Feature company has no shape. Setting to VarLenSparseTensor. INFO:absl:Feature payment_type has no shape. Setting to VarLenSparseTensor. INFO:absl:Feature dropoff_census_tract has no shape. Setting to VarLenSparseTensor. INFO:absl:Feature dropoff_community_area has no shape. Setting to VarLenSparseTensor. INFO:absl:Feature dropoff_latitude has no shape. Setting to VarLenSparseTensor. INFO:absl:Feature dropoff_longitude has no shape. Setting to VarLenSparseTensor. INFO:absl:Feature fare has no shape. Setting to VarLenSparseTensor. INFO:absl:Feature pickup_census_tract has no shape. Setting to VarLenSparseTensor. INFO:absl:Feature pickup_community_area has no shape. Setting to VarLenSparseTensor. INFO:absl:Feature pickup_latitude has no shape. Setting to VarLenSparseTensor. INFO:absl:Feature pickup_longitude has no shape. Setting to VarLenSparseTensor. INFO:absl:Feature tips has no shape. Setting to VarLenSparseTensor. INFO:absl:Feature trip_miles has no shape. Setting to VarLenSparseTensor. INFO:absl:Feature trip_seconds has no shape. Setting to VarLenSparseTensor. INFO:absl:Feature trip_start_day has no shape. Setting to VarLenSparseTensor. INFO:absl:Feature trip_start_hour has no shape. Setting to VarLenSparseTensor. INFO:absl:Feature trip_start_month has no shape. Setting to VarLenSparseTensor. INFO:absl:Feature trip_start_timestamp has no shape. Setting to VarLenSparseTensor. INFO:absl:Feature company has no shape. Setting to VarLenSparseTensor. INFO:absl:Feature payment_type has no shape. Setting to VarLenSparseTensor. INFO:absl:Feature dropoff_census_tract has no shape. Setting to VarLenSparseTensor. INFO:absl:Feature dropoff_community_area has no shape. Setting to VarLenSparseTensor. INFO:absl:Feature dropoff_latitude has no shape. Setting to VarLenSparseTensor. INFO:absl:Feature dropoff_longitude has no shape. Setting to VarLenSparseTensor. INFO:absl:Feature fare has no shape. Setting to VarLenSparseTensor. INFO:absl:Feature pickup_census_tract has no shape. Setting to VarLenSparseTensor. INFO:absl:Feature pickup_community_area has no shape. Setting to VarLenSparseTensor. INFO:absl:Feature pickup_latitude has no shape. Setting to VarLenSparseTensor. INFO:absl:Feature pickup_longitude has no shape. Setting to VarLenSparseTensor. INFO:absl:Feature tips has no shape. Setting to VarLenSparseTensor. INFO:absl:Feature trip_miles has no shape. Setting to VarLenSparseTensor. INFO:absl:Feature trip_seconds has no shape. Setting to VarLenSparseTensor. INFO:absl:Feature trip_start_day has no shape. Setting to VarLenSparseTensor. INFO:absl:Feature trip_start_hour has no shape. Setting to VarLenSparseTensor. INFO:absl:Feature trip_start_month has no shape. Setting to VarLenSparseTensor. INFO:absl:Feature trip_start_timestamp has no shape. Setting to VarLenSparseTensor. INFO:absl:Feature company has no shape. Setting to VarLenSparseTensor. INFO:absl:Feature payment_type has no shape. Setting to VarLenSparseTensor. INFO:absl:Feature dropoff_census_tract has no shape. Setting to VarLenSparseTensor. INFO:absl:Feature dropoff_community_area has no shape. Setting to VarLenSparseTensor. INFO:absl:Feature dropoff_latitude has no shape. Setting to VarLenSparseTensor. INFO:absl:Feature dropoff_longitude has no shape. Setting to VarLenSparseTensor. INFO:absl:Feature fare has no shape. Setting to VarLenSparseTensor. INFO:absl:Feature pickup_census_tract has no shape. Setting to VarLenSparseTensor. INFO:absl:Feature pickup_community_area has no shape. Setting to VarLenSparseTensor. INFO:absl:Feature pickup_latitude has no shape. Setting to VarLenSparseTensor. INFO:absl:Feature pickup_longitude has no shape. Setting to VarLenSparseTensor. INFO:absl:Feature tips has no shape. Setting to VarLenSparseTensor. INFO:absl:Feature trip_miles has no shape. Setting to VarLenSparseTensor. INFO:absl:Feature trip_seconds has no shape. Setting to VarLenSparseTensor. INFO:absl:Feature trip_start_day has no shape. Setting to VarLenSparseTensor. INFO:absl:Feature trip_start_hour has no shape. Setting to VarLenSparseTensor. INFO:absl:Feature trip_start_month has no shape. Setting to VarLenSparseTensor. INFO:absl:Feature trip_start_timestamp has no shape. Setting to VarLenSparseTensor. INFO:absl:Feature company has no shape. Setting to VarLenSparseTensor. INFO:absl:Feature payment_type has no shape. Setting to VarLenSparseTensor. INFO:absl:Feature dropoff_census_tract has no shape. Setting to VarLenSparseTensor. INFO:absl:Feature dropoff_community_area has no shape. Setting to VarLenSparseTensor. INFO:absl:Feature dropoff_latitude has no shape. Setting to VarLenSparseTensor. INFO:absl:Feature dropoff_longitude has no shape. Setting to VarLenSparseTensor. INFO:absl:Feature fare has no shape. Setting to VarLenSparseTensor. INFO:absl:Feature pickup_census_tract has no shape. Setting to VarLenSparseTensor. INFO:absl:Feature pickup_community_area has no shape. Setting to VarLenSparseTensor. INFO:absl:Feature pickup_latitude has no shape. Setting to VarLenSparseTensor. INFO:absl:Feature pickup_longitude has no shape. Setting to VarLenSparseTensor. INFO:absl:Feature tips has no shape. Setting to VarLenSparseTensor. INFO:absl:Feature trip_miles has no shape. Setting to VarLenSparseTensor. INFO:absl:Feature trip_seconds has no shape. Setting to VarLenSparseTensor. INFO:absl:Feature trip_start_day has no shape. Setting to VarLenSparseTensor. INFO:absl:Feature trip_start_hour has no shape. Setting to VarLenSparseTensor. INFO:absl:Feature trip_start_month has no shape. Setting to VarLenSparseTensor. INFO:absl:Feature trip_start_timestamp has no shape. Setting to VarLenSparseTensor. Warning:tensorflow:TFT beam APIs accept both the TFXIO format and the instance dict format now. There is no need to set use_tfxio any more and it will be removed soon. Warning:root:This output type hint will be ignored and not used for type-checking purposes. Typically, output type hints for a PTransform are single (or nested) types wrapped by a PCollection, PDone, or None. Got: Tuple[Dict[str, Union[NoneType, _Dataset]], Union[Dict[str, Dict[str, PCollection]], NoneType]] instead. WARNING:root:This output type hint will be ignored and not used for type-checking purposes. Typically, output type hints for a PTransform are single (or nested) types wrapped by a PCollection, PDone, or None. Got: Tuple[Dict[str, Union[NoneType, _Dataset]], Union[Dict[str, Dict[str, PCollection]], NoneType]] instead. Warning:tensorflow:Tensorflow version (2.3.1) found. Note that Tensorflow Transform support for TF 2.0 is currently in beta, and features such as tf.function may not work as intended. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/saved_model/signature_def_utils_impl.py:201: build_tensor_info (from tensorflow.python.saved_model.utils_impl) is deprecated and will be removed in a future version. Instructions for updating: This function will only be available through the v1 compatibility library as tf.compat.v1.saved_model.utils.build_tensor_info or tf.compat.v1.saved_model.build_tensor_info. INFO:tensorflow:Assets added to graph. INFO:tensorflow:No assets to write. WARNING:tensorflow:Issue encountered when serializing tft_mapper_use. Type is unsupported, or the types of the items don't match field type in CollectionDef. Note this is a warning and probably safe to ignore. 'Counter' object has no attribute 'name' INFO:tensorflow:SavedModel written to: /tmp/tfx-interactive-2020-11-26T10_13_22.805724-fe5em05p/Transform/transform_graph/5/.temp_path/tftransform_tmp/2b84bdf4729f420da7592003458afbea/saved_model.pb INFO:tensorflow:Assets added to graph. INFO:tensorflow:No assets to write. WARNING:tensorflow:Issue encountered when serializing tft_mapper_use. Type is unsupported, or the types of the items don't match field type in CollectionDef. Note this is a warning and probably safe to ignore. 'Counter' object has no attribute 'name' INFO:tensorflow:SavedModel written to: /tmp/tfx-interactive-2020-11-26T10_13_22.805724-fe5em05p/Transform/transform_graph/5/.temp_path/tftransform_tmp/41fa9edf8d3540c4bb4163ed11673ec1/saved_model.pb INFO:absl:Feature company has no shape. Setting to VarLenSparseTensor. INFO:absl:Feature payment_type has no shape. Setting to VarLenSparseTensor. INFO:absl:Feature dropoff_census_tract has no shape. Setting to VarLenSparseTensor. INFO:absl:Feature dropoff_community_area has no shape. Setting to VarLenSparseTensor. INFO:absl:Feature dropoff_latitude has no shape. Setting to VarLenSparseTensor. INFO:absl:Feature dropoff_longitude has no shape. Setting to VarLenSparseTensor. INFO:absl:Feature fare has no shape. Setting to VarLenSparseTensor. INFO:absl:Feature pickup_census_tract has no shape. Setting to VarLenSparseTensor. INFO:absl:Feature pickup_community_area has no shape. Setting to VarLenSparseTensor. INFO:absl:Feature pickup_latitude has no shape. Setting to VarLenSparseTensor. INFO:absl:Feature pickup_longitude has no shape. Setting to VarLenSparseTensor. INFO:absl:Feature tips has no shape. Setting to VarLenSparseTensor. INFO:absl:Feature trip_miles has no shape. Setting to VarLenSparseTensor. INFO:absl:Feature trip_seconds has no shape. Setting to VarLenSparseTensor. INFO:absl:Feature trip_start_day has no shape. Setting to VarLenSparseTensor. INFO:absl:Feature trip_start_hour has no shape. Setting to VarLenSparseTensor. INFO:absl:Feature trip_start_month has no shape. Setting to VarLenSparseTensor. INFO:absl:Feature trip_start_timestamp has no shape. Setting to VarLenSparseTensor. Warning:tensorflow:Tensorflow version (2.3.1) found. Note that Tensorflow Transform support for TF 2.0 is currently in beta, and features such as tf.function may not work as intended. Warning:apache_beam.typehints.typehints:Ignoring send_type hint: <class 'NoneType'> WARNING:apache_beam.typehints.typehints:Ignoring return_type hint: <class 'NoneType'> WARNING:apache_beam.typehints.typehints:Ignoring send_type hint: <class 'NoneType'> WARNING:apache_beam.typehints.typehints:Ignoring return_type hint: <class 'NoneType'> WARNING:apache_beam.typehints.typehints:Ignoring send_type hint: <class 'NoneType'> WARNING:apache_beam.typehints.typehints:Ignoring return_type hint: <class 'NoneType'> INFO:absl:Feature company has no shape. Setting to VarLenSparseTensor. INFO:absl:Feature payment_type has no shape. Setting to VarLenSparseTensor. INFO:absl:Feature dropoff_census_tract has no shape. Setting to VarLenSparseTensor. INFO:absl:Feature dropoff_community_area has no shape. Setting to VarLenSparseTensor. INFO:absl:Feature dropoff_latitude has no shape. Setting to VarLenSparseTensor. INFO:absl:Feature dropoff_longitude has no shape. Setting to VarLenSparseTensor. INFO:absl:Feature fare has no shape. Setting to VarLenSparseTensor. INFO:absl:Feature pickup_census_tract has no shape. Setting to VarLenSparseTensor. INFO:absl:Feature pickup_community_area has no shape. Setting to VarLenSparseTensor. INFO:absl:Feature pickup_latitude has no shape. Setting to VarLenSparseTensor. INFO:absl:Feature pickup_longitude has no shape. Setting to VarLenSparseTensor. INFO:absl:Feature tips has no shape. Setting to VarLenSparseTensor. INFO:absl:Feature trip_miles has no shape. Setting to VarLenSparseTensor. INFO:absl:Feature trip_seconds has no shape. Setting to VarLenSparseTensor. INFO:absl:Feature trip_start_day has no shape. Setting to VarLenSparseTensor. INFO:absl:Feature trip_start_hour has no shape. Setting to VarLenSparseTensor. INFO:absl:Feature trip_start_month has no shape. Setting to VarLenSparseTensor. INFO:absl:Feature trip_start_timestamp has no shape. Setting to VarLenSparseTensor. Warning:tensorflow:Tensorflow version (2.3.1) found. Note that Tensorflow Transform support for TF 2.0 is currently in beta, and features such as tf.function may not work as intended. Warning:apache_beam.typehints.typehints:Ignoring send_type hint: <class 'NoneType'> WARNING:apache_beam.typehints.typehints:Ignoring return_type hint: <class 'NoneType'> WARNING:apache_beam.typehints.typehints:Ignoring send_type hint: <class 'NoneType'> WARNING:apache_beam.typehints.typehints:Ignoring return_type hint: <class 'NoneType'> WARNING:apache_beam.typehints.typehints:Ignoring send_type hint: <class 'NoneType'> WARNING:apache_beam.typehints.typehints:Ignoring return_type hint: <class 'NoneType'> INFO:tensorflow:Saver not created because there are no variables in the graph to restore INFO:tensorflow:Saver not created because there are no variables in the graph to restore INFO:tensorflow:Assets added to graph. INFO:tensorflow:Assets written to: /tmp/tfx-interactive-2020-11-26T10_13_22.805724-fe5em05p/Transform/transform_graph/5/.temp_path/tftransform_tmp/af6fc404ec6441879c1fa72d5be16c8b/assets INFO:tensorflow:SavedModel written to: /tmp/tfx-interactive-2020-11-26T10_13_22.805724-fe5em05p/Transform/transform_graph/5/.temp_path/tftransform_tmp/af6fc404ec6441879c1fa72d5be16c8b/saved_model.pb WARNING:tensorflow:Expected binary or unicode string, got type_url: "type.googleapis.com/tensorflow.AssetFileDef" value: "\n\013\n\tConst_2:0\022-vocab_compute_and_apply_vocabulary_vocabulary" Warning:tensorflow:Expected binary or unicode string, got type_url: "type.googleapis.com/tensorflow.AssetFileDef" value: "\n\013\n\tConst_4:0\022/vocab_compute_and_apply_vocabulary_1_vocabulary" INFO:tensorflow:Saver not created because there are no variables in the graph to restore WARNING:tensorflow:Expected binary or unicode string, got type_url: "type.googleapis.com/tensorflow.AssetFileDef" value: "\n\013\n\tConst_2:0\022-vocab_compute_and_apply_vocabulary_vocabulary" Warning:tensorflow:Expected binary or unicode string, got type_url: "type.googleapis.com/tensorflow.AssetFileDef" value: "\n\013\n\tConst_4:0\022/vocab_compute_and_apply_vocabulary_1_vocabulary" INFO:tensorflow:Saver not created because there are no variables in the graph to restore WARNING:tensorflow:Expected binary or unicode string, got type_url: "type.googleapis.com/tensorflow.AssetFileDef" value: "\n\013\n\tConst_2:0\022-vocab_compute_and_apply_vocabulary_vocabulary" Warning:tensorflow:Expected binary or unicode string, got type_url: "type.googleapis.com/tensorflow.AssetFileDef" value: "\n\013\n\tConst_4:0\022/vocab_compute_and_apply_vocabulary_1_vocabulary" INFO:tensorflow:Saver not created because there are no variables in the graph to restore INFO:absl:Running publisher for Transform INFO:absl:MetadataStore with DB connection initialized
Let's examine the output artifacts of Transform
. This component produces two types of outputs:
transform_graph
is the graph that can perform the preprocessing operations (this graph will be included in the serving and evaluation models).transformed_examples
represents the preprocessed training and evaluation data.
transform.outputs
{'transform_graph': Channel( type_name: TransformGraph artifacts: [Artifact(artifact: id: 5 type_id: 13 uri: "/tmp/tfx-interactive-2020-11-26T10_13_22.805724-fe5em05p/Transform/transform_graph/5" custom_properties { key: "name" value { string_value: "transform_graph" } } custom_properties { key: "producer_component" value { string_value: "Transform" } } custom_properties { key: "state" value { string_value: "published" } } state: LIVE , artifact_type: id: 13 name: "TransformGraph" )] ), 'transformed_examples': Channel( type_name: Examples artifacts: [Artifact(artifact: id: 6 type_id: 5 uri: "/tmp/tfx-interactive-2020-11-26T10_13_22.805724-fe5em05p/Transform/transformed_examples/5" properties { key: "split_names" value { string_value: "[\"train\", \"eval\"]" } } custom_properties { key: "name" value { string_value: "transformed_examples" } } custom_properties { key: "producer_component" value { string_value: "Transform" } } custom_properties { key: "state" value { string_value: "published" } } state: LIVE , artifact_type: id: 5 name: "Examples" properties { key: "span" value: INT } properties { key: "split_names" value: STRING } properties { key: "version" value: INT } )] ), 'updated_analyzer_cache': Channel( type_name: TransformCache artifacts: [Artifact(artifact: id: 7 type_id: 14 uri: "/tmp/tfx-interactive-2020-11-26T10_13_22.805724-fe5em05p/Transform/updated_analyzer_cache/5" custom_properties { key: "name" value { string_value: "updated_analyzer_cache" } } custom_properties { key: "producer_component" value { string_value: "Transform" } } custom_properties { key: "state" value { string_value: "published" } } state: LIVE , artifact_type: id: 14 name: "TransformCache" )] )}
Take a peek at the transform_graph
artifact. It points to a directory containing three subdirectories.
train_uri = transform.outputs['transform_graph'].get()[0].uri
os.listdir(train_uri)
['transform_fn', 'transformed_metadata', 'metadata']
The transformed_metadata
subdirectory contains the schema of the preprocessed data. The transform_fn
subdirectory contains the actual preprocessing graph. The metadata
subdirectory contains the schema of the original data.
We can also take a look at the first three transformed examples:
# Get the URI of the output artifact representing the transformed examples, which is a directory
train_uri = os.path.join(transform.outputs['transformed_examples'].get()[0].uri, 'train')
# Get the list of files in this directory (all compressed TFRecord files)
tfrecord_filenames = [os.path.join(train_uri, name)
for name in os.listdir(train_uri)]
# Create a `TFRecordDataset` to read these files
dataset = tf.data.TFRecordDataset(tfrecord_filenames, compression_type="GZIP")
# Iterate over the first 3 records and decode them.
for tfrecord in dataset.take(3):
serialized_example = tfrecord.numpy()
example = tf.train.Example()
example.ParseFromString(serialized_example)
pp.pprint(example)
features { feature { key: "company_xf" value { int64_list { value: 8 } } } feature { key: "dropoff_census_tract_xf" value { int64_list { value: 0 } } } feature { key: "dropoff_community_area_xf" value { int64_list { value: 0 } } } feature { key: "dropoff_latitude_xf" value { int64_list { value: 0 } } } feature { key: "dropoff_longitude_xf" value { int64_list { value: 9 } } } feature { key: "fare_xf" value { float_list { value: 0.06106060370802879 } } } feature { key: "payment_type_xf" value { int64_list { value: 1 } } } feature { key: "pickup_census_tract_xf" value { int64_list { value: 0 } } } feature { key: "pickup_community_area_xf" value { int64_list { value: 0 } } } feature { key: "pickup_latitude_xf" value { int64_list { value: 0 } } } feature { key: "pickup_longitude_xf" value { int64_list { value: 9 } } } feature { key: "tips_xf" value { int64_list { value: 0 } } } feature { key: "trip_miles_xf" value { float_list { value: -0.15886740386486053 } } } feature { key: "trip_seconds_xf" value { float_list { value: -0.7118487358093262 } } } feature { key: "trip_start_day_xf" value { int64_list { value: 6 } } } feature { key: "trip_start_hour_xf" value { int64_list { value: 19 } } } feature { key: "trip_start_month_xf" value { int64_list { value: 5 } } } } features { feature { key: "company_xf" value { int64_list { value: 0 } } } feature { key: "dropoff_census_tract_xf" value { int64_list { value: 0 } } } feature { key: "dropoff_community_area_xf" value { int64_list { value: 0 } } } feature { key: "dropoff_latitude_xf" value { int64_list { value: 0 } } } feature { key: "dropoff_longitude_xf" value { int64_list { value: 9 } } } feature { key: "fare_xf" value { float_list { value: 1.2521241903305054 } } } feature { key: "payment_type_xf" value { int64_list { value: 0 } } } feature { key: "pickup_census_tract_xf" value { int64_list { value: 0 } } } feature { key: "pickup_community_area_xf" value { int64_list { value: 60 } } } feature { key: "pickup_latitude_xf" value { int64_list { value: 0 } } } feature { key: "pickup_longitude_xf" value { int64_list { value: 3 } } } feature { key: "tips_xf" value { int64_list { value: 0 } } } feature { key: "trip_miles_xf" value { float_list { value: 0.532160758972168 } } } feature { key: "trip_seconds_xf" value { float_list { value: 0.5509493350982666 } } } feature { key: "trip_start_day_xf" value { int64_list { value: 3 } } } feature { key: "trip_start_hour_xf" value { int64_list { value: 2 } } } feature { key: "trip_start_month_xf" value { int64_list { value: 10 } } } } features { feature { key: "company_xf" value { int64_list { value: 48 } } } feature { key: "dropoff_census_tract_xf" value { int64_list { value: 0 } } } feature { key: "dropoff_community_area_xf" value { int64_list { value: 0 } } } feature { key: "dropoff_latitude_xf" value { int64_list { value: 0 } } } feature { key: "dropoff_longitude_xf" value { int64_list { value: 9 } } } feature { key: "fare_xf" value { float_list { value: 0.3873794972896576 } } } feature { key: "payment_type_xf" value { int64_list { value: 0 } } } feature { key: "pickup_census_tract_xf" value { int64_list { value: 0 } } } feature { key: "pickup_community_area_xf" value { int64_list { value: 13 } } } feature { key: "pickup_latitude_xf" value { int64_list { value: 9 } } } feature { key: "pickup_longitude_xf" value { int64_list { value: 0 } } } feature { key: "tips_xf" value { int64_list { value: 0 } } } feature { key: "trip_miles_xf" value { float_list { value: 0.21955278515815735 } } } feature { key: "trip_seconds_xf" value { float_list { value: 0.0019067145185545087 } } } feature { key: "trip_start_day_xf" value { int64_list { value: 3 } } } feature { key: "trip_start_hour_xf" value { int64_list { value: 12 } } } feature { key: "trip_start_month_xf" value { int64_list { value: 11 } } } }
After the Transform
component has transformed your data into features, and the next step is to train a model.
Trainer
The Trainer
component will train a model that you define in TensorFlow (either using the Estimator API or the Keras API with model_to_estimator
).
Trainer
takes as input the schema from SchemaGen
, the transformed data and graph from Transform
, training parameters, as well as a module that contains user-defined model code.
Let's see an example of user-defined model code below (for an introduction to the TensorFlow Estimator APIs, see the tutorial):
_taxi_trainer_module_file = 'taxi_trainer.py'
%%writefile {_taxi_trainer_module_file}
import tensorflow as tf
import tensorflow_model_analysis as tfma
import tensorflow_transform as tft
from tensorflow_transform.tf_metadata import schema_utils
from tfx_bsl.tfxio import dataset_options
import taxi_constants
_DENSE_FLOAT_FEATURE_KEYS = taxi_constants.DENSE_FLOAT_FEATURE_KEYS
_VOCAB_FEATURE_KEYS = taxi_constants.VOCAB_FEATURE_KEYS
_VOCAB_SIZE = taxi_constants.VOCAB_SIZE
_OOV_SIZE = taxi_constants.OOV_SIZE
_FEATURE_BUCKET_COUNT = taxi_constants.FEATURE_BUCKET_COUNT
_BUCKET_FEATURE_KEYS = taxi_constants.BUCKET_FEATURE_KEYS
_CATEGORICAL_FEATURE_KEYS = taxi_constants.CATEGORICAL_FEATURE_KEYS
_MAX_CATEGORICAL_FEATURE_VALUES = taxi_constants.MAX_CATEGORICAL_FEATURE_VALUES
_LABEL_KEY = taxi_constants.LABEL_KEY
_transformed_name = taxi_constants.transformed_name
def _transformed_names(keys):
return [_transformed_name(key) for key in keys]
# Tf.Transform considers these features as "raw"
def _get_raw_feature_spec(schema):
return schema_utils.schema_as_feature_spec(schema).feature_spec
def _build_estimator(config, hidden_units=None, warm_start_from=None):
"""Build an estimator for predicting the tipping behavior of taxi riders.
Args:
config: tf.estimator.RunConfig defining the runtime environment for the
estimator (including model_dir).
hidden_units: [int], the layer sizes of the DNN (input layer first)
warm_start_from: Optional directory to warm start from.
Returns:
A dict of the following:
- estimator: The estimator that will be used for training and eval.
- train_spec: Spec for training.
- eval_spec: Spec for eval.
- eval_input_receiver_fn: Input function for eval.
"""
real_valued_columns = [
tf.feature_column.numeric_column(key, shape=())
for key in _transformed_names(_DENSE_FLOAT_FEATURE_KEYS)
]
categorical_columns = [
tf.feature_column.categorical_column_with_identity(
key, num_buckets=_VOCAB_SIZE + _OOV_SIZE, default_value=0)
for key in _transformed_names(_VOCAB_FEATURE_KEYS)
]
categorical_columns += [
tf.feature_column.categorical_column_with_identity(
key, num_buckets=_FEATURE_BUCKET_COUNT, default_value=0)
for key in _transformed_names(_BUCKET_FEATURE_KEYS)
]
categorical_columns += [
tf.feature_column.categorical_column_with_identity( # pylint: disable=g-complex-comprehension
key,
num_buckets=num_buckets,
default_value=0) for key, num_buckets in zip(
_transformed_names(_CATEGORICAL_FEATURE_KEYS),
_MAX_CATEGORICAL_FEATURE_VALUES)
]
return tf.estimator.DNNLinearCombinedClassifier(
config=config,
linear_feature_columns=categorical_columns,
dnn_feature_columns=real_valued_columns,
dnn_hidden_units=hidden_units or [100, 70, 50, 25],
warm_start_from=warm_start_from)
def _example_serving_receiver_fn(tf_transform_graph, schema):
"""Build the serving in inputs.
Args:
tf_transform_graph: A TFTransformOutput.
schema: the schema of the input data.
Returns:
Tensorflow graph which parses examples, applying tf-transform to them.
"""
raw_feature_spec = _get_raw_feature_spec(schema)
raw_feature_spec.pop(_LABEL_KEY)
raw_input_fn = tf.estimator.export.build_parsing_serving_input_receiver_fn(
raw_feature_spec, default_batch_size=None)
serving_input_receiver = raw_input_fn()
transformed_features = tf_transform_graph.transform_raw_features(
serving_input_receiver.features)
return tf.estimator.export.ServingInputReceiver(
transformed_features, serving_input_receiver.receiver_tensors)
def _eval_input_receiver_fn(tf_transform_graph, schema):
"""Build everything needed for the tf-model-analysis to run the model.
Args:
tf_transform_graph: A TFTransformOutput.
schema: the schema of the input data.
Returns:
EvalInputReceiver function, which contains:
- Tensorflow graph which parses raw untransformed features, applies the
tf-transform preprocessing operators.
- Set of raw, untransformed features.
- Label against which predictions will be compared.
"""
# Notice that the inputs are raw features, not transformed features here.
raw_feature_spec = _get_raw_feature_spec(schema)
serialized_tf_example = tf.compat.v1.placeholder(
dtype=tf.string, shape=[None], name='input_example_tensor')
# Add a parse_example operator to the tensorflow graph, which will parse
# raw, untransformed, tf examples.
features = tf.io.parse_example(serialized_tf_example, raw_feature_spec)
# Now that we have our raw examples, process them through the tf-transform
# function computed during the preprocessing step.
transformed_features = tf_transform_graph.transform_raw_features(
features)
# The key name MUST be 'examples'.
receiver_tensors = {'examples': serialized_tf_example}
# NOTE: Model is driven by transformed features (since training works on the
# materialized output of TFT, but slicing will happen on raw features.
features.update(transformed_features)
return tfma.export.EvalInputReceiver(
features=features,
receiver_tensors=receiver_tensors,
labels=transformed_features[_transformed_name(_LABEL_KEY)])
def _input_fn(file_pattern, data_accessor, tf_transform_output, batch_size=200):
"""Generates features and label for tuning/training.
Args:
file_pattern: List of paths or patterns of input tfrecord files.
data_accessor: DataAccessor for converting input to RecordBatch.
tf_transform_output: A TFTransformOutput.
batch_size: representing the number of consecutive elements of returned
dataset to combine in a single batch
Returns:
A dataset that contains (features, indices) tuple where features is a
dictionary of Tensors, and indices is a single Tensor of label indices.
"""
return data_accessor.tf_dataset_factory(
file_pattern,
dataset_options.TensorFlowDatasetOptions(
batch_size=batch_size, label_key=_transformed_name(_LABEL_KEY)),
tf_transform_output.transformed_metadata.schema)
# TFX will call this function
def trainer_fn(trainer_fn_args, schema):
"""Build the estimator using the high level API.
Args:
trainer_fn_args: Holds args used to train the model as name/value pairs.
schema: Holds the schema of the training examples.
Returns:
A dict of the following:
- estimator: The estimator that will be used for training and eval.
- train_spec: Spec for training.
- eval_spec: Spec for eval.
- eval_input_receiver_fn: Input function for eval.
"""
# Number of nodes in the first layer of the DNN
first_dnn_layer_size = 100
num_dnn_layers = 4
dnn_decay_factor = 0.7
train_batch_size = 40
eval_batch_size = 40
tf_transform_graph = tft.TFTransformOutput(trainer_fn_args.transform_output)
train_input_fn = lambda: _input_fn( # pylint: disable=g-long-lambda
trainer_fn_args.train_files,
trainer_fn_args.data_accessor,
tf_transform_graph,
batch_size=train_batch_size)
eval_input_fn = lambda: _input_fn( # pylint: disable=g-long-lambda
trainer_fn_args.eval_files,
trainer_fn_args.data_accessor,
tf_transform_graph,
batch_size=eval_batch_size)
train_spec = tf.estimator.TrainSpec( # pylint: disable=g-long-lambda
train_input_fn,
max_steps=trainer_fn_args.train_steps)
serving_receiver_fn = lambda: _example_serving_receiver_fn( # pylint: disable=g-long-lambda
tf_transform_graph, schema)
exporter = tf.estimator.FinalExporter('chicago-taxi', serving_receiver_fn)
eval_spec = tf.estimator.EvalSpec(
eval_input_fn,
steps=trainer_fn_args.eval_steps,
exporters=[exporter],
name='chicago-taxi-eval')
run_config = tf.estimator.RunConfig(
save_checkpoints_steps=999, keep_checkpoint_max=1)
run_config = run_config.replace(model_dir=trainer_fn_args.serving_model_dir)
estimator = _build_estimator(
# Construct layers sizes with exponetial decay
hidden_units=[
max(2, int(first_dnn_layer_size * dnn_decay_factor**i))
for i in range(num_dnn_layers)
],
config=run_config,
warm_start_from=trainer_fn_args.base_model)
# Create an input receiver for TFMA processing
receiver_fn = lambda: _eval_input_receiver_fn( # pylint: disable=g-long-lambda
tf_transform_graph, schema)
return {
'estimator': estimator,
'train_spec': train_spec,
'eval_spec': eval_spec,
'eval_input_receiver_fn': receiver_fn
}
Writing taxi_trainer.py
Now, we pass in this model code to the Trainer
component and run it to train the model.
trainer = Trainer(
module_file=os.path.abspath(_taxi_trainer_module_file),
transformed_examples=transform.outputs['transformed_examples'],
schema=schema_gen.outputs['schema'],
transform_graph=transform.outputs['transform_graph'],
train_args=trainer_pb2.TrainArgs(num_steps=10000),
eval_args=trainer_pb2.EvalArgs(num_steps=5000))
context.run(trainer)
INFO:absl:Running driver for Trainer INFO:absl:MetadataStore with DB connection initialized INFO:absl:Running executor for Trainer INFO:absl:Train on the 'train' split when train_args.splits is not set. INFO:absl:Evaluate on the 'eval' split when eval_args.splits is not set. WARNING:absl:Examples artifact does not have payload_format custom property. Falling back to FORMAT_TF_EXAMPLE WARNING:absl:Examples artifact does not have payload_format custom property. Falling back to FORMAT_TF_EXAMPLE INFO:tensorflow:Using config: {'_model_dir': '/tmp/tfx-interactive-2020-11-26T10_13_22.805724-fe5em05p/Trainer/model_run/6/serving_model_dir', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': 999, '_save_checkpoints_secs': None, '_session_config': allow_soft_placement: true graph_options { rewrite_options { meta_optimizer_iterations: ONE } } , '_keep_checkpoint_max': 1, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1} INFO:absl:Training model. INFO:tensorflow:Not using Distribute Coordinator. INFO:tensorflow:Running training and evaluation locally (non-distributed). INFO:tensorflow:Start train and evaluate loop. The evaluate will happen after every checkpoint. Checkpoint frequency is determined based on RunConfig arguments: save_checkpoints_steps 999 or save_checkpoints_secs None. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/training/training_util.py:236: Variable.initialized_value (from tensorflow.python.ops.variables) is deprecated and will be removed in a future version. Instructions for updating: Use Variable.read_value. Variables in 2.X are initialized automatically both in eager and graph (inside tf.defun) contexts. INFO:absl:Feature company_xf has a shape . Setting to DenseTensor. INFO:absl:Feature dropoff_census_tract_xf has a shape . Setting to DenseTensor. INFO:absl:Feature dropoff_community_area_xf has a shape . Setting to DenseTensor. INFO:absl:Feature dropoff_latitude_xf has a shape . Setting to DenseTensor. INFO:absl:Feature dropoff_longitude_xf has a shape . Setting to DenseTensor. INFO:absl:Feature fare_xf has a shape . Setting to DenseTensor. INFO:absl:Feature payment_type_xf has a shape . Setting to DenseTensor. INFO:absl:Feature pickup_census_tract_xf has a shape . Setting to DenseTensor. INFO:absl:Feature pickup_community_area_xf has a shape . Setting to DenseTensor. INFO:absl:Feature pickup_latitude_xf has a shape . Setting to DenseTensor. INFO:absl:Feature pickup_longitude_xf has a shape . Setting to DenseTensor. INFO:absl:Feature tips_xf has a shape . Setting to DenseTensor. INFO:absl:Feature trip_miles_xf has a shape . Setting to DenseTensor. INFO:absl:Feature trip_seconds_xf has a shape . Setting to DenseTensor. INFO:absl:Feature trip_start_day_xf has a shape . Setting to DenseTensor. INFO:absl:Feature trip_start_hour_xf has a shape . Setting to DenseTensor. INFO:absl:Feature trip_start_month_xf has a shape . Setting to DenseTensor. INFO:tensorflow:Calling model_fn. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_estimator/python/estimator/canned/linear.py:1481: Layer.add_variable (from tensorflow.python.keras.engine.base_layer_v1) is deprecated and will be removed in a future version. Instructions for updating: Please use `layer.add_weight` method instead. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/keras/optimizer_v2/adagrad.py:83: calling Constant.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version. Instructions for updating: Call initializer instance with the dtype argument instead of passing it to the constructor INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Create CheckpointSaverHook. INFO:tensorflow:Graph was finalized. INFO:tensorflow:Running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0... INFO:tensorflow:Saving checkpoints for 0 into /tmp/tfx-interactive-2020-11-26T10_13_22.805724-fe5em05p/Trainer/model_run/6/serving_model_dir/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0... INFO:tensorflow:loss = 0.6974667, step = 0 INFO:tensorflow:global_step/sec: 89.1231 INFO:tensorflow:loss = 0.5732481, step = 100 (1.123 sec) INFO:tensorflow:global_step/sec: 113.474 INFO:tensorflow:loss = 0.5176774, step = 200 (0.881 sec) INFO:tensorflow:global_step/sec: 114.63 INFO:tensorflow:loss = 0.5562607, step = 300 (0.872 sec) INFO:tensorflow:global_step/sec: 113.5 INFO:tensorflow:loss = 0.53873014, step = 400 (0.881 sec) INFO:tensorflow:global_step/sec: 111.799 INFO:tensorflow:loss = 0.44045067, step = 500 (0.894 sec) INFO:tensorflow:global_step/sec: 113.062 INFO:tensorflow:loss = 0.57715523, step = 600 (0.885 sec) INFO:tensorflow:global_step/sec: 112.051 INFO:tensorflow:loss = 0.46877265, step = 700 (0.892 sec) INFO:tensorflow:global_step/sec: 112.388 INFO:tensorflow:loss = 0.51078546, step = 800 (0.890 sec) INFO:tensorflow:global_step/sec: 111.415 INFO:tensorflow:loss = 0.4560148, step = 900 (0.898 sec) INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 999... INFO:tensorflow:Saving checkpoints for 999 into /tmp/tfx-interactive-2020-11-26T10_13_22.805724-fe5em05p/Trainer/model_run/6/serving_model_dir/model.ckpt. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/training/saver.py:971: remove_checkpoint (from tensorflow.python.training.checkpoint_management) is deprecated and will be removed in a future version. Instructions for updating: Use standard file APIs to delete files with this prefix. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 999... INFO:absl:Feature company_xf has a shape . Setting to DenseTensor. INFO:absl:Feature dropoff_census_tract_xf has a shape . Setting to DenseTensor. INFO:absl:Feature dropoff_community_area_xf has a shape . Setting to DenseTensor. INFO:absl:Feature dropoff_latitude_xf has a shape . Setting to DenseTensor. INFO:absl:Feature dropoff_longitude_xf has a shape . Setting to DenseTensor. INFO:absl:Feature fare_xf has a shape . Setting to DenseTensor. INFO:absl:Feature payment_type_xf has a shape . Setting to DenseTensor. INFO:absl:Feature pickup_census_tract_xf has a shape . Setting to DenseTensor. INFO:absl:Feature pickup_community_area_xf has a shape . Setting to DenseTensor. INFO:absl:Feature pickup_latitude_xf has a shape . Setting to DenseTensor. INFO:absl:Feature pickup_longitude_xf has a shape . Setting to DenseTensor. INFO:absl:Feature tips_xf has a shape . Setting to DenseTensor. INFO:absl:Feature trip_miles_xf has a shape . Setting to DenseTensor. INFO:absl:Feature trip_seconds_xf has a shape . Setting to DenseTensor. INFO:absl:Feature trip_start_day_xf has a shape . Setting to DenseTensor. INFO:absl:Feature trip_start_hour_xf has a shape . Setting to DenseTensor. INFO:absl:Feature trip_start_month_xf has a shape . Setting to DenseTensor. INFO:tensorflow:Calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Starting evaluation at 2020-11-26T10:14:00Z INFO:tensorflow:Graph was finalized. INFO:tensorflow:Restoring parameters from /tmp/tfx-interactive-2020-11-26T10_13_22.805724-fe5em05p/Trainer/model_run/6/serving_model_dir/model.ckpt-999 INFO:tensorflow:Running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Evaluation [500/5000] INFO:tensorflow:Evaluation [1000/5000] INFO:tensorflow:Evaluation [1500/5000] INFO:tensorflow:Evaluation [2000/5000] INFO:tensorflow:Evaluation [2500/5000] INFO:tensorflow:Evaluation [3000/5000] INFO:tensorflow:Evaluation [3500/5000] INFO:tensorflow:Evaluation [4000/5000] INFO:tensorflow:Evaluation [4500/5000] INFO:tensorflow:Evaluation [5000/5000] INFO:tensorflow:Inference Time : 46.48948s INFO:tensorflow:Finished evaluation at 2020-11-26-10:14:46 INFO:tensorflow:Saving dict for global step 999: accuracy = 0.77121, accuracy_baseline = 0.77121, auc = 0.9247545, auc_precision_recall = 0.66692716, average_loss = 0.46251872, global_step = 999, label/mean = 0.22879, loss = 0.46251866, precision = 0.0, prediction/mean = 0.2535291, recall = 0.0 INFO:tensorflow:Saving 'checkpoint_path' summary for global step 999: /tmp/tfx-interactive-2020-11-26T10_13_22.805724-fe5em05p/Trainer/model_run/6/serving_model_dir/model.ckpt-999 INFO:tensorflow:global_step/sec: 2.043 INFO:tensorflow:loss = 0.3749024, step = 1000 (48.948 sec) INFO:tensorflow:global_step/sec: 109.709 INFO:tensorflow:loss = 0.4592988, step = 1100 (0.912 sec) INFO:tensorflow:global_step/sec: 112.449 INFO:tensorflow:loss = 0.37520203, step = 1200 (0.890 sec) INFO:tensorflow:global_step/sec: 109.812 INFO:tensorflow:loss = 0.38372535, step = 1300 (0.910 sec) INFO:tensorflow:global_step/sec: 113.247 INFO:tensorflow:loss = 0.4365883, step = 1400 (0.883 sec) INFO:tensorflow:global_step/sec: 114.602 INFO:tensorflow:loss = 0.42457563, step = 1500 (0.873 sec) INFO:tensorflow:global_step/sec: 115.115 INFO:tensorflow:loss = 0.38821325, step = 1600 (0.869 sec) INFO:tensorflow:global_step/sec: 114.471 INFO:tensorflow:loss = 0.4025802, step = 1700 (0.874 sec) INFO:tensorflow:global_step/sec: 114.604 INFO:tensorflow:loss = 0.50898427, step = 1800 (0.873 sec) INFO:tensorflow:global_step/sec: 115.29 INFO:tensorflow:loss = 0.39604157, step = 1900 (0.868 sec) INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 1998... INFO:tensorflow:Saving checkpoints for 1998 into /tmp/tfx-interactive-2020-11-26T10_13_22.805724-fe5em05p/Trainer/model_run/6/serving_model_dir/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 1998... INFO:tensorflow:Skip the current checkpoint eval due to throttle secs (600 secs). INFO:tensorflow:global_step/sec: 100.896 INFO:tensorflow:loss = 0.4355023, step = 2000 (0.991 sec) INFO:tensorflow:global_step/sec: 114.377 INFO:tensorflow:loss = 0.4315641, step = 2100 (0.874 sec) INFO:tensorflow:global_step/sec: 113.536 INFO:tensorflow:loss = 0.34944555, step = 2200 (0.881 sec) INFO:tensorflow:global_step/sec: 113.708 INFO:tensorflow:loss = 0.43738118, step = 2300 (0.879 sec) INFO:tensorflow:global_step/sec: 115.111 INFO:tensorflow:loss = 0.47878623, step = 2400 (0.869 sec) INFO:tensorflow:global_step/sec: 112.33 INFO:tensorflow:loss = 0.383889, step = 2500 (0.891 sec) INFO:tensorflow:global_step/sec: 113.464 INFO:tensorflow:loss = 0.35396174, step = 2600 (0.881 sec) INFO:tensorflow:global_step/sec: 112.511 INFO:tensorflow:loss = 0.3468891, step = 2700 (0.889 sec) INFO:tensorflow:global_step/sec: 113.298 INFO:tensorflow:loss = 0.38933015, step = 2800 (0.883 sec) INFO:tensorflow:global_step/sec: 114.633 INFO:tensorflow:loss = 0.34905306, step = 2900 (0.872 sec) INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 2997... INFO:tensorflow:Saving checkpoints for 2997 into /tmp/tfx-interactive-2020-11-26T10_13_22.805724-fe5em05p/Trainer/model_run/6/serving_model_dir/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 2997... INFO:tensorflow:Skip the current checkpoint eval due to throttle secs (600 secs). INFO:tensorflow:global_step/sec: 100.777 INFO:tensorflow:loss = 0.42350882, step = 3000 (0.992 sec) INFO:tensorflow:global_step/sec: 113.785 INFO:tensorflow:loss = 0.31562096, step = 3100 (0.879 sec) INFO:tensorflow:global_step/sec: 113.34 INFO:tensorflow:loss = 0.42950398, step = 3200 (0.882 sec) INFO:tensorflow:global_step/sec: 114.675 INFO:tensorflow:loss = 0.47113928, step = 3300 (0.872 sec) INFO:tensorflow:global_step/sec: 114.508 INFO:tensorflow:loss = 0.41468287, step = 3400 (0.873 sec) INFO:tensorflow:global_step/sec: 115.75 INFO:tensorflow:loss = 0.31615996, step = 3500 (0.864 sec) INFO:tensorflow:global_step/sec: 114.195 INFO:tensorflow:loss = 0.4714353, step = 3600 (0.876 sec) INFO:tensorflow:global_step/sec: 114.797 INFO:tensorflow:loss = 0.30707058, step = 3700 (0.871 sec) INFO:tensorflow:global_step/sec: 114.767 INFO:tensorflow:loss = 0.3783667, step = 3800 (0.871 sec) INFO:tensorflow:global_step/sec: 114.016 INFO:tensorflow:loss = 0.34137535, step = 3900 (0.877 sec) INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 3996... INFO:tensorflow:Saving checkpoints for 3996 into /tmp/tfx-interactive-2020-11-26T10_13_22.805724-fe5em05p/Trainer/model_run/6/serving_model_dir/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 3996... INFO:tensorflow:Skip the current checkpoint eval due to throttle secs (600 secs). INFO:tensorflow:global_step/sec: 100.197 INFO:tensorflow:loss = 0.45167357, step = 4000 (0.998 sec) INFO:tensorflow:global_step/sec: 114.295 INFO:tensorflow:loss = 0.36682805, step = 4100 (0.875 sec) INFO:tensorflow:global_step/sec: 112.865 INFO:tensorflow:loss = 0.37716842, step = 4200 (0.886 sec) INFO:tensorflow:global_step/sec: 114.446 INFO:tensorflow:loss = 0.34928334, step = 4300 (0.874 sec) INFO:tensorflow:global_step/sec: 113.651 INFO:tensorflow:loss = 0.46821588, step = 4400 (0.880 sec) INFO:tensorflow:global_step/sec: 111.894 INFO:tensorflow:loss = 0.42282277, step = 4500 (0.894 sec) INFO:tensorflow:global_step/sec: 113.436 INFO:tensorflow:loss = 0.34230947, step = 4600 (0.883 sec) INFO:tensorflow:global_step/sec: 112.371 INFO:tensorflow:loss = 0.40548354, step = 4700 (0.889 sec) INFO:tensorflow:global_step/sec: 113.867 INFO:tensorflow:loss = 0.3861308, step = 4800 (0.878 sec) INFO:tensorflow:global_step/sec: 112.406 INFO:tensorflow:loss = 0.3875846, step = 4900 (0.890 sec) INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 4995... INFO:tensorflow:Saving checkpoints for 4995 into /tmp/tfx-interactive-2020-11-26T10_13_22.805724-fe5em05p/Trainer/model_run/6/serving_model_dir/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 4995... INFO:tensorflow:Skip the current checkpoint eval due to throttle secs (600 secs). INFO:tensorflow:global_step/sec: 99.7081 INFO:tensorflow:loss = 0.44678396, step = 5000 (1.003 sec) INFO:tensorflow:global_step/sec: 112.912 INFO:tensorflow:loss = 0.39096385, step = 5100 (0.886 sec) INFO:tensorflow:global_step/sec: 113.011 INFO:tensorflow:loss = 0.3414794, step = 5200 (0.885 sec) INFO:tensorflow:global_step/sec: 114.988 INFO:tensorflow:loss = 0.3385579, step = 5300 (0.869 sec) INFO:tensorflow:global_step/sec: 113.642 INFO:tensorflow:loss = 0.32704902, step = 5400 (0.880 sec) INFO:tensorflow:global_step/sec: 113.67 INFO:tensorflow:loss = 0.39469784, step = 5500 (0.880 sec) INFO:tensorflow:global_step/sec: 114.248 INFO:tensorflow:loss = 0.33442765, step = 5600 (0.875 sec) INFO:tensorflow:global_step/sec: 114.85 INFO:tensorflow:loss = 0.3017143, step = 5700 (0.871 sec) INFO:tensorflow:global_step/sec: 114.052 INFO:tensorflow:loss = 0.33482796, step = 5800 (0.877 sec) INFO:tensorflow:global_step/sec: 113.413 INFO:tensorflow:loss = 0.42112976, step = 5900 (0.882 sec) INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 5994... INFO:tensorflow:Saving checkpoints for 5994 into /tmp/tfx-interactive-2020-11-26T10_13_22.805724-fe5em05p/Trainer/model_run/6/serving_model_dir/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 5994... INFO:tensorflow:Skip the current checkpoint eval due to throttle secs (600 secs). INFO:tensorflow:global_step/sec: 98.5739 INFO:tensorflow:loss = 0.3670102, step = 6000 (1.015 sec) INFO:tensorflow:global_step/sec: 114.786 INFO:tensorflow:loss = 0.33206683, step = 6100 (0.871 sec) INFO:tensorflow:global_step/sec: 112.258 INFO:tensorflow:loss = 0.3649478, step = 6200 (0.891 sec) INFO:tensorflow:global_step/sec: 113.265 INFO:tensorflow:loss = 0.34351104, step = 6300 (0.884 sec) INFO:tensorflow:global_step/sec: 113.556 INFO:tensorflow:loss = 0.33417398, step = 6400 (0.880 sec) INFO:tensorflow:global_step/sec: 115.112 INFO:tensorflow:loss = 0.40121427, step = 6500 (0.869 sec) INFO:tensorflow:global_step/sec: 112.671 INFO:tensorflow:loss = 0.34591922, step = 6600 (0.888 sec) INFO:tensorflow:global_step/sec: 114.971 INFO:tensorflow:loss = 0.34938964, step = 6700 (0.870 sec) INFO:tensorflow:global_step/sec: 114.675 INFO:tensorflow:loss = 0.37972444, step = 6800 (0.872 sec) INFO:tensorflow:global_step/sec: 114.236 INFO:tensorflow:loss = 0.37977144, step = 6900 (0.875 sec) INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 6993... INFO:tensorflow:Saving checkpoints for 6993 into /tmp/tfx-interactive-2020-11-26T10_13_22.805724-fe5em05p/Trainer/model_run/6/serving_model_dir/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 6993... INFO:tensorflow:Skip the current checkpoint eval due to throttle secs (600 secs). INFO:tensorflow:global_step/sec: 99.2694 INFO:tensorflow:loss = 0.35126358, step = 7000 (1.007 sec) INFO:tensorflow:global_step/sec: 112.529 INFO:tensorflow:loss = 0.37536883, step = 7100 (0.889 sec) INFO:tensorflow:global_step/sec: 114.702 INFO:tensorflow:loss = 0.38491726, step = 7200 (0.872 sec) INFO:tensorflow:global_step/sec: 114.154 INFO:tensorflow:loss = 0.34859103, step = 7300 (0.876 sec) INFO:tensorflow:global_step/sec: 113.878 INFO:tensorflow:loss = 0.2851309, step = 7400 (0.878 sec) INFO:tensorflow:global_step/sec: 115.684 INFO:tensorflow:loss = 0.33018833, step = 7500 (0.864 sec) INFO:tensorflow:global_step/sec: 115.05 INFO:tensorflow:loss = 0.2911943, step = 7600 (0.869 sec) INFO:tensorflow:global_step/sec: 114.228 INFO:tensorflow:loss = 0.34445795, step = 7700 (0.875 sec) INFO:tensorflow:global_step/sec: 112.949 INFO:tensorflow:loss = 0.3777662, step = 7800 (0.885 sec) INFO:tensorflow:global_step/sec: 114.814 INFO:tensorflow:loss = 0.3525071, step = 7900 (0.871 sec) INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 7992... INFO:tensorflow:Saving checkpoints for 7992 into /tmp/tfx-interactive-2020-11-26T10_13_22.805724-fe5em05p/Trainer/model_run/6/serving_model_dir/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 7992... INFO:tensorflow:Skip the current checkpoint eval due to throttle secs (600 secs). INFO:tensorflow:global_step/sec: 100.595 INFO:tensorflow:loss = 0.31738895, step = 8000 (0.995 sec) INFO:tensorflow:global_step/sec: 115.258 INFO:tensorflow:loss = 0.32818204, step = 8100 (0.867 sec) INFO:tensorflow:global_step/sec: 114.606 INFO:tensorflow:loss = 0.3743958, step = 8200 (0.872 sec) INFO:tensorflow:global_step/sec: 114.643 INFO:tensorflow:loss = 0.34712324, step = 8300 (0.872 sec) INFO:tensorflow:global_step/sec: 115.182 INFO:tensorflow:loss = 0.35043937, step = 8400 (0.868 sec) INFO:tensorflow:global_step/sec: 114.304 INFO:tensorflow:loss = 0.37599805, step = 8500 (0.875 sec) INFO:tensorflow:global_step/sec: 112.461 INFO:tensorflow:loss = 0.34265774, step = 8600 (0.889 sec) INFO:tensorflow:global_step/sec: 114.598 INFO:tensorflow:loss = 0.35895196, step = 8700 (0.873 sec) INFO:tensorflow:global_step/sec: 114.452 INFO:tensorflow:loss = 0.33054015, step = 8800 (0.874 sec) INFO:tensorflow:global_step/sec: 115.018 INFO:tensorflow:loss = 0.30547625, step = 8900 (0.869 sec) INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 8991... INFO:tensorflow:Saving checkpoints for 8991 into /tmp/tfx-interactive-2020-11-26T10_13_22.805724-fe5em05p/Trainer/model_run/6/serving_model_dir/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 8991... INFO:tensorflow:Skip the current checkpoint eval due to throttle secs (600 secs). INFO:tensorflow:global_step/sec: 101.705 INFO:tensorflow:loss = 0.3225851, step = 9000 (0.983 sec) INFO:tensorflow:global_step/sec: 114.576 INFO:tensorflow:loss = 0.34777793, step = 9100 (0.873 sec) INFO:tensorflow:global_step/sec: 115.056 INFO:tensorflow:loss = 0.3121646, step = 9200 (0.869 sec) INFO:tensorflow:global_step/sec: 114.867 INFO:tensorflow:loss = 0.29584178, step = 9300 (0.871 sec) INFO:tensorflow:global_step/sec: 112.136 INFO:tensorflow:loss = 0.34916228, step = 9400 (0.892 sec) INFO:tensorflow:global_step/sec: 112.682 INFO:tensorflow:loss = 0.37707886, step = 9500 (0.887 sec) INFO:tensorflow:global_step/sec: 111.734 INFO:tensorflow:loss = 0.3262028, step = 9600 (0.895 sec) INFO:tensorflow:global_step/sec: 110.884 INFO:tensorflow:loss = 0.4361206, step = 9700 (0.902 sec) INFO:tensorflow:global_step/sec: 111.58 INFO:tensorflow:loss = 0.24492769, step = 9800 (0.896 sec) INFO:tensorflow:global_step/sec: 113.033 INFO:tensorflow:loss = 0.320422, step = 9900 (0.885 sec) INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 9990... INFO:tensorflow:Saving checkpoints for 9990 into /tmp/tfx-interactive-2020-11-26T10_13_22.805724-fe5em05p/Trainer/model_run/6/serving_model_dir/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 9990... INFO:tensorflow:Skip the current checkpoint eval due to throttle secs (600 secs). INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 10000... INFO:tensorflow:Saving checkpoints for 10000 into /tmp/tfx-interactive-2020-11-26T10_13_22.805724-fe5em05p/Trainer/model_run/6/serving_model_dir/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 10000... INFO:tensorflow:Skip the current checkpoint eval due to throttle secs (600 secs). INFO:absl:Feature company_xf has a shape . Setting to DenseTensor. INFO:absl:Feature dropoff_census_tract_xf has a shape . Setting to DenseTensor. INFO:absl:Feature dropoff_community_area_xf has a shape . Setting to DenseTensor. INFO:absl:Feature dropoff_latitude_xf has a shape . Setting to DenseTensor. INFO:absl:Feature dropoff_longitude_xf has a shape . Setting to DenseTensor. INFO:absl:Feature fare_xf has a shape . Setting to DenseTensor. INFO:absl:Feature payment_type_xf has a shape . Setting to DenseTensor. INFO:absl:Feature pickup_census_tract_xf has a shape . Setting to DenseTensor. INFO:absl:Feature pickup_community_area_xf has a shape . Setting to DenseTensor. INFO:absl:Feature pickup_latitude_xf has a shape . Setting to DenseTensor. INFO:absl:Feature pickup_longitude_xf has a shape . Setting to DenseTensor. INFO:absl:Feature tips_xf has a shape . Setting to DenseTensor. INFO:absl:Feature trip_miles_xf has a shape . Setting to DenseTensor. INFO:absl:Feature trip_seconds_xf has a shape . Setting to DenseTensor. INFO:absl:Feature trip_start_day_xf has a shape . Setting to DenseTensor. INFO:absl:Feature trip_start_hour_xf has a shape . Setting to DenseTensor. INFO:absl:Feature trip_start_month_xf has a shape . Setting to DenseTensor. INFO:tensorflow:Calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Starting evaluation at 2020-11-26T10:16:08Z INFO:tensorflow:Graph was finalized. INFO:tensorflow:Restoring parameters from /tmp/tfx-interactive-2020-11-26T10_13_22.805724-fe5em05p/Trainer/model_run/6/serving_model_dir/model.ckpt-10000 INFO:tensorflow:Running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Evaluation [500/5000] INFO:tensorflow:Evaluation [1000/5000] INFO:tensorflow:Evaluation [1500/5000] INFO:tensorflow:Evaluation [2000/5000] INFO:tensorflow:Evaluation [2500/5000] INFO:tensorflow:Evaluation [3000/5000] INFO:tensorflow:Evaluation [3500/5000] INFO:tensorflow:Evaluation [4000/5000] INFO:tensorflow:Evaluation [4500/5000] INFO:tensorflow:Evaluation [5000/5000] INFO:tensorflow:Inference Time : 44.32158s INFO:tensorflow:Finished evaluation at 2020-11-26-10:16:53 INFO:tensorflow:Saving dict for global step 10000: accuracy = 0.785505, accuracy_baseline = 0.771185, auc = 0.9335164, auc_precision_recall = 0.703921, average_loss = 0.34539613, global_step = 10000, label/mean = 0.228815, loss = 0.34539708, precision = 0.6922148, prediction/mean = 0.23044644, recall = 0.112689294 INFO:tensorflow:Saving 'checkpoint_path' summary for global step 10000: /tmp/tfx-interactive-2020-11-26T10_13_22.805724-fe5em05p/Trainer/model_run/6/serving_model_dir/model.ckpt-10000 INFO:tensorflow:Performing the final export in the end of training. WARNING:tensorflow:Expected binary or unicode string, got type_url: "type.googleapis.com/tensorflow.AssetFileDef" value: "\n\013\n\tConst_2:0\022-vocab_compute_and_apply_vocabulary_vocabulary" Warning:tensorflow:Expected binary or unicode string, got type_url: "type.googleapis.com/tensorflow.AssetFileDef" value: "\n\013\n\tConst_4:0\022/vocab_compute_and_apply_vocabulary_1_vocabulary" INFO:tensorflow:Saver not created because there are no variables in the graph to restore INFO:tensorflow:Calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Signatures INCLUDED in export for Classify: ['serving_default', 'classification'] INFO:tensorflow:Signatures INCLUDED in export for Regress: ['regression'] INFO:tensorflow:Signatures INCLUDED in export for Predict: ['predict'] INFO:tensorflow:Signatures INCLUDED in export for Train: None INFO:tensorflow:Signatures INCLUDED in export for Eval: None INFO:tensorflow:Restoring parameters from /tmp/tfx-interactive-2020-11-26T10_13_22.805724-fe5em05p/Trainer/model_run/6/serving_model_dir/model.ckpt-10000 INFO:tensorflow:Assets added to graph. INFO:tensorflow:Assets written to: /tmp/tfx-interactive-2020-11-26T10_13_22.805724-fe5em05p/Trainer/model_run/6/serving_model_dir/export/chicago-taxi/temp-1606385813/assets INFO:tensorflow:SavedModel written to: /tmp/tfx-interactive-2020-11-26T10_13_22.805724-fe5em05p/Trainer/model_run/6/serving_model_dir/export/chicago-taxi/temp-1606385813/saved_model.pb INFO:tensorflow:Loss for final step: 0.373849. INFO:absl:Training complete. Model written to /tmp/tfx-interactive-2020-11-26T10_13_22.805724-fe5em05p/Trainer/model_run/6/serving_model_dir. ModelRun written to /tmp/tfx-interactive-2020-11-26T10_13_22.805724-fe5em05p/Trainer/model_run/6 INFO:absl:Exporting eval_savedmodel for TFMA. Warning:tensorflow:Expected binary or unicode string, got type_url: "type.googleapis.com/tensorflow.AssetFileDef" value: "\n\013\n\tConst_2:0\022-vocab_compute_and_apply_vocabulary_vocabulary" Warning:tensorflow:Expected binary or unicode string, got type_url: "type.googleapis.com/tensorflow.AssetFileDef" value: "\n\013\n\tConst_4:0\022/vocab_compute_and_apply_vocabulary_1_vocabulary" INFO:tensorflow:Saver not created because there are no variables in the graph to restore INFO:tensorflow:Calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Signatures INCLUDED in export for Classify: None INFO:tensorflow:Signatures INCLUDED in export for Regress: None INFO:tensorflow:Signatures INCLUDED in export for Predict: None INFO:tensorflow:Signatures INCLUDED in export for Train: None INFO:tensorflow:Signatures INCLUDED in export for Eval: ['eval'] WARNING:tensorflow:Export includes no default signature! INFO:tensorflow:Restoring parameters from /tmp/tfx-interactive-2020-11-26T10_13_22.805724-fe5em05p/Trainer/model_run/6/serving_model_dir/model.ckpt-10000 INFO:tensorflow:Assets added to graph. INFO:tensorflow:Assets written to: /tmp/tfx-interactive-2020-11-26T10_13_22.805724-fe5em05p/Trainer/model_run/6/eval_model_dir/temp-1606385814/assets INFO:tensorflow:SavedModel written to: /tmp/tfx-interactive-2020-11-26T10_13_22.805724-fe5em05p/Trainer/model_run/6/eval_model_dir/temp-1606385814/saved_model.pb INFO:absl:Exported eval_savedmodel to /tmp/tfx-interactive-2020-11-26T10_13_22.805724-fe5em05p/Trainer/model_run/6/eval_model_dir. WARNING:absl:Support for estimator-based executor and model export will be deprecated soon. Please use export structure <ModelExportPath>/serving_model_dir/saved_model.pb" INFO:absl:Serving model copied to: /tmp/tfx-interactive-2020-11-26T10_13_22.805724-fe5em05p/Trainer/model/6/serving_model_dir. WARNING:absl:Support for estimator-based executor and model export will be deprecated soon. Please use export structure <ModelExportPath>/eval_model_dir/saved_model.pb" INFO:absl:Eval model copied to: /tmp/tfx-interactive-2020-11-26T10_13_22.805724-fe5em05p/Trainer/model/6/eval_model_dir. INFO:absl:Running publisher for Trainer INFO:absl:MetadataStore with DB connection initialized
Analyze Training with TensorBoard
Optionally, we can connect TensorBoard to the Trainer to analyze our model's training curves.
# Get the URI of the output artifact representing the training logs, which is a directory
model_run_dir = trainer.outputs['model_run'].get()[0].uri
%load_ext tensorboard
%tensorboard --logdir {model_run_dir}
Evaluator
The Evaluator
component computes model performance metrics over the evaluation set. It uses the TensorFlow Model Analysis library. The Evaluator
can also optionally validate that a newly trained model is better than the previous model. This is useful in a production pipeline setting where you may automatically train and validate a model every day. In this notebook, we only train one model, so the Evaluator
automatically will label the model as "good".
Evaluator
will take as input the data from ExampleGen
, the trained model from Trainer
, and slicing configuration. The slicing configuration allows you to slice your metrics on feature values (e.g. how does your model perform on taxi trips that start at 8am versus 8pm?). See an example of this configuration below:
eval_config = tfma.EvalConfig(
model_specs=[
# Using signature 'eval' implies the use of an EvalSavedModel. To use
# a serving model remove the signature to defaults to 'serving_default'
# and add a label_key.
tfma.ModelSpec(signature_name='eval')
],
metrics_specs=[
tfma.MetricsSpec(
# The metrics added here are in addition to those saved with the
# model (assuming either a keras model or EvalSavedModel is used).
# Any metrics added into the saved model (for example using
# model.compile(..., metrics=[...]), etc) will be computed
# automatically.
metrics=[
tfma.MetricConfig(class_name='ExampleCount')
],
# To add validation thresholds for metrics saved with the model,
# add them keyed by metric name to the thresholds map.
thresholds = {
'accuracy': tfma.MetricThreshold(
value_threshold=tfma.GenericValueThreshold(
lower_bound={'value': 0.5}),
change_threshold=tfma.GenericChangeThreshold(
direction=tfma.MetricDirection.HIGHER_IS_BETTER,
absolute={'value': -1e-10}))
}
)
],
slicing_specs=[
# An empty slice spec means the overall slice, i.e. the whole dataset.
tfma.SlicingSpec(),
# Data can be sliced along a feature column. In this case, data is
# sliced along feature column trip_start_hour.
tfma.SlicingSpec(feature_keys=['trip_start_hour'])
])
Next, we give this configuration to Evaluator
and run it.
# Use TFMA to compute a evaluation statistics over features of a model and
# validate them against a baseline.
# The model resolver is only required if performing model validation in addition
# to evaluation. In this case we validate against the latest blessed model. If
# no model has been blessed before (as in this case) the evaluator will make our
# candidate the first blessed model.
model_resolver = ResolverNode(
instance_name='latest_blessed_model_resolver',
resolver_class=latest_blessed_model_resolver.LatestBlessedModelResolver,
model=Channel(type=Model),
model_blessing=Channel(type=ModelBlessing))
context.run(model_resolver)
evaluator = Evaluator(
examples=example_gen.outputs['examples'],
model=trainer.outputs['model'],
#baseline_model=model_resolver.outputs['model'],
# Change threshold will be ignored if there is no baseline (first run).
eval_config=eval_config)
context.run(evaluator)
WARNING:absl:`instance_name` is deprecated, please set node id directly using`with_id()` or `.id` setter. INFO:absl:Running driver for ResolverNode.latest_blessed_model_resolver INFO:absl:MetadataStore with DB connection initialized INFO:absl:Running publisher for ResolverNode.latest_blessed_model_resolver INFO:absl:MetadataStore with DB connection initialized INFO:absl:Running driver for Evaluator INFO:absl:MetadataStore with DB connection initialized INFO:absl:Running executor for Evaluator WARNING:absl:"maybe_add_baseline" and "maybe_remove_baseline" are deprecated, please use "has_baseline" instead. INFO:absl:Request was made to ignore the baseline ModelSpec and any change thresholds. This is likely because a baseline model was not provided: updated_config= model_specs { signature_name: "eval" } slicing_specs { } slicing_specs { feature_keys: "trip_start_hour" } metrics_specs { metrics { class_name: "ExampleCount" } thresholds { key: "accuracy" value { value_threshold { lower_bound { value: 0.5 } } } } } INFO:absl:Using /tmp/tfx-interactive-2020-11-26T10_13_22.805724-fe5em05p/Trainer/model/6/eval_model_dir as model. INFO:absl:The 'example_splits' parameter is not set, using 'eval' split. INFO:absl:Evaluating model. Warning:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_model_analysis/eval_saved_model/load.py:169: load (from tensorflow.python.saved_model.loader_impl) is deprecated and will be removed in a future version. Instructions for updating: This function will only be available through the v1 compatibility library as tf.compat.v1.saved_model.loader.load or tf.compat.v1.saved_model.load. There will be a new function for importing SavedModels in Tensorflow 2.0. INFO:tensorflow:Restoring parameters from /tmp/tfx-interactive-2020-11-26T10_13_22.805724-fe5em05p/Trainer/model/6/eval_model_dir/variables/variables WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_model_analysis/eval_saved_model/graph_ref.py:189: get_tensor_from_tensor_info (from tensorflow.python.saved_model.utils_impl) is deprecated and will be removed in a future version. Instructions for updating: This function will only be available through the v1 compatibility library as tf.compat.v1.saved_model.utils.get_tensor_from_tensor_info or tf.compat.v1.saved_model.get_tensor_from_tensor_info. INFO:absl:Evaluation complete. Results written to /tmp/tfx-interactive-2020-11-26T10_13_22.805724-fe5em05p/Evaluator/evaluation/8. INFO:absl:Checking validation results. INFO:absl:Blessing result True written to /tmp/tfx-interactive-2020-11-26T10_13_22.805724-fe5em05p/Evaluator/blessing/8. INFO:absl:Running publisher for Evaluator INFO:absl:MetadataStore with DB connection initialized
Now let's examine the output artifacts of Evaluator
.
evaluator.outputs
{'evaluation': Channel( type_name: ModelEvaluation artifacts: [Artifact(artifact: id: 10 type_id: 20 uri: "/tmp/tfx-interactive-2020-11-26T10_13_22.805724-fe5em05p/Evaluator/evaluation/8" custom_properties { key: "name" value { string_value: "evaluation" } } custom_properties { key: "producer_component" value { string_value: "Evaluator" } } custom_properties { key: "state" value { string_value: "published" } } state: LIVE , artifact_type: id: 20 name: "ModelEvaluation" )] ), 'blessing': Channel( type_name: ModelBlessing artifacts: [Artifact(artifact: id: 11 type_id: 21 uri: "/tmp/tfx-interactive-2020-11-26T10_13_22.805724-fe5em05p/Evaluator/blessing/8" custom_properties { key: "blessed" value { int_value: 1 } } custom_properties { key: "current_model" value { string_value: "/tmp/tfx-interactive-2020-11-26T10_13_22.805724-fe5em05p/Trainer/model/6" } } custom_properties { key: "current_model_id" value { int_value: 8 } } custom_properties { key: "name" value { string_value: "blessing" } } custom_properties { key: "producer_component" value { string_value: "Evaluator" } } custom_properties { key: "state" value { string_value: "published" } } state: LIVE , artifact_type: id: 21 name: "ModelBlessing" )] )}
Using the evaluation
output we can show the default visualization of global metrics on the entire evaluation set.
context.show(evaluator.outputs['evaluation'])
To see the visualization for sliced evaluation metrics, we can directly call the TensorFlow Model Analysis library.
import tensorflow_model_analysis as tfma
# Get the TFMA output result path and load the result.
PATH_TO_RESULT = evaluator.outputs['evaluation'].get()[0].uri
tfma_result = tfma.load_eval_result(PATH_TO_RESULT)
# Show data sliced along feature column trip_start_hour.
tfma.view.render_slicing_metrics(
tfma_result, slicing_column='trip_start_hour')
SlicingMetricsViewer(config={'weightedExamplesColumn': 'example_count'}, data=[{'slice': 'trip_start_hour:19',…
This visualization shows the same metrics, but computed at every feature value of trip_start_hour
instead of on the entire evaluation set.
TensorFlow Model Analysis supports many other visualizations, such as Fairness Indicators and plotting a time series of model performance. To learn more, see the tutorial.
Since we added thresholds to our config, validation output is also available. The precence of a blessing
artifact indicates that our model passed validation. Since this is the first validation being performed the candidate is automatically blessed.
blessing_uri = evaluator.outputs.blessing.get()[0].uri
!ls -l {blessing_uri}
total 0 -rw-rw-r-- 1 kbuilder kbuilder 0 Nov 26 10:17 BLESSED
Now can also verify the success by loading the validation result record:
PATH_TO_RESULT = evaluator.outputs['evaluation'].get()[0].uri
print(tfma.load_validation_result(PATH_TO_RESULT))
validation_ok: true validation_details { slicing_details { slicing_spec { } num_matching_slices: 25 } }
Pusher
The Pusher
component is usually at the end of a TFX pipeline. It checks whether a model has passed validation, and if so, exports the model to _serving_model_dir
.
pusher = Pusher(
model=trainer.outputs['model'],
model_blessing=evaluator.outputs['blessing'],
push_destination=pusher_pb2.PushDestination(
filesystem=pusher_pb2.PushDestination.Filesystem(
base_directory=_serving_model_dir)))
context.run(pusher)
INFO:absl:Running driver for Pusher INFO:absl:MetadataStore with DB connection initialized INFO:absl:Running executor for Pusher INFO:absl:Model version: 1606385828 INFO:absl:Model written to serving path /tmp/tmpdyqisj4r/serving_model/taxi_simple/1606385828. INFO:absl:Model pushed to /tmp/tfx-interactive-2020-11-26T10_13_22.805724-fe5em05p/Pusher/pushed_model/9. INFO:absl:Running publisher for Pusher INFO:absl:MetadataStore with DB connection initialized
Let's examine the output artifacts of Pusher
.
pusher.outputs
{'pushed_model': Channel( type_name: PushedModel artifacts: [Artifact(artifact: id: 12 type_id: 23 uri: "/tmp/tfx-interactive-2020-11-26T10_13_22.805724-fe5em05p/Pusher/pushed_model/9" custom_properties { key: "name" value { string_value: "pushed_model" } } custom_properties { key: "producer_component" value { string_value: "Pusher" } } custom_properties { key: "pushed" value { int_value: 1 } } custom_properties { key: "pushed_destination" value { string_value: "/tmp/tmpdyqisj4r/serving_model/taxi_simple/1606385828" } } custom_properties { key: "pushed_version" value { string_value: "1606385828" } } custom_properties { key: "state" value { string_value: "published" } } state: LIVE , artifact_type: id: 23 name: "PushedModel" )] )}
In particular, the Pusher will export your model in the SavedModel format, which looks like this:
push_uri = pusher.outputs.model_push.get()[0].uri
model = tf.saved_model.load(push_uri)
for item in model.signatures.items():
pp.pprint(item)
('regression', <ConcreteFunction pruned(inputs) at 0x7F46B81E1630>) ('classification', <ConcreteFunction pruned(inputs) at 0x7F46B8375828>) ('predict', <ConcreteFunction pruned(examples) at 0x7F46D1325F28>) ('serving_default', <ConcreteFunction pruned(inputs) at 0x7F46A0410F28>)
We're finished our tour of built-in TFX components!