Get Started with TensorFlow Transform

This guide introduces the basic concepts of tf.Transform and how to use them. It will:

  • Define a preprocessing function, a logical description of the pipeline that transforms the raw data into the data used to train a machine learning model.
  • Show the Apache Beam implementation used to transform data by converting the preprocessing function into a Beam pipeline.
  • Show additional usage examples.

Setup

pip install -U tensorflow_transform
pip install pyarrow
import pkg_resources
import importlib
importlib.reload(pkg_resources)
<module 'pkg_resources' from '/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/pkg_resources/__init__.py'>
import os
import tempfile

import tensorflow as tf
import tensorflow_transform as tft
import tensorflow_transform.beam as tft_beam

from tensorflow_transform.tf_metadata import dataset_metadata
from tensorflow_transform.tf_metadata import schema_utils

from tfx_bsl.public import tfxio
2023-04-13 09:15:54.685940: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory
2023-04-13 09:15:54.686060: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory
2023-04-13 09:15:54.686073: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.

Define a preprocessing function

The preprocessing function is the most important concept of tf.Transform. The preprocessing function is a logical description of a transformation of the dataset. The preprocessing function accepts and returns a dictionary of tensors, where a tensor means Tensor or SparseTensor. There are two kinds of functions used to define the preprocessing function:

  1. Any function that accepts and returns tensors. These add TensorFlow operations to the graph that transform raw data into transformed data.
  2. Any of the analyzers provided by tf.Transform. Analyzers also accept and return tensors, but unlike TensorFlow functions, they do not add operations to the graph. Instead, analyzers cause tf.Transform to compute a full-pass operation outside of TensorFlow. They use the input tensor values over the entire dataset to generate a constant tensor that is returned as the output. For example, tft.min computes the minimum of a tensor over the dataset. tf.Transform provides a fixed set of analyzers, but this will be extended in future versions.

Preprocessing function example

By combining analyzers and regular TensorFlow functions, users can create flexible pipelines for transforming data. The following preprocessing function transforms each of the three features in different ways, and combines two of the features:

def preprocessing_fn(inputs):
  x = inputs['x']
  y = inputs['y']
  s = inputs['s']
  x_centered = x - tft.mean(x)
  y_normalized = tft.scale_to_0_1(y)
  s_integerized = tft.compute_and_apply_vocabulary(s)
  x_centered_times_y_normalized = x_centered * y_normalized
  return {
      'x_centered': x_centered,
      'y_normalized': y_normalized,
      'x_centered_times_y_normalized': x_centered_times_y_normalized,
      's_integerized': s_integerized
  }

Here, x, y and s are Tensors that represent input features. The first new tensor that is created, x_centered, is built by applying tft.mean to x and subtracting this from x. tft.mean(x) returns a tensor representing the mean of the tensor x. x_centered is the tensor x with the mean subtracted.

The second new tensor, y_normalized, is created in a similar manner but using the convenience method tft.scale_to_0_1. This method does something similar to computing x_centered, namely computing a maximum and minimum and using these to scale y.

The tensor s_integerized shows an example of string manipulation. In this case, we take a string and map it to an integer. This uses the convenience function tft.compute_and_apply_vocabulary. This function uses an analyzer to compute the unique values taken by the input strings, and then uses TensorFlow operations to convert the input strings to indices in the table of unique values.

The final column shows that it is possible to use TensorFlow operations to create new features by combining tensors.

The preprocessing function defines a pipeline of operations on a dataset. In order to apply the pipeline, we rely on a concrete implementation of the tf.Transform API. The Apache Beam implementation provides PTransform which applies a user's preprocessing function to data. The typical workflow of a tf.Transform user will construct a preprocessing function, then incorporate this into a larger Beam pipeline, creating the data for training.

Batching

Batching is an important part of TensorFlow. Since one of the goals of tf.Transform is to provide a TensorFlow graph for preprocessing that can be incorporated into the serving graph (and, optionally, the training graph), batching is also an important concept in tf.Transform.

