Travailler avec ClientData de tff.

Voir sur TensorFlow.org Exécuter dans Google Colab Voir la source sur GitHub Télécharger le cahier

La notion d'ensemble de données saisie par les clients (par exemple les utilisateurs) est essentielle au calcul fédéré tel que modélisé dans TFF. TFF fournit l'interface tff.simulation.datasets.ClientData abstraire sur ce concept, et les ensembles de données qui accueille TFF ( stackoverflow , shakespeare , emnist , cifar100 et gldv2 ) mettre en œuvre toute cette interface.

Si vous travaillez sur l' apprentissage fédéré avec votre propre jeu de données, TFF vous encourage fortement à mettre en œuvre soit la ClientData interface ou utilisez l' une des fonctions d'aide de TFF pour générer un ClientData qui représente vos données sur le disque, par exemple tff.simulation.datasets.ClientData.from_clients_and_fn .

Comme la plupart des exemples de bout en bout de TFF commencer par ClientData objets, la mise en œuvre de la ClientData interface avec votre jeu de données personnalisé sera plus facile à Spelunk par le code existant écrit avec TFF. En outre, les tf.data.Datasets qui ClientData constructions peuvent être répétées sur directement pour donner des structures numpy matrices, de sorte que ClientData objets peuvent être utilisés avec tout cadre ML à base de Python avant de passer à la FFT.

Il existe plusieurs modèles avec lesquels vous pouvez vous faciliter la vie si vous avez l'intention d'étendre vos simulations à de nombreuses machines ou de les déployer. Ci - dessous , nous marcherons à travers quelques - unes des façons dont nous pouvons utiliser ClientData et TFF pour rendre notre petite itération à grande échelle d' expérimentation à l' expérience de déploiement production aussi lisse que possible.

Quel modèle dois-je utiliser pour transmettre ClientData dans TFF ?

Nous allons discuter de deux usages de la TFF de ClientData en profondeur; si vous vous situez dans l'une des deux catégories ci-dessous, vous préférerez clairement l'une à l'autre. Sinon, vous aurez peut-être besoin d'une compréhension plus détaillée des avantages et des inconvénients de chacun pour faire un choix plus nuancé.

  • Je veux itérer le plus rapidement possible sur une machine locale ; Je n'ai pas besoin de pouvoir profiter facilement du runtime distribué de TFF.

    • Vous voulez passer tf.data.Datasets pour TFF directement.
    • Cela vous permet de programmer avec impérieusement tf.data.Dataset objets, et de les traiter de façon arbitraire.
    • Elle offre plus de flexibilité que l'option ci-dessous ; pousser la logique vers les clients nécessite que cette logique soit sérialisable.
  • Je souhaite exécuter mon calcul fédéré dans l'environnement d'exécution distant de TFF, ou je prévois de le faire bientôt.

    • Dans ce cas, vous souhaitez mapper la construction et le prétraitement du jeu de données aux clients.
    • Il en résulte que vous en passant simplement une liste de client_ids directement à votre calcul fédérée.
    • Pousser la construction et le prétraitement des ensembles de données vers les clients évite les goulots d'étranglement dans la sérialisation et augmente considérablement les performances avec des centaines à des milliers de clients.

Mettre en place un environnement open source

Importer des packages

Manipulation d'un objet ClientData

Commençons par le chargement et l' exploration EMNIST de TFF ClientData :

client_data, _ = tff.simulation.datasets.emnist.load_data()
Downloading emnist_all.sqlite.lzma: 100%|██████████| 170507172/170507172 [00:19<00:00, 8831921.67it/s]
2021-10-01 11:17:58.718735: E tensorflow/stream_executor/cuda/cuda_driver.cc:271] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected

Inspectant le premier ensemble de données peut nous dire quel type d'exemples sont dans le ClientData .

first_client_id = client_data.client_ids[0]
first_client_dataset = client_data.create_tf_dataset_for_client(
    first_client_id)
print(first_client_dataset.element_spec)
# This information is also available as a `ClientData` property:
assert client_data.element_type_structure == first_client_dataset.element_spec
OrderedDict([('label', TensorSpec(shape=(), dtype=tf.int32, name=None)), ('pixels', TensorSpec(shape=(28, 28), dtype=tf.float32, name=None))])

