tfdf.keras.get_worker_idx_and_num_workers

Gets the current worker index and the total number of workers.

This method should be called by a worker in a tf.function called in the worker context. In practice, this method should be called in the in a the dataset_fn(context) method.

Currently, context is ignored as it is not populated by the ParameterServerStrategyV2. However, context should still be provided for compatibility with future API changes.

paths = [...list of dataset files]

def dataset_fn(context: Optional[distribute_lib.InputContext] = None): # Distributed dataset_fn.

ds_path = tf.data.Dataset.from_tensor_slices(paths)

if context is not None: current_worker = keras.get_worker_idx_and_num_workers(context) assert current_worker.num_workers > 1, "Not distributed dataset reading" ds_path = ds_path.shard( num_shards=current_worker.num_workers, index=current_worker.worker_index)

# Load the examples from "ds_path", for example with # tf.data.experimental.CsvDataset.

def read_csv_file(path): csv_columns = [ ... ] return tf.data.experimental.CsvDataset(path, csv_columns, header=False)

ds_columns = ds_path.interleave(read_csv_file)

def extract_label(*columns): return columns[0:-1], columns[-1]

return ds_columns.map(extract_label).batch(batch_size)

context Distribution strategy input context.

Return the index of the worker (tensor) and the total number of workers (integer).