While not obvious in the example above, the user defined preprocessing function is passed tensors representing batches and not individual instances, as happens during training and serving with TensorFlow. On the other hand, analyzers perform a computation over the entire dataset that returns a single value and not a batch of values. x is a Tensor with a shape of (batch_size,), while tft.mean(x) is a Tensor with a shape of (). The subtraction x - tft.mean(x) broadcasts where the value of tft.mean(x) is subtracted from every element of the batch represented by x.

Apache Beam Implementation

While the preprocessing function is intended as a logical description of a preprocessing pipeline implemented on multiple data processing frameworks, tf.Transform provides a canonical implementation used on Apache Beam. This implementation demonstrates the functionality required from an implementation. There is no formal API for this functionality, so each implementation can use an API that is idiomatic for its particular data processing framework.

The Apache Beam implementation provides two PTransforms used to process data for a preprocessing function. The following shows the usage for the composite PTransform - tft_beam.AnalyzeAndTransformDataset:

raw_data = [
    {'x': 1, 'y': 1, 's': 'hello'},
    {'x': 2, 'y': 2, 's': 'world'},
    {'x': 3, 'y': 3, 's': 'hello'}
]

raw_data_metadata = dataset_metadata.DatasetMetadata(
    schema_utils.schema_from_feature_spec({
        'y': tf.io.FixedLenFeature([], tf.float32),
        'x': tf.io.FixedLenFeature([], tf.float32),
        's': tf.io.FixedLenFeature([], tf.string),
    }))

with tft_beam.Context(temp_dir=tempfile.mkdtemp()):
  transformed_dataset, transform_fn = (
      (raw_data, raw_data_metadata) |
      tft_beam.AnalyzeAndTransformDataset(preprocessing_fn))