Notez que les rendements des jeux de données collections.OrderedDict objets qui ont des pixels et label clés, où les pixels est un tenseur de forme [28, 28] . Supposons que nous voulons aplatir nos entrées vers la forme [784] . Une façon possible , nous pouvons le faire serait d'appliquer une fonction de pré-traitement à notre ClientData objet.

def preprocess_dataset(dataset):
  """Create batches of 5 examples, and limit to 3 batches."""

  def map_fn(input):
    return collections.OrderedDict(
        x=tf.reshape(input['pixels'], shape=(-1, 784)),
        y=tf.cast(tf.reshape(input['label'], shape=(-1, 1)), tf.int64),
    )

  return dataset.batch(5).map(
      map_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE).take(5)


preprocessed_client_data = client_data.preprocess(preprocess_dataset)

# Notice that we have both reshaped and renamed the elements of the ordered dict.
first_client_dataset = preprocessed_client_data.create_tf_dataset_for_client(
    first_client_id)
print(first_client_dataset.element_spec)
OrderedDict([('x', TensorSpec(shape=(None, 784), dtype=tf.float32, name=None)), ('y', TensorSpec(shape=(None, 1), dtype=tf.int64, name=None))])

Nous pouvons également souhaiter effectuer un prétraitement plus complexe (et éventuellement avec état), par exemple un brassage.

def preprocess_and_shuffle(dataset):
  """Applies `preprocess_dataset` above and shuffles the result."""
  preprocessed = preprocess_dataset(dataset)
  return preprocessed.shuffle(buffer_size=5)

preprocessed_and_shuffled = client_data.preprocess(preprocess_and_shuffle)

# The type signature will remain the same, but the batches will be shuffled.
first_client_dataset = preprocessed_and_shuffled.create_tf_dataset_for_client(
    first_client_id)
print(first_client_dataset.element_spec)
OrderedDict([('x', TensorSpec(shape=(None, 784), dtype=tf.float32, name=None)), ('y', TensorSpec(shape=(None, 1), dtype=tf.int64, name=None))])

Interfacer avec un tff.Computation

Maintenant que nous pouvons effectuer quelques manipulations de base avec ClientData objets, nous sommes prêts à des données d'alimentation à un tff.Computation . Nous définissons une tff.templates.IterativeProcess qui implémente la moyenne fédérée , et explorer les différentes méthodes de la transmettre des données.

def model_fn():
  model = tf.keras.models.Sequential([
      tf.keras.layers.InputLayer(input_shape=(784,)),
      tf.keras.layers.Dense(10, kernel_initializer='zeros'),
  ])
  return tff.learning.from_keras_model(
      model,
      # Note: input spec is the _batched_ shape, and includes the 
      # label tensor which will be passed to the loss function. This model is
      # therefore configured to accept data _after_ it has been preprocessed.
      input_spec=collections.OrderedDict(
          x=tf.TensorSpec(shape=[None, 784], dtype=tf.float32),
          y=tf.TensorSpec(shape=[None, 1], dtype=tf.int64)),
      loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
      metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])

trainer = tff.learning.build_federated_averaging_process(
    model_fn,
    client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=0.01))

Avant de commencer à travailler avec ce IterativeProcess , un commentaire sur la sémantique de ClientData est en ordre. Un ClientData objet représente la totalité de la population pour la formation fédérée, qui est en général pas à la disposition de l'environnement d'exécution d'un système de production FL et est spécifique à la simulation. ClientData donne en effet à l'utilisateur la capacité de contourner l' informatique fédérée former entièrement et simplement un modèle côté serveur comme d' habitude via ClientData.create_tf_dataset_from_all_clients .

