tff.simulation.baselines.BaselineTaskDatasets

A convenience class for a task's data and preprocessing logic.

train_data A tff.simulation.datasets.ClientData for training.
test_data A tff.simulation.datasets.ClientData or a tf.data.Dataset for computing test metrics.
validation_data An optional tff.simulation.datasets.ClientData or a tf.data.Dataset for computing validation metrics.
train_preprocess_fn An optional callable accepting and returning a tf.data.Dataset, used to perform dataset preprocessing for training. If set to None, we use the identity map for all train preprocessing.
eval_preprocess_fn An optional callable accepting and returning a tf.data.Dataset, used to perform evaluation (eg. validation, testing) preprocessing. If None, evaluation preprocessing will be done via the identity map.

ValueError If train_data and test_data have different element types after preprocessing with train_preprocess_fn and eval_preprocess_fn, or if validation_data is not None and has a different element type than the test data.

train_data A tff.simulation.datasets.ClientData for training.
test_data The test data for the baseline task. Can be a tff.simulation.datasets.ClientData or a tf.data.Dataset.
validation_data The validation data for the baseline task. Can be one of tff.simulation.datasets.ClientData, tf.data.Dataset, or None if the task does not have a validation dataset.
train_preprocess_fn A callable mapping accepting and return tf.data.Dataset instances, used for preprocessing train datasets. Set to None if no train preprocessing occurs for the task.
eval_preprocess_fn A callable mapping accepting and return tf.data.Dataset instances, used for preprocessing evaluation datasets. Set to None if no eval preprocessing occurs for the task.
element_type_structure A nested structure of tf.TensorSpec objects defining the type of the elements contained in datasets associated to this task.

Methods

get_centralized_test_data

View source

Returns a tf.data.Dataset of test data for the task.

If the baseline task has centralized data, then this method will return the centralized data after applying preprocessing. If the test data is federated, then this method will first amalgamate the client datasets into a single dataset, then apply preprocessing.

sample_train_clients

View source

Samples training clients uniformly at random.

Args
num_clients A positive integer representing number of clients to be sampled.
replace Whether to sample with replacement. If set to False, then num_clients cannot exceed the number of training clients in the associated train data.
random_seed An optional integer used to set a random seed for sampling. If no random seed is passed or the random seed is set to None, this will attempt to set the random seed according to the current system time (see numpy.random.RandomState for details).

Returns
A list of tf.data.Dataset instances representing the client datasets.

summary

View source

Prints a summary of the train, test, and validation data.

The summary will be printed as a table containing information on the type of train, test, and validation data (ie. federated or centralized) and the number of clients each data structure has (if it is federated). For example, if the train data has 10 clients, and both the test and validation data are centralized, then this will print the following table:

Split      |Dataset Type |Number of Clients |
=============================================
Train      |Federated    |10                |
Test       |Centralized  |N/A               |
Validation |Centralized  |N/A               |
_____________________________________________

In addition, this will print two lines after the table indicating whether train and eval preprocessing functions were passed in. In the example above, if we passed in a train preprocessing function but no eval preprocessing function, it would also print the lines:

Train Preprocess Function: True
Eval Preprocess Function: False

To capture the summary, you can use a custom print function. For example, setting print_fn = summary_list.append will cause each of the lines above to be appended to summary_list.

Args
print_fn An optional callable accepting string inputs. Used to print each row of the summary. Defaults to print if not specified.