ML Community Day is November 9! Join us for updates from TensorFlow, JAX, and more Learn more

Module: tf.data.experimental.service

API for using the tf.data service.

This module contains:

  1. tf.data server implementations for running the tf.data service.
  2. APIs for registering datasets with the tf.data service and reading from the registered datasets.

The tf.data service provides the following benefits:

  • Horizontal scaling of tf.data input pipeline processing to solve input bottlenecks.
  • Data coordination for distributed training. Coordinated reads enable all replicas to train on similar-length examples across each global training step, improving step times in synchronous training.
  • Dynamic balancing of data across training replicas.
dispatcher = tf.data.experimental.service.DispatchServer()
dispatcher_address = dispatcher.target.split("://")[1]
worker = tf.data.experimental.service.WorkerServer(
    tf.data.experimental.service.WorkerConfig(
        dispatcher_address=dispatcher_address))
dataset = tf.data.Dataset.range(10)
dataset = dataset.apply(tf.data.experimental.service.distribute(
    processing_mode="parallel_epochs", service=dispatcher.target))
print(list(dataset.as_numpy_iterator()))
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]

Setup

This section goes over how to set up the tf.data service.

Run tf.data servers

The tf.data service consists of one dispatch server and n worker servers. tf.data servers should be brought up alongside your training jobs, then brought down when the jobs are finished. Use tf.data.experimental.service.DispatchServer to start a dispatch server, and tf.data.experimental.service.WorkerServer to start worker servers. Servers can be run in the same process for testing purposes, or scaled up on separate machines.

See https://github.com/tensorflow/ecosystem/tree/master/data_service for an example of using Google Kubernetes Engine (GKE) to manage the tf.data service. Note that the server implementation in tf_std_data_server.py is not GKE-specific, and can be used to run the tf.data service in other contexts.

Custom ops

If your dataset uses custom ops, these ops need to be made available to tf.data servers by calling load_op_library from the dispatcher and worker processes at startup.

Usage

Users interact with tf.data service by programmatically registering their datasets with tf.data service, then creating datasets that read from the registered datasets. The register_dataset function registers a dataset, then the from_dataset_id function creates a new dataset which reads from the registered dataset. The distribute function wraps register_dataset and from_dataset_id into a single convenient transformation which registers its input dataset and then reads from it. distribute enables tf.data service to be used with a one-line code change. However, it assumes that the dataset is created and consumed by the same entity and this assumption might not always be valid or desirable. In particular, in certain scenarios, such as distributed training, it might be desirable to decouple the creation and consumption of the dataset (via register_dataset and from_dataset_id respectively) to avoid having to create the dataset on each of the training workers.

Example

distribute

To use the distribute transformation, apply the transformation after the prefix of your input pipeline that you would like to be executed using tf.data service (typically at the end).

dataset = ...  # Define your dataset here.
# Move dataset processing from the local machine to the tf.data service
dataset = dataset.apply(
    tf.data.experimental.service.distribute(
        processing_mode="parallel_epochs",
        service=FLAGS.tf_data_ser