L'environnement de simulation de TFF donne au chercheur le contrôle total de la boucle externe. En particulier, cela implique que les considérations de disponibilité du client, d'abandon du client, etc., doivent être traitées par l'utilisateur ou le script du pilote Python. On pourrait par exemple l' abandon du client modèle en ajustant la distribution d'échantillonnage sur vos ClientData's client_ids tels que les utilisateurs avec plus de données (et corrélativement plus long en cours d' exécution des calculs locaux) seront sélectionnés avec une probabilité plus faible.

Dans un système fédéré réel, cependant, les clients ne peuvent pas être sélectionnés explicitement par le formateur modèle ; la sélection des clients est déléguée au système qui exécute le calcul fédéré.

En passant tf.data.Datasets directement à TFF

Une option que nous avons pour l' interface entre un ClientData et un IterativeProcess est celui de la construction tf.data.Datasets en Python, et en passant ces ensembles de données de TFF.

Notez que si nous utilisons nos prétraitées ClientData les ensembles de données que nous cédons sont du type approprié prévu par notre modèle défini ci - dessus.

selected_client_ids = preprocessed_and_shuffled.client_ids[:10]

preprocessed_data_for_clients = [
    preprocessed_and_shuffled.create_tf_dataset_for_client(
        selected_client_ids[i]) for i in range(10)
]

state = trainer.initialize()
for _ in range(5):
  t1 = time.time()
  state, metrics = trainer.next(state, preprocessed_data_for_clients)
  t2 = time.time()
  print('loss {}, round time {}'.format(metrics['train']['loss'], t2 - t1))
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow_federated/python/core/impl/compiler/tensorflow_computation_transformations.py:62: extract_sub_graph (from tensorflow.python.framework.graph_util_impl) is deprecated and will be removed in a future version.
Instructions for updating:
Use `tf.compat.v1.graph_util.extract_sub_graph`
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow_federated/python/core/impl/compiler/tensorflow_computation_transformations.py:62: extract_sub_graph (from tensorflow.python.framework.graph_util_impl) is deprecated and will be removed in a future version.
Instructions for updating:
Use `tf.compat.v1.graph_util.extract_sub_graph`
loss 2.9005744457244873, round time 4.576513767242432
loss 3.113278388977051, round time 0.49641919136047363
loss 2.7581865787506104, round time 0.4904160499572754
loss 2.87259578704834, round time 0.48976993560791016
loss 3.1202380657196045, round time 0.6724586486816406

Si nous prenons cette voie, cependant, nous ne pourrons pas passer trivialement à la simulation multimachine. Les ensembles de données que nous construisons dans le runtime tensorflow local peut capturer l' état de l'environnement python environnant, et ne parviennent pas à sérialisation ou désérialisation lorsqu'ils tentent d'état de référence qui ne sont plus à leur disposition . Cela peut se manifester par exemple dans l'erreur insondable de tensorflow de tensor_util.cc :

Check failed: DT_VARIANT == input.dtype() (21 vs. 20)

Cartographier la construction et le prétraitement sur les clients

Pour éviter ce problème, TFF recommande à ses utilisateurs d'examiner ensemble de données instanciation et prétraiter comme quelque chose qui se passe sur chaque client local, et d'utiliser les aides de TFF ou federated_map pour exécuter explicitement ce code prétraiter à chaque client.

Conceptuellement, la raison de préférer cela est claire : dans l'environnement d'exécution local de TFF, les clients n'ont accès qu'"accidentellement" à l'environnement Python global en raison du fait que l'ensemble de l'orchestration fédérée se déroule sur une seule machine. Il convient de noter à ce stade qu'une réflexion similaire donne naissance à la philosophie fonctionnelle multiplateforme et toujours sérialisable de TFF.

TFF fait un simple changement par l' intermédiaire ClientData's attribut dataset_computation , un tff.Computation qui prend client_id et renvoie l'associé tf.data.Dataset .

Notez que preprocess fonctionne simplement avec dataset_computation ; l' dataset_computation attribut de la prétraité ClientData intègre le pipeline ensemble de pré - traitement , nous venons de définir:

print('dataset computation without preprocessing:')
print(client_data.dataset_computation.type_signature)
print('\n')
print('dataset computation with preprocessing:')
print(preprocessed_and_shuffled.dataset_computation.type_signature)
dataset computation without preprocessing:
(string -> <label=int32,pixels=float32[28,28]>*)


dataset computation with preprocessing:
(string -> <x=float32[?,784],y=int64[?,1]>*)

On pourrait invoquer dataset_computation et recevoir un ensemble de données désireux dans le runtime Python, mais la puissance réelle de cette approche est exercée lorsque nous composons avec un processus itératif ou un autre calcul pour éviter la matérialisation de ces ensembles de données dans le moteur d' exécution global désireux du tout. TFF fournit une fonction aide tff.simulation.compose_dataset_computation_with_iterative_process qui peut être utilisé pour faire exactement cela.

trainer_accepting_ids = tff.simulation.compose_dataset_computation_with_iterative_process(
    preprocessed_and_shuffled.dataset_computation, trainer)

Les deux ce tff.templates.IterativeProcesses et celui ci - dessus courir de la même manière; mais l' ancien accepte ensembles de données client prétraités, et celui - ci accepte les chaînes représentant ids client, la manipulation à la fois la construction de jeu de données et de pré - traitement dans son corps - en fait state peut être passé entre les deux.

for _ in range(5):
  t1 = time.time()
  state, metrics = trainer_accepting_ids.next(state, selected_client_ids)
  t2 = time.time()
  print('loss {}, round time {}'.format(metrics['train']['loss'], t2 - t1))
loss 2.8417396545410156, round time 1.6707067489624023
loss 2.7670371532440186, round time 0.5207102298736572
loss 2.665048122406006, round time 0.5302855968475342
loss 2.7213189601898193, round time 0.5313887596130371
loss 2.580148935317993, round time 0.5283482074737549

S'adapter à un grand nombre de clients

trainer_accepting_ids peuvent être immédiatement utilisés dans l' exécution de multimachine de TFF, et évite la matérialisation tf.data.Datasets et le contrôleur (et donc les sérialisation et de les envoyer aux travailleurs).

Cela accélère considérablement les simulations distribuées, en particulier avec un grand nombre de clients, et permet une agrégation intermédiaire pour éviter une surcharge de sérialisation/désérialisation similaire.

Deepdive en option : composition manuelle de la logique de prétraitement dans TFF

TFF est conçu pour la compositionalité à partir de zéro ; le type de composition que vient d'interpréter l'assistant de TFF est entièrement sous notre contrôle en tant qu'utilisateurs. Nous aurions pu composer manuellement le calcul de pré - traitement , nous venons de définir avec son propre formateur next tout simplement:

selected_clients_type = tff.FederatedType(preprocessed_and_shuffled.dataset_computation.type_signature.parameter, tff.CLIENTS)

@tff.federated_computation(trainer.next.type_signature.parameter[0], selected_clients_type)
def new_next(server_state, selected_clients):
  preprocessed_data = tff.federated_map(preprocessed_and_shuffled.dataset_computation, selected_clients)
  return trainer.next(server_state, preprocessed_data)

manual_trainer_with_preprocessing = tff.templates.IterativeProcess(initialize_fn=trainer.initialize, next_fn=new_next)

En fait, c'est effectivement ce que l'assistant que nous avons utilisé fait sous le capot (en plus d'effectuer une vérification et une manipulation de type appropriées). Nous aurions même pu exprimer un peu différemment la même logique, par sérialisation preprocess_and_shuffle dans un tff.Computation et décomposer le federated_map en une seule étape qui construit des ensembles de données de l' ONU-prétraité et une autre qui fonctionne preprocess_and_shuffle à chaque client.

Nous pouvons vérifier que ce chemin plus manuel aboutit à des calculs avec la même signature de type que l'assistant de TFF (noms de paramètres modulo) :

print(trainer_accepting_ids.next.type_signature)
print(manual_trainer_with_preprocessing.next.type_signature)
(<server_state=<model=<trainable=<float32[784,10],float32[10]>,non_trainable=<>>,optimizer_state=<int64>,delta_aggregate_state=<value_sum_process=<>,weight_sum_process=<>>,model_broadcast_state=<>>@SERVER,federated_dataset={string}@CLIENTS> -> <<model=<trainable=<float32[784,10],float32[10]>,non_trainable=<>>,optimizer_state=<int64>,delta_aggregate_state=<value_sum_process=<>,weight_sum_process=<>>,model_broadcast_state=<>>@SERVER,<broadcast=<>,aggregation=<mean_value=<>,mean_weight=<>>,train=<sparse_categorical_accuracy=float32,loss=float32>,stat=<num_examples=int64>>@SERVER>)
(<server_state=<model=<trainable=<float32[784,10],float32[10]>,non_trainable=<>>,optimizer_state=<int64>,delta_aggregate_state=<value_sum_process=<>,weight_sum_process=<>>,model_broadcast_state=<>>@SERVER,selected_clients={string}@CLIENTS> -> <<model=<trainable=<float32[784,10],float32[10]>,non_trainable=<>>,optimizer_state=<int64>,delta_aggregate_state=<value_sum_process=<>,weight_sum_process=<>>,model_broadcast_state=<>>@SERVER,<broadcast=<>,aggregation=<mean_value=<>,mean_weight=<>>,train=<sparse_categorical_accuracy=float32,loss=float32>,stat=<num_examples=int64>>@SERVER>)