tff.simulation.baselines.stackoverflow.create_tag_prediction_task

Creates a baseline task for tag prediction on Stack Overflow.

The goal of the task is to predict the tags associated to a post based on a bag-of-words representation of the post.

train_client_spec A tff.simulation.baselines.ClientSpec specifying how to preprocess train client data.
eval_client_spec An optional tff.simulation.baselines.ClientSpec specifying how to preprocess evaluation client data. If set to None, the evaluation datasets will use a batch size of 64 with no extra preprocessing.
word_vocab_size Integer dictating the number of most frequent words in the entire corpus to use for the task's vocabulary. By default, this is set to tff.simulation.baselines.stackoverflow.DEFAULT_WORD_VOCAB_SIZE.
tag_vocab_size Integer dictating the number of most frequent tags in the entire corpus to use for the task's labels. By default, this is set to tff.simulation.baselines.stackoverflow.DEFAULT_TAG_VOCAB_SIZE.
cache_dir An optional directory to cache the downloadeded datasets. If None, they will be cached to ~/.tff/.
use_synthetic_data A boolean indicating whether to use synthetic Stack Overflow data. This option should only be used for testing purposes, in order to avoid downloading the entire Stack Overflow dataset. Synthetic word vocabularies and tag vocabularies will also be used (not necessarily of sizes word_vocab_size and tag_vocab_size).

A tff.simulation.baselines.BaselineTask.