tff.learning.programs.train_model_with_vizier
Stay organized with collections
Save and categorize content based on your preferences.
Trains and tunes a federated model using Vizier.
tff.learning.programs.train_model_with_vizier(
*,
study,
total_trials,
num_parallel_trials=1,
update_hparams,
train_model_program_logic,
train_process_factory,
train_data_source,
total_rounds,
num_clients,
program_state_manager_factory,
model_output_manager_factory,
train_metrics_manager_factory=None,
evaluation_manager_factory,
evaluation_periodicity
)
Args |
study
|
The Vizier study to use to to tune train_model_program_logic .
|
total_trials
|
The number of Vizier trials.
|
num_parallel_trials
|
The number of Vizier trials to be evaluated in
parallel. Default is 1.
|
update_hparams
|
A tff.Computation to use to update the models hparams
using a trials parameters.
|
train_model_program_logic
|
The program logic to use for training and
evaluating the model.
|
train_process_factory
|
A factory for creating
tff.learning.templates.LearningProcess to run for training.
|
train_data_source
|
A tff.program.FederatedDataSource which returns client
data used during training.
|
total_rounds
|
The number of rounds of training.
|
num_clients
|
The number of clients per round of training.
|
program_state_manager_factory
|
A factory for creating
tff.program.ProgramStateManager s for each trail.
|
model_output_manager_factory
|
A factory for creating
tff.program.ReleaseManager s used to release the model.
|
train_metrics_manager_factory
|
A factory for creating
tff.program.ReleaseManager s used to release training metrics for each
trail.
|
evaluation_manager_factory
|
A factory for creating
tff.learning.programs.EvaluationManager s for each trail.
|
evaluation_periodicity
|
Either a integer number of rounds or
datetime.timedelta to await before sending a new training checkpoint to
evaluation_manager.start_evaluation .
|
Except as otherwise noted, the content of this page is licensed under the Creative Commons Attribution 4.0 License, and code samples are licensed under the Apache 2.0 License. For details, see the Google Developers Site Policies. Java is a registered trademark of Oracle and/or its affiliates.
Last updated 2024-09-20 UTC.
[[["Easy to understand","easyToUnderstand","thumb-up"],["Solved my problem","solvedMyProblem","thumb-up"],["Other","otherUp","thumb-up"]],[["Missing the information I need","missingTheInformationINeed","thumb-down"],["Too complicated / too many steps","tooComplicatedTooManySteps","thumb-down"],["Out of date","outOfDate","thumb-down"],["Samples / code issue","samplesCodeIssue","thumb-down"],["Other","otherDown","thumb-down"]],["Last updated 2024-09-20 UTC."],[],[]]