TensorFlow 2.x in TFX

TensorFlow 2.0 was released in 2019, with tight integration of Keras, eager execution by default, and Pythonic function execution, among other new features and improvements.

This guide provides a comprehensive technical overview of TF 2.x in TFX.

Which version to use?

TFX is compatible with TensorFlow 2.x, and the high-level APIs that existed in TensorFlow 1.x (particularly Estimators) continue to work.

Start new projects in TensorFlow 2.x

Since TensorFlow 2.x retains the high-level capabilities of TensorFlow 1.x, there is no advantage to using the older version on new projects, even if you don't plan to use the new features.

Therefore, if you are starting a new TFX project, we recommend that you use TensorFlow 2.x. You may want to update your code later as full support for Keras and other new features become available, and the scope of changes will be much more limited if you start with TensorFlow 2.x, rather than trying to upgrade from TensorFlow 1.x in the future.

Converting existing projects to TensorFlow 2.x

Code written for TensorFlow 1.x is largely compatible with TensorFlow 2.x and will continue to work in TFX.

However, if you'd like to take advantage of improvements and new features as they become available in TF 2.x, you can follow the instructions for migrating to TF 2.x.


The Estimator API has been retained in TensorFlow 2.x, but is not the focus of new features and development. Code written in TensorFlow 1.x or 2.x using Estimators will continue to work as expected in TFX.

Here is an end-to-end TFX example using pure Estimator: Taxi example (Estimator)

Keras with model_to_estimator

Keras models can be wrapped with the tf.keras.estimator.model_to_estimator function, which allows them to work as if they were Estimators. To use this:

  1. Build a Keras model.
  2. Pass the compiled model into model_to_estimator.
  3. Use the result of model_to_estimator in Trainer, the way you would typically use an Estimator.
# Build a Keras model.
def _keras_model_builder():
  """Creates a Keras model."""

  model = tf.keras.Model(inputs=inputs, outputs=output)

  return model

# Write a typical trainer function
def trainer_fn(trainer_fn_args, schema):
  """Build the estimator, using model_to_estimator."""

  # Model to estimator
  estimator = tf.keras.estimator.model_to_estimator(
      keras_model=_keras_model_builder(), config=run_config)

  return {
      'estimator': estimator,

Other than the user module file of Trainer, the rest of the pipeline remains unchanged. Here is an end-to-end TFX example using Keras with model_to_estimator: Iris example (model_to_estimator)

Native Keras (i.e. Keras without model_to_estimator)

Examples and Colab

Here are several examples with native Keras:

We also have a per-component Keras Colab.

TFX Components

The following sections explain how related TFX components support native Keras.


Transform currently has experimental support for Keras models.

The Transform component itself can be used for native Keras without change. The preprocessing_fn definition remains the same, using TensorFlow and tf.Transform ops.

The serving function and eval function are changed for native Keras. Details will be discussed in the following Trainer and Evaluator sections.


To configure native Keras, the GenericExecutor needs to be set for Trainer component to replace the default Estimator based executor. For details, please check here.

Keras Module file with Transform

The training module file must contains a run_fn which will be called by the GenericExecutor, a typical Keras run_fn would look like this:

def run_fn(fn_args: TrainerFnArgs):
  """Train the model based on given args.

    fn_args: Holds args used to train the model as name/value pairs.
  tf_transform_output = tft.TFTransformOutput(fn_args.transform_output)

  # Train and eval files contains transformed examples.
  # _input_fn read dataset based on transformed feature_spec from tft.
  train_dataset = _input_fn(fn_args.train_files, tf_transform_output, 40)
  eval_dataset = _input_fn(fn_args.eval_files, tf_transform_output, 40)

  model = _build_keras_model()


  signatures = {
  model.save(fn_args.serving_model_dir, save_format='tf', signatures=signatures)

In the run_fn above, a serving signature is needed when exporting the trained model so that model can take raw examples for prediction. A typical serving function would look like this:

def _get_serve_tf_examples_fn(model, tf_transform_output):
  """Returns a function that parses a serialized tf.Example."""

  # the layer is added as an attribute to the model in order to make sure that
  # the model assets are handled correctly when exporting.
  model.tft_layer = tf_transform_output.transform_features_layer()

  def serve_tf_examples_fn(serialized_tf_examples):
    """Returns the output to be used in the serving signature."""
    feature_spec = tf_transform_output.raw_feature_spec()
    parsed_features = tf.io.parse_example(serialized_tf_examples, feature_spec)

    transformed_features = model.tft_layer(parsed_features)

    return model(transformed_features)

  return serve_tf_examples_fn

In above serving function, tf.Transform transformations need to be applied to the raw data for inference, using the tft.TransformFeaturesLayer layer. The previous _serving_input_receiver_fn which was required for Estimators will no longer be needed with Keras.

Keras Module file without Transform

This is similar to the module file shown above, but without the transformations:

def _get_serve_tf_examples_fn(model, schema):

  def serve_tf_examples_fn(serialized_tf_examples):
    feature_spec = _get_raw_feature_spec(schema)
    parsed_features = tf.io.parse_example(serialized_tf_examples, feature_spec)
    return model(parsed_features)

  return serve_tf_examples_fn

def run_fn(fn_args: TrainerFnArgs):
  schema = io_utils.parse_pbtxt_file(fn_args.schema_file, schema_pb2.Schema())

  # Train and eval files contains raw examples.
  # _input_fn reads the dataset based on raw feature_spec from schema.
  train_dataset = _input_fn(fn_args.train_files, schema, 40)
  eval_dataset = _input_fn(fn_args.eval_files, schema, 40)

  model = _build_keras_model()


  signatures = {
          _get_serve_tf_examples_fn(model, schema).get_concrete_function(
              tf.TensorSpec(shape=[None], dtype=tf.string, name='examples')),
  model.save(fn_args.serving_model_dir, save_format='tf', signatures=signatures)

At this time TFX only supports single worker strategies (e.g., MirroredStrategy, OneDeviceStrategy).

To use a distribution strategy, create an appropriate tf.distribute.Strategy and move the creation and compiling of the Keras model inside a strategy scope.

For example, replace above model = _build_keras_model() with:

  mirrored_strategy = tf.distribute.MirroredStrategy()
  with mirrored_strategy.scope():
    model = _build_keras_model()

  # Rest of the code can be unchanged.

To verify the device (CPU/GPU) used by MirroredStrategy, enable info level tensorflow logging:

import logging

and you should be able to see Using MirroredStrategy with devices (...) in the log.


In TFMA v0.2x, ModelValidator and Evaluator have been combined into a single new Evaluator component. The new Evaluator component can perform both single model evaluation and also validate the current model compared with previous models. With this change, the Pusher component now consumes a blessing result from Evaluator instead of ModelValidator.

The new Evaluator supports Keras models as well as Estimator models. The _eval_input_receiver_fn and eval saved model which were required previously will no longer be needed with Keras, since Evaluator is now based on the same SavedModel that is used for serving.

See Evaluator for more information.