Join the SIG TFX-Addons community and help make TFX even better!

tft.make_and_track_object

Keeps track of the object created by invoking trackable_factory_callable.

This API is only for use when Transform APIs are run with TF2 behaviors enabled and tft_beam.Context.force_tf_compat_v1 is set to False.

Use this API to track TF Trackable objects created in the preprocessing_fn such as tf.hub modules, tf.data.Dataset etc. This ensures they are serialized correctly when exporting to SavedModel.

trackable_factory_callable A callable that creates and returns a Trackable object.

Example:

def preprocessing_fn(inputs):
  dataset = tft.make_and_track_object(
      lambda: tf.data.Dataset.from_tensor_slices([1, 2, 3]))
  with tf.init_scope():
    dataset_list = list(dataset.as_numpy_iterator())
  return {'x_0': dataset_list[0] + inputs['x']}
raw_data = [dict(x=1), dict(x=2), dict(x=3)]
feature_spec = dict(x=tf.io.FixedLenFeature([], tf.int64))
raw_data_metadata = tft.tf_metadata.dataset_metadata.DatasetMetadata(
    tft.tf_metadata.schema_utils.schema_from_feature_spec(feature_spec))
with tft_beam.Context(temp_dir=tempfile.mkdtemp(),
                      force_tf_compat_v1=False):
  transformed_dataset, transform_fn = (
      (raw_data, raw_data_metadata)
      | tft_beam.AnalyzeAndTransformDataset(preprocessing_fn))
transformed_data, transformed_metadata = transformed_dataset
transformed_data
[{'x_0': 2}, {'x_0': 3}, {'x_0': 4}]

The object returned when trackable_factory_callable is invoked. The object creation is lifted out to the eager context using tf.init_scope.