WARNING:apache_beam.runners.interactive.interactive_environment:Dependencies required for Interactive Beam PCollection visualization are not available, please use: `pip install apache-beam[interactive]` to install necessary dependencies to enable all data visualization features.
WARNING:tensorflow:You are passing instance dicts and DatasetMetadata to TFT which will not provide optimal performance. Consider following the TFT guide to upgrade to the TFXIO format (Apache Arrow RecordBatch).
WARNING:tensorflow:You are passing instance dicts and DatasetMetadata to TFT which will not provide optimal performance. Consider following the TFT guide to upgrade to the TFXIO format (Apache Arrow RecordBatch).
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_transform/tf_utils.py:324: Tensor.experimental_ref (from tensorflow.python.framework.ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use ref() instead.
2023-04-13 09:15:56.867283: E tensorflow/compiler/xla/stream_executor/cuda/cuda_driver.cc:267] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_transform/tf_utils.py:324: Tensor.experimental_ref (from tensorflow.python.framework.ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use ref() instead.
WARNING:tensorflow:You are passing instance dicts and DatasetMetadata to TFT which will not provide optimal performance. Consider following the TFT guide to upgrade to the TFXIO format (Apache Arrow RecordBatch).
WARNING:tensorflow:You are passing instance dicts and DatasetMetadata to TFT which will not provide optimal performance. Consider following the TFT guide to upgrade to the TFXIO format (Apache Arrow RecordBatch).
WARNING:apache_beam.options.pipeline_options:Discarding unparseable args: ['/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/ipykernel_launcher.py', '-f', '/tmpfs/tmp/tmpzu0d2pwa.json', '--HistoryManager.hist_file=:memory:']
INFO:tensorflow:Assets written to: /tmpfs/tmp/tmpdhm3m_yu/tftransform_tmp/88750e1500194862a87b2f23e04367bc/assets
INFO:tensorflow:Assets written to: /tmpfs/tmp/tmpdhm3m_yu/tftransform_tmp/88750e1500194862a87b2f23e04367bc/assets
INFO:tensorflow:struct2tensor is not available.
INFO:tensorflow:struct2tensor is not available.
INFO:tensorflow:tensorflow_decision_forests is not available.
INFO:tensorflow:tensorflow_decision_forests is not available.
INFO:tensorflow:tensorflow_text is not available.
INFO:tensorflow:tensorflow_text is not available.
INFO:tensorflow:Assets written to: /tmpfs/tmp/tmpdhm3m_yu/tftransform_tmp/8fad0af5a26242cc9733a752a7652277/assets
INFO:tensorflow:Assets written to: /tmpfs/tmp/tmpdhm3m_yu/tftransform_tmp/8fad0af5a26242cc9733a752a7652277/assets
INFO:tensorflow:struct2tensor is not available.
INFO:tensorflow:struct2tensor is not available.
INFO:tensorflow:tensorflow_decision_forests is not available.
INFO:tensorflow:tensorflow_decision_forests is not available.
INFO:tensorflow:tensorflow_text is not available.
INFO:tensorflow:tensorflow_text is not available.
transformed_data, transformed_metadata = transformed_dataset

The transformed_data content is shown below and contains the transformed columns in the same format as the raw data. In particular, the values of s_integerized are [0, 1, 0]—these values depend on how the words hello and world were mapped to integers, which is deterministic. For the column x_centered, we subtracted the mean so the values of the column x, which were [1.0, 2.0, 3.0], became [-1.0, 0.0, 1.0]. Similarly, the rest of the columns match their expected values.

transformed_data
[{'s_integerized': 0,
  'x_centered': -1.0,
  'x_centered_times_y_normalized': -0.0,
  'y_normalized': 0.0},
 {'s_integerized': 1,
  'x_centered': 0.0,
  'x_centered_times_y_normalized': 0.0,
  'y_normalized': 0.5},
 {'s_integerized': 0,
  'x_centered': 1.0,
  'x_centered_times_y_normalized': 1.0,
  'y_normalized': 1.0}]

Both raw_data and transformed_data are datasets. The next two sections show how the Beam implementation represents datasets and how to read and write data to disk. The other return value, transform_fn, represents the transformation applied to the data, covered in detail below.

The tft_beam.AnalyzeAndTransformDataset class is the composition of the two fundamental transforms provided by the implementation tft_beam.AnalyzeDataset and tft_beam.TransformDataset. So the following two code snippets are equivalent:

my_data = (raw_data, raw_data_metadata)
with tft_beam.Context(temp_dir=tempfile.mkdtemp()):
  transformed_data, transform_fn = (
      my_data | tft_beam.AnalyzeAndTransformDataset(preprocessing_fn))
WARNING:tensorflow:You are passing instance dicts and DatasetMetadata to TFT which will not provide optimal performance. Consider following the TFT guide to upgrade to the TFXIO format (Apache Arrow RecordBatch).
WARNING:tensorflow:You are passing instance dicts and DatasetMetadata to TFT which will not provide optimal performance. Consider following the TFT guide to upgrade to the TFXIO format (Apache Arrow RecordBatch).
WARNING:tensorflow:You are passing instance dicts and DatasetMetadata to TFT which will not provide optimal performance. Consider following the TFT guide to upgrade to the TFXIO format (Apache Arrow RecordBatch).
WARNING:tensorflow:You are passing instance dicts and DatasetMetadata to TFT which will not provide optimal performance. Consider following the TFT guide to upgrade to the TFXIO format (Apache Arrow RecordBatch).
WARNING:apache_beam.options.pipeline_options:Discarding unparseable args: ['/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/ipykernel_launcher.py', '-f', '/tmpfs/tmp/tmpzu0d2pwa.json', '--HistoryManager.hist_file=:memory:']
INFO:tensorflow:Assets written to: /tmpfs/tmp/tmp8afa0l99/tftransform_tmp/8dc250e431e848a386d53f050ae886df/assets
INFO:tensorflow:Assets written to: /tmpfs/tmp/tmp8afa0l99/tftransform_tmp/8dc250e431e848a386d53f050ae886df/assets
INFO:tensorflow:struct2tensor is not available.
INFO:tensorflow:struct2tensor is not available.
INFO:tensorflow:tensorflow_decision_forests is not available.
INFO:tensorflow:tensorflow_decision_forests is not available.
INFO:tensorflow:tensorflow_text is not available.
INFO:tensorflow:tensorflow_text is not available.
INFO:tensorflow:Assets written to: /tmpfs/tmp/tmp8afa0l99/tftransform_tmp/46d2e23e8b9745219e9812f9b7f5aee1/assets
INFO:tensorflow:Assets written to: /tmpfs/tmp/tmp8afa0l99/tftransform_tmp/46d2e23e8b9745219e9812f9b7f5aee1/assets
INFO:tensorflow:struct2tensor is not available.
INFO:tensorflow:struct2tensor is not available.
INFO:tensorflow:tensorflow_decision_forests is not available.
INFO:tensorflow:tensorflow_decision_forests is not available.
INFO:tensorflow:tensorflow_text is not available.
INFO:tensorflow:tensorflow_text is not available.
with tft_beam.Context(temp_dir=tempfile.mkdtemp()):
  transform_fn = my_data | tft_beam.AnalyzeDataset(preprocessing_fn)
  transformed_data = (my_data, transform_fn) | tft_beam.TransformDataset()
WARNING:tensorflow:You are passing instance dicts and DatasetMetadata to TFT which will not provide optimal performance. Consider following the TFT guide to upgrade to the TFXIO format (Apache Arrow RecordBatch).
WARNING:tensorflow:You are passing instance dicts and DatasetMetadata to TFT which will not provide optimal performance. Consider following the TFT guide to upgrade to the TFXIO format (Apache Arrow RecordBatch).
WARNING:apache_beam.options.pipeline_options:Discarding unparseable args: ['/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/ipykernel_launcher.py', '-f', '/tmpfs/tmp/tmpzu0d2pwa.json', '--HistoryManager.hist_file=:memory:']
INFO:tensorflow:Assets written to: /tmpfs/tmp/tmpoezjiky4/tftransform_tmp/2f6feb69b15d4a429fa4f56dd7fb02a3/assets
INFO:tensorflow:Assets written to: /tmpfs/tmp/tmpoezjiky4/tftransform_tmp/2f6feb69b15d4a429fa4f56dd7fb02a3/assets
INFO:tensorflow:struct2tensor is not available.
INFO:tensorflow:struct2tensor is not available.
INFO:tensorflow:tensorflow_decision_forests is not available.
INFO:tensorflow:tensorflow_decision_forests is not available.
INFO:tensorflow:tensorflow_text is not available.
INFO:tensorflow:tensorflow_text is not available.
INFO:tensorflow:Assets written to: /tmpfs/tmp/tmpoezjiky4/tftransform_tmp/26cbcc6000e947c798b5af9ad57c0b42/assets
INFO:tensorflow:Assets written to: /tmpfs/tmp/tmpoezjiky4/tftransform_tmp/26cbcc6000e947c798b5af9ad57c0b42/assets
WARNING:tensorflow:You are passing instance dicts and DatasetMetadata to TFT which will not provide optimal performance. Consider following the TFT guide to upgrade to the TFXIO format (Apache Arrow RecordBatch).
WARNING:tensorflow:You are passing instance dicts and DatasetMetadata to TFT which will not provide optimal performance. Consider following the TFT guide to upgrade to the TFXIO format (Apache Arrow RecordBatch).
WARNING:apache_beam.options.pipeline_options:Discarding unparseable args: ['/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/ipykernel_launcher.py', '-f', '/tmpfs/tmp/tmpzu0d2pwa.json', '--HistoryManager.hist_file=:memory:']
INFO:tensorflow:struct2tensor is not available.
INFO:tensorflow:struct2tensor is not available.
INFO:tensorflow:tensorflow_decision_forests is not available.
INFO:tensorflow:tensorflow_decision_forests is not available.
INFO:tensorflow:tensorflow_text is not available.
INFO:tensorflow:tensorflow_text is not available.

transform_fn is a pure function that represents an operation that is applied to each row of the dataset. In particular, the analyzer values are already computed and treated as constants. In the example, the transform_fn contains as constants the mean of column x, the min and max of column y, and the vocabulary used to map the strings to integers.

An important feature of tf.Transform is that transform_fn represents a map over rows—it is a pure function applied to each row separately. All of the computation for aggregating rows is done in AnalyzeDataset. Furthermore, the transform_fn is represented as a TensorFlow Graph which can be embedded into the serving graph.

AnalyzeAndTransformDataset is provided for optimizations in this special case. This is the same pattern used in scikit-learn, providing the fit, transform, and fit_transform methods.

Data Formats and Schema

TFT Beam implementation accepts two different input data formats. The "instance dict" format (as seen in the example above and simple.ipynb & simple_example.py) is an intuitive format and is suitable for small datasets while the TFXIO (Apache Arrow) format provides improved performance and is suitble for large datasets.

The "metadata" accompanying the PCollection tells the Beam implementation the format of the PCollection.

(raw_data, raw_data_metadata) | tft.AnalyzeDataset(...)
  • If raw_data_metadata is a dataset_metadata.DatasetMetadata (see below, "The 'instance dict' format" section), then raw_data is expected to be in the "instance dict" format.
  • If raw_data_metadata is a tfxio.TensorAdapterConfig (see below, "The TFXIO format" section), then raw_data is expected to be in the TFXIO format.

The "instance dict" format

The previous code examples used this format. The metadata contains the schema that defines the layout of the data and how it is read from and written to various formats. Even this in-memory format is not self-describing and requires the schema in order to be interpreted as tensors.

Again, here is the definition of the schema for the example data:

import tensorflow_transform as tft

raw_data_metadata = tft.DatasetMetadata.from_feature_spec({
        's': tf.io.FixedLenFeature([], tf.string),
        'y': tf.io.FixedLenFeature([], tf.float32),
        'x': tf.io.FixedLenFeature([], tf.float32),
    })

The Schema proto contains the information needed to parse the data from its on-disk or in-memory format, into tensors. It is typically constructed by calling schema_utils.schema_from_feature_spec with a dict mapping feature keys to tf.io.FixedLenFeature, tf.io.VarLenFeature, and tf.io.SparseFeature values. See the documentation for tf.parse_example for more details.

Above we use tf.io.FixedLenFeature to indicate that each feature contains a fixed number of values, in this case a single scalar value. Because tf.Transform batches instances, the actual Tensor representing the feature will have shape (None,) where the unknown dimension is the batch dimension.

The TFXIO format

With this format, the data is expected to be contained in a pyarrow.RecordBatch. For tabular data, our Apache Beam implementation accepts Arrow RecordBatches that consist of columns of the following types:

  • pa.list_(<primitive>), where <primitive> is pa.int64(), pa.float32() pa.binary() or pa.large_binary().

  • pa.large_list(<primitive>)

The toy input dataset we used above, when represented as a RecordBatch, looks like the following:

import pyarrow as pa

raw_data = [
    pa.record_batch(
    data=[
        pa.array([[1], [2], [3]], pa.list_(pa.float32())),
        pa.array([[1], [2], [3]], pa.list_(pa.float32())),
        pa.array([['hello'], ['world'], ['hello']], pa.list_(pa.binary())),
    ],
    names=['x', 'y', 's'])
]

Similar to the dataset_metadata.DatasetMetadata instance that accompanies the "instance dict" format, a tfxio.TensorAdapterConfig is must accompany the RecordBatches. It consists of the Arrow schema of the RecordBatches, and tfxio.TensorRepresentations to uniquely determine how columns in RecordBatches can be interpreted as TensorFlow Tensors (including but not limited to tf.Tensor, tf.SparseTensor).

tfxio.TensorRepresentations is type alias for a Dict[str, tensorflow_metadata.proto.v0.schema_pb2.TensorRepresentation] which establishes the relationship between a Tensor that a preprocessing_fn accepts and columns in the RecordBatches. For example:

from google.protobuf import text_format
from tensorflow_metadata.proto.v0 import schema_pb2

tensor_representation = {
    'x': text_format.Parse(
        """dense_tensor { column_name: "col1" shape { dim { size: 2 } } }""",
        schema_pb2.TensorRepresentation())
}

Means that inputs['x'] in preprocessing_fn should be a dense tf.Tensor, whose values come from a column of name 'col1' in the input RecordBatches, and its (batched) shape should be [batch_size, 2].

A schema_pb2.TensorRepresentation is a Protobuf defined in TensorFlow Metadata.

Compatibility with TensorFlow

tf.Transform provides support for exporting the transform_fn as a SavedModel, see the simple tutorial for an example. The default behavior before the 0.30 release exported a TF 1.x SavedModel. Starting with the 0.30 release, the default behavior is to export a TF 2.x SavedModel unless TF 2.x behaviors are explicitly disabled (by calling tf.compat.v1.disable_v2_behavior()).

If using TF 1.x concepts such as tf.estimator and tf.Sessions, you can retain the previous behavior by passing force_tf_compat_v1=True to tft_beam.Context if using tf.Transform as a standalone library or to the Transform component in TFX.

When exporting the transform_fn as a TF 2.x SavedModel, the preprocessing_fn is expected to be traceable using tf.function. Additionally, if running your pipeline remotely (for example with the DataflowRunner), ensure that the preprocessing_fn and any dependencies are packaged properly as described here.

Known issues with using tf.Transform to export a TF 2.x SavedModel are documented here.

Input and output with Apache Beam

So far, we've seen input and output data in python lists (of RecordBatches or instance dictionaries). This is a simplification that relies on Apache Beam's ability to work with lists as well as its main representation of data, the PCollection.

A PCollection is a data representation that forms a part of a Beam pipeline. A Beam pipeline is formed by applying various PTransforms, including AnalyzeDataset and TransformDataset, and running the pipeline. A PCollection is not created in the memory of the main binary, but instead is distributed among the workers (although this section uses the in-memory execution mode).

Pre-canned PCollection Sources (TFXIO)

The RecordBatch format that our implementation accepts is a common format that other TFX libraries accept. Therefore TFX offers convenient "sources" (a.k.a TFXIO) that read files of various formats on disk and produce RecordBatches and can also give tfxio.TensorAdapterConfig, including inferred tfxio.TensorRepresentations.

Those TFXIOs can be found in package tfx_bsl (tfx_bsl.public.tfxio).

Example: "Census Income" dataset

The following example requires both reading and writing data on disk and representing data as a PCollection (not a list), see: census_example.py. Below we show how to download the data and run this example. The "Census Income" dataset is provided by the UCI Machine Learning Repository. This dataset contains both categorical and numeric data.

Here is some code to download and preview this data:

wget https://storage.googleapis.com/artifacts.tfx-oss-public.appspot.com/datasets/census/adult.data
--2023-04-13 09:16:10--  https://storage.googleapis.com/artifacts.tfx-oss-public.appspot.com/datasets/census/adult.data
Resolving storage.googleapis.com (storage.googleapis.com)... 172.217.203.128, 74.125.141.128, 142.250.98.128, ...
Connecting to storage.googleapis.com (storage.googleapis.com)|172.217.203.128|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 3974305 (3.8M) [application/octet-stream]
Saving to: ‘adult.data’

adult.data          100%[===================>]   3.79M  --.-KB/s    in 0.02s   

2023-04-13 09:16:10 (153 MB/s) - ‘adult.data’ saved [3974305/3974305]
import pandas as pd

train_data_file = "adult.data"

There's some configuration code hidden in the cell below.

pd.read_csv(train_data_file, names = ORDERED_CSV_COLUMNS).head()

The columns of the dataset are either categorical or numeric. This dataset describes a classification problem: predicting the last column where the individual earns more or less than 50K per year. However, from the perspective of tf.Transform, this label is just another categorical column.

We use a Pre-canned tfxio.BeamRecordCsvTFXIO to translate the CSV lines into RecordBatches. TFXIO requires two important piece of information:

  • a TensorFlow Metadata Schema,tfmd.proto.v0.shema_pb2, that contains type and shape information about each CSV column. schema_pb2.TensorRepresentations are an optional part of the Schema; if not provided (which is the case in this example), they will be inferred from the type and shape information. One can get the Schema either by using a helper function we provide to translate from TF parsing specs (shown in this example), or by running TensorFlow Data Validation.
  • a list of column names, in the order they appear in the CSV file. Note that those names must match the feature names in the Schema.
pip install -U -q tfx_bsl
from tfx_bsl.public import tfxio
from tfx_bsl.coders.example_coder import RecordBatchToExamples

import apache_beam as beam
pipeline = beam.Pipeline()

csv_tfxio = tfxio.BeamRecordCsvTFXIO(
    physical_format='text', column_names=ORDERED_CSV_COLUMNS, schema=SCHEMA)

raw_data = (
    pipeline
    | 'ReadTrainData' >> beam.io.ReadFromText(
        train_data_file, coder=beam.coders.BytesCoder())
    | 'FixCommasTrainData' >> beam.Map(
        lambda line: line.replace(b', ', b','))
    | 'DecodeTrainData' >> csv_tfxio.BeamSource())
raw_data
<PCollection[[21]: DecodeTrainData/RawRecordToRecordBatch/CollectRecordBatchTelemetry/ProfileRecordBatches.None] at 0x7feeaa6fd5b0>

Note that we had to do some additional fix-ups after the CSV lines are read in. Otherwise, we could rely on the tfxio.CsvTFXIO to handle both reading the files and translating to RecordBatches:

csv_tfxio = tfxio.CsvTFXIO(train_data_file,
                           telemetry_descriptors=[], #???
                           column_names=ORDERED_CSV_COLUMNS,
                           schema=SCHEMA)

p2 = beam.Pipeline()
raw_data_2 = p2 | 'TFXIORead' >> csv_tfxio.BeamSource()

Preprocessing for this dataset is similar to the previous example, except the preprocessing function is programmatically generated instead of manually specifying each column. In the preprocessing function below, NUMERICAL_COLUMNS and CATEGORICAL_COLUMNS are lists that contain the names of the numeric and categorical columns:

NUM_OOV_BUCKETS = 1

def preprocessing_fn(inputs):
  """Preprocess input columns into transformed columns."""
  # Since we are modifying some features and leaving others unchanged, we
  # start by setting `outputs` to a copy of `inputs.
  outputs = inputs.copy()

  # Scale numeric columns to have range [0, 1].
  for key in NUMERIC_FEATURE_KEYS:
    outputs[key] = tft.scale_to_0_1(outputs[key])

  # For all categorical columns except the label column, we generate a
  # vocabulary but do not modify the feature.  This vocabulary is instead
  # used in the trainer, by means of a feature column, to convert the feature
  # from a string to an integer id.
  for key in CATEGORICAL_FEATURE_KEYS:
    outputs[key] = tft.compute_and_apply_vocabulary(
        tf.strings.strip(inputs[key]),
        num_oov_buckets=NUM_OOV_BUCKETS,
        vocab_filename=key)

  # For the label column we provide the mapping from string to index.
  with tf.init_scope():
    # `init_scope` - Only initialize the table once.
    initializer = tf.lookup.KeyValueTensorInitializer(
        keys=['>50K', '<=50K'],
        values=tf.cast(tf.range(2), tf.int64),
        key_dtype=tf.string,
        value_dtype=tf.int64)
    table = tf.lookup.StaticHashTable(initializer, default_value=-1)

  outputs[LABEL_KEY] = table.lookup(outputs[LABEL_KEY])

  return outputs

One difference from the previous example is the label column manually specifies the mapping from the string to an index. So '>50' is mapped to 0 and '<=50K' is mapped to 1 because it's useful to know which index in the trained model corresponds to which label.

The record_batches variable represents a PCollection of pyarrow.RecordBatches. The tensor_adapter_config is given by csv_tfxio, which is inferred from SCHEMA (and ultimately, in this example, from the TF parsing specs).

The final stage is to write the transformed data to disk and has a similar form to reading the raw data. The schema used to do this is part of the output of tft_beam.AnalyzeAndTransformDataset which infers a schema for the output data. The code to write to disk is shown below. The schema is a part of the metadata but uses the two interchangeably in the tf.Transform API (i.e. pass the metadata to the tft.coders.ExampleProtoCoder). Be aware that this writes to a different format. Instead of textio.WriteToText, use Beam's built-in support for the TFRecord format and use a coder to encode the data as Example protos. This is a better format to use for training, as shown in the next section. transformed_eval_data_base provides the base filename for the individual shards that are written.

raw_dataset = (raw_data, csv_tfxio.TensorAdapterConfig())
working_dir = tempfile.mkdtemp()
with tft_beam.Context(temp_dir=working_dir):
  transformed_dataset, transform_fn = (
      raw_dataset | tft_beam.AnalyzeAndTransformDataset(
          preprocessing_fn, output_record_batches=True))
output_dir = tempfile.mkdtemp()
transformed_data, _ = transformed_dataset

_ = (
    transformed_data
    | 'EncodeTrainData' >>
    beam.FlatMapTuple(lambda batch, _: RecordBatchToExamples(batch))
    | 'WriteTrainData' >> beam.io.WriteToTFRecord(
        os.path.join(output_dir , 'transformed.tfrecord')))

In addition to the training data, transform_fn is also written out with the metadata:

_ = (
    transform_fn
    | 'WriteTransformFn' >> tft_beam.WriteTransformFn(output_dir))

Run the entire Beam pipeline with pipeline.run().wait_until_finish(). Up until this point, the Beam pipeline represents a deferred, distributed computation. It provides instructions for what will be done, but the instructions have not been executed. This final call executes the specified pipeline.

result = pipeline.run().wait_until_finish()
INFO:tensorflow:Assets written to: /tmpfs/tmp/tmphiyrst4f/tftransform_tmp/c633cd0eb0c14a2bba2bc6f7ba556ce3/assets
INFO:tensorflow:Assets written to: /tmpfs/tmp/tmphiyrst4f/tftransform_tmp/c633cd0eb0c14a2bba2bc6f7ba556ce3/assets
INFO:tensorflow:struct2tensor is not available.
INFO:tensorflow:struct2tensor is not available.
INFO:tensorflow:tensorflow_decision_forests is not available.
INFO:tensorflow:tensorflow_decision_forests is not available.
INFO:tensorflow:tensorflow_text is not available.
INFO:tensorflow:tensorflow_text is not available.
INFO:tensorflow:Assets written to: /tmpfs/tmp/tmphiyrst4f/tftransform_tmp/9080e8c73e2443fea34d6505feed4129/assets
INFO:tensorflow:Assets written to: /tmpfs/tmp/tmphiyrst4f/tftransform_tmp/9080e8c73e2443fea34d6505feed4129/assets
INFO:tensorflow:struct2tensor is not available.
INFO:tensorflow:struct2tensor is not available.
INFO:tensorflow:tensorflow_decision_forests is not available.
INFO:tensorflow:tensorflow_decision_forests is not available.
INFO:tensorflow:tensorflow_text is not available.
INFO:tensorflow:tensorflow_text is not available.
WARNING:apache_beam.io.tfrecordio:Couldn't find python-snappy so the implementation of _TFRecordUtil._masked_crc32c is not as fast as it could be.

After running the pipeline the output directory contains two artifacts.

  • The transformed data, and the metadata describing it.
  • The tf.saved_model containing the resulting preprocessing_fn
ls {output_dir}
transform_fn  transformed.tfrecord-00000-of-00001  transformed_metadata

To see how to use these artifacts refer to the Advanced preprocessing tutorial.