Attend the Women in ML Symposium on December 7 Register now


Stay organized with collections Save and categorize content based on your preferences.

Gets the current worker index and the total number of workers.

This method should be called by a worker in a tf.function called in the worker context. In practice, this method should be called in the in a the dataset_fn(context) method.

Currently, context is ignored as it is not populated by the ParameterServerStrategyV2. However, context should still be provided for compatibility with future API changes.

paths = [...list of dataset files]

def dataset_fn(context: Optional[distribute_lib.InputContext] = None): # Distributed dataset_fn.

ds_path =

if context is not None: current_worker = keras.get_worker_idx_and_num_workers(context) assert current_worker.num_workers > 1, "Not distributed dataset reading" ds_path = ds_path.shard( num_shards=current_worker.num_workers, index=current_worker.worker_index)

# Load the examples from "ds_path", for example with #

def read_csv_file(path): csv_columns = [ ... ] return, csv_columns, header=False)

ds_columns = ds_path.interleave(read_csv_file)

def extract_label(*columns): return columns[0:-1], columns[-1]


context Distribution strategy input context.

Return the index of the worker (tensor) and the total number of workers (integer).