Using Other ML Frameworks in TFX

TFX as a platform is framework neutral, and can be used with other ML frameworks, e.g., JAX, scikit-learn.

For model developers, this means they do not need to rewrite their model code implemented in another ML framework, but can instead reuse the bulk of the training code as-is in TFX, and benefit from other capabilities TFX and the rest of the TensorFlow Ecosystem offers.

The TFX pipeline SDK and most modules in TFX, e.g., pipeline orchestrator, don't have any direct dependency on TensorFlow, but there are some aspects which are oriented towards TensorFlow, such as data formats. With some consideration of the needs of a particular modeling framework, a TFX pipeline can be used to train models in any other Python-based ML framework. This includes Scikit-learn, XGBoost, and PyTorch, among others. Some of the considerations for using the standard TFX components with other frameworks include:

  • ExampleGen outputs tf.train.Example in TFRecord files. It's a generic representation for training data, and downstream components use TFXIO to read it as Arrow/RecordBatch in memory, which can be further converted to tf.dataset, Tensors or other formats. Payload/File formats other than tf.train.Example/TFRecord are being considered, but for TFXIO users it should be a blackbox.
  • Transform can be used to generate transformed training examples no matter what framework is used for training, but if the model format is not saved_model, users won't be able to embed the transform graph into the model. In that case, model prediction needs to take transformed features instead of raw features, and users can run transform as a preprocessing step before calling the model prediction when serving.
  • Trainer supports GenericTraining so users can train their models using any ML framework.
  • Evaluator by default only supports saved_model, but users can provide a UDF that generates predictions for model evaluation.

Training a model in a non-Python-based framework will require isolating a custom training component in a Docker container, as part of a pipeline which is running in a containerized environment such as Kubernetes.

JAX

JAX is Autograd and XLA, brought together for high-performance machine learning research. Flax is a neural network library and ecosystem for JAX, designed for flexibility.

With jax2tf, we are able to convert trained JAX/Flax models into saved_model format, which can be used seamlessly in TFX with generic training and model evaluation. For details, check this example.

scikit-learn

Scikit-learn is a machine learning library for the Python programming language. We have an e2e example with customized training and evaluation in TFX-Addons.