tff.simulation.baselines.stackoverflow.create_tag_prediction_task
Stay organized with collections
Save and categorize content based on your preferences.
Creates a baseline task for tag prediction on Stack Overflow.
tff.simulation.baselines.stackoverflow.create_tag_prediction_task(
train_client_spec: tff.simulation.baselines.ClientSpec
,
eval_client_spec: Optional[tff.simulation.baselines.ClientSpec
] = None,
word_vocab_size: int = constants.DEFAULT_WORD_VOCAB_SIZE,
tag_vocab_size: int = constants.DEFAULT_TAG_VOCAB_SIZE,
cache_dir: Optional[str] = None,
use_synthetic_data: bool = False
) -> tff.simulation.baselines.BaselineTask
The goal of the task is to predict the tags associated to a post based on a
bag-of-words representation of the post.
Args |
train_client_spec
|
A tff.simulation.baselines.ClientSpec specifying how to
preprocess train client data.
|
eval_client_spec
|
An optional tff.simulation.baselines.ClientSpec
specifying how to preprocess evaluation client data. If set to None , the
evaluation datasets will use a batch size of 64 with no extra
preprocessing.
|
word_vocab_size
|
Integer dictating the number of most frequent words in the
entire corpus to use for the task's vocabulary. By default, this is set to
tff.simulation.baselines.stackoverflow.DEFAULT_WORD_VOCAB_SIZE .
|
tag_vocab_size
|
Integer dictating the number of most frequent tags in the
entire corpus to use for the task's labels. By default, this is set to
tff.simulation.baselines.stackoverflow.DEFAULT_TAG_VOCAB_SIZE .
|
cache_dir
|
An optional directory to cache the downloadeded datasets. If
None , they will be cached to ~/.tff/ .
|
use_synthetic_data
|
A boolean indicating whether to use synthetic Stack
Overflow data. This option should only be used for testing purposes, in
order to avoid downloading the entire Stack Overflow dataset. Synthetic
word vocabularies and tag vocabularies will also be used (not necessarily
of sizes word_vocab_size and tag_vocab_size ).
|
Except as otherwise noted, the content of this page is licensed under the Creative Commons Attribution 4.0 License, and code samples are licensed under the Apache 2.0 License. For details, see the Google Developers Site Policies. Java is a registered trademark of Oracle and/or its affiliates.
Last updated 2024-09-20 UTC.
[{
"type": "thumb-down",
"id": "missingTheInformationINeed",
"label":"Missing the information I need"
},{
"type": "thumb-down",
"id": "tooComplicatedTooManySteps",
"label":"Too complicated / too many steps"
},{
"type": "thumb-down",
"id": "outOfDate",
"label":"Out of date"
},{
"type": "thumb-down",
"id": "samplesCodeIssue",
"label":"Samples / code issue"
},{
"type": "thumb-down",
"id": "otherDown",
"label":"Other"
}]
[{
"type": "thumb-up",
"id": "easyToUnderstand",
"label":"Easy to understand"
},{
"type": "thumb-up",
"id": "solvedMyProblem",
"label":"Solved my problem"
},{
"type": "thumb-up",
"id": "otherUp",
"label":"Other"
}]
{"lastModified": "Last updated 2024-09-20 UTC."}
[[["Easy to understand","easyToUnderstand","thumb-up"],["Solved my problem","solvedMyProblem","thumb-up"],["Other","otherUp","thumb-up"]],[["Missing the information I need","missingTheInformationINeed","thumb-down"],["Too complicated / too many steps","tooComplicatedTooManySteps","thumb-down"],["Out of date","outOfDate","thumb-down"],["Samples / code issue","samplesCodeIssue","thumb-down"],["Other","otherDown","thumb-down"]],["Last updated 2024-09-20 UTC."],[],[]]