Cette page a été traduite par l'API Cloud Translation.
Switch to English

Apprentissage fédéré pour la classification d'images

Voir sur TensorFlow.org Exécuter dans Google Colab Afficher la source sur GitHub

Dans ce didacticiel, nous utilisons l'exemple de formation MNIST classique pour présenter la couche API Federated Learning (FL) de TFF, tff.learning - un ensemble d'interfaces de niveau supérieur pouvant être utilisées pour effectuer des types courants de tâches d'apprentissage fédéré, telles que formation fédérée, par rapport aux modèles fournis par l'utilisateur implémentés dans TensorFlow.

Ce didacticiel et l'API Federated Learning sont principalement destinés aux utilisateurs qui souhaitent brancher leurs propres modèles TensorFlow dans TFF, en traitant ce dernier principalement comme une boîte noire. Pour une compréhension plus approfondie de TFF et comment implémenter vos propres algorithmes d'apprentissage fédéré, consultez les didacticiels sur l'API FC Core - Algorithmes fédérés personnalisés parties 1 et partie 2 .

Pour en savoir plus sur tff.learning , continuez avec le didacticiel Federated Learning for Text Generation , qui, en plus de couvrir les modèles récurrents, montre également le chargement d'un modèle Keras sérialisé pré-entraîné pour le raffinement avec un apprentissage fédéré combiné à une évaluation à l'aide de Keras.

Avant de commencer

Avant de commencer, veuillez exécuter ce qui suit pour vous assurer que votre environnement est correctement configuré. Si vous ne voyez pas de message d'accueil, veuillez vous reporter au guide d' installation pour obtenir des instructions.


!pip install --quiet --upgrade tensorflow_federated_nightly
!pip install --quiet --upgrade nest_asyncio

import nest_asyncio
nest_asyncio.apply()

%load_ext tensorboard
import collections

import numpy as np
import tensorflow as tf
import tensorflow_federated as tff

np.random.seed(0)

tff.federated_computation(lambda: 'Hello, World!')()
b'Hello, World!'

Préparation des données d'entrée

Commençons par les données. L'apprentissage fédéré nécessite un ensemble de données fédérées, c'est-à-dire une collection de données provenant de plusieurs utilisateurs. Les données fédérées sont généralement non iid , ce qui pose un ensemble unique de défis.

Afin de faciliter l'expérimentation, nous avons ensemencé le référentiel TFF avec quelques ensembles de données, y compris une version fédérée de MNIST qui contient une version de l' ensemble de données NIST original qui a été retraité à l'aide de Leaf afin que les données soient saisies par l'auteur original de les chiffres. Étant donné que chaque rédacteur a un style unique, cet ensemble de données présente le type de comportement non iid attendu des ensembles de données fédérés.

Voici comment nous pouvons le charger.

emnist_train, emnist_test = tff.simulation.datasets.emnist.load_data()

Les ensembles de données retournés par load_data() sont des instances de tff.simulation.ClientData , une interface qui vous permet d'énumérer l'ensemble d'utilisateurs, de construire un tf.data.Dataset qui représente les données d'un utilisateur particulier et d'interroger le structure des éléments individuels. Voici comment vous pouvez utiliser cette interface pour explorer le contenu de l'ensemble de données. Gardez à l'esprit que bien que cette interface vous permette d'itérer sur les identifiants clients, il ne s'agit que d'une fonctionnalité des données de simulation. Comme vous le verrez sous peu, les identités des clients ne sont pas utilisées par le framework d'apprentissage fédéré - leur seul objectif est de vous permettre de sélectionner des sous-ensembles de données pour les simulations.

len(emnist_train.client_ids)
3383
emnist_train.element_type_structure
OrderedDict([('pixels', TensorSpec(shape=(28, 28), dtype=tf.float32, name=None)), ('label', TensorSpec(shape=(), dtype=tf.int32, name=None))])
example_dataset = emnist_train.create_tf_dataset_for_client(
    emnist_train.client_ids[0])

example_element = next(iter(example_dataset))

example_element['label'].numpy()
1
from matplotlib import pyplot as plt
plt.imshow(example_element['pixels'].numpy(), cmap='gray', aspect='equal')
plt.grid(False)
_ = plt.show()

png

Explorer l'hétérogénéité des données fédérées

Les données fédérées sont généralement non iid , les utilisateurs ont généralement différentes distributions de données en fonction des modèles d'utilisation. Certains clients peuvent avoir moins d'exemples de formation sur l'appareil, souffrant d'un manque de données localement, tandis que certains clients auront plus que suffisamment d'exemples de formation. Explorons ce concept d'hétérogénéité des données typique d'un système fédéré avec les données EMNIST dont nous disposons. Il est important de noter que cette analyse approfondie des données d'un client n'est disponible que pour nous car il s'agit d'un environnement de simulation où toutes les données nous sont disponibles localement. Dans un environnement fédéré de production réel, vous ne pourrez pas inspecter les données d'un seul client.

Tout d'abord, prenons un échantillon des données d'un client pour avoir une idée des exemples sur un appareil simulé. Étant donné que l'ensemble de données que nous utilisons a été saisi par un rédacteur unique, les données d'un client représentent l'écriture manuscrite d'une personne pour un échantillon des chiffres 0 à 9, simulant le «modèle d'utilisation» unique d'un utilisateur.

## Example MNIST digits for one client
figure = plt.figure(figsize=(20, 4))
j = 0

for example in example_dataset.take(40):
  plt.subplot(4, 10, j+1)
  plt.imshow(example['pixels'].numpy(), cmap='gray', aspect='equal')
  plt.axis('off')
  j += 1

png

Visualisons maintenant le nombre d'exemples sur chaque client pour chaque étiquette numérique MNIST. Dans l'environnement fédéré, le nombre d'exemples sur chaque client peut varier considérablement en fonction du comportement de l'utilisateur.

# Number of examples per layer for a sample of clients
f = plt.figure(figsize=(12, 7))
f.suptitle('Label Counts for a Sample of Clients')
for i in range(6):
  client_dataset = emnist_train.create_tf_dataset_for_client(
      emnist_train.client_ids[i])
  plot_data = collections.defaultdict(list)
  for example in client_dataset:
    # Append counts individually per label to make plots
    # more colorful instead of one color per plot.
    label = example['label'].numpy()
    plot_data[label].append(label)
  plt.subplot(2, 3, i+1)
  plt.title('Client {}'.format(i))
  for j in range(10):
    plt.hist(
        plot_data[j],
        density=False,
        bins=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10])

png

Visualisons maintenant l'image moyenne par client pour chaque étiquette MNIST. Ce code produira la moyenne de chaque valeur de pixel pour tous les exemples de l'utilisateur pour une étiquette. Nous verrons que l'image moyenne d'un client pour un chiffre sera différente de l'image moyenne d'un autre client pour le même chiffre, en raison du style d'écriture unique de chaque personne. Nous pouvons réfléchir à la manière dont chaque cycle de formation local poussera le modèle dans une direction différente pour chaque client, car nous apprenons à partir des données uniques de cet utilisateur dans ce cycle local. Plus loin dans le didacticiel, nous verrons comment nous pouvons prendre chaque mise à jour du modèle de tous les clients et les regrouper dans notre nouveau modèle global, qui a appris de chacune des données uniques de nos clients.

# Each client has different mean images, meaning each client will be nudging
# the model in their own directions locally.

for i in range(5):
  client_dataset = emnist_train.create_tf_dataset_for_client(
      emnist_train.client_ids[i])
  plot_data = collections.defaultdict(list)
  for example in client_dataset:
    plot_data[example['label'].numpy()].append(example['pixels'].numpy())
  f = plt.figure(i, figsize=(12, 5))
  f.suptitle("Client #{}'s Mean Image Per Label".format(i))
  for j in range(10):
    mean_img = np.mean(plot_data[j], 0)
    plt.subplot(2, 5, j+1)
    plt.imshow(mean_img.reshape((28, 28)))
    plt.axis('off')

png

png

png

png

png

Les données utilisateur peuvent être bruyantes et étiquetées de manière non fiable. Par exemple, en regardant les données du client n ° 2 ci-dessus, nous pouvons voir que pour l'étiquette 2, il est possible qu'il y ait eu des exemples mal étiquetés créant une image moyenne plus bruyante.

Prétraitement des données d'entrée

Étant donné que les données sont déjà un tf.data.Dataset , le prétraitement peut être effectué à l'aide de transformations de jeu de données. Ici, nous aplatit les 28x28 images dans 784 tableaux -Element, mélanger les exemples individuels, les organiser en lots, et renomme les caractéristiques de pixels et l' label à x et y pour une utilisation avec Keras. Nous lançons également une repeat sur l'ensemble de données pour exécuter plusieurs époques.

NUM_CLIENTS = 10
NUM_EPOCHS = 5
BATCH_SIZE = 20
SHUFFLE_BUFFER = 100
PREFETCH_BUFFER= 10

def preprocess(dataset):

  def batch_format_fn(element):
    """Flatten a batch `pixels` and return the features as an `OrderedDict`."""
    return collections.OrderedDict(
        x=tf.reshape(element['pixels'], [-1, 784]),
        y=tf.reshape(element['label'], [-1, 1]))

  return dataset.repeat(NUM_EPOCHS).shuffle(SHUFFLE_BUFFER).batch(
      BATCH_SIZE).map(batch_format_fn).prefetch(PREFETCH_BUFFER)

Vérifions que cela a fonctionné.

preprocessed_example_dataset = preprocess(example_dataset)

sample_batch = tf.nest.map_structure(lambda x: x.numpy(),
                                     next(iter(preprocessed_example_dataset)))

sample_batch
OrderedDict([('x', array([[1., 1., 1., ..., 1., 1., 1.],
       [1., 1., 1., ..., 1., 1., 1.],
       [1., 1., 1., ..., 1., 1., 1.],
       ...,
       [1., 1., 1., ..., 1., 1., 1.],
       [1., 1., 1., ..., 1., 1., 1.],
       [1., 1., 1., ..., 1., 1., 1.]], dtype=float32)), ('y', array([[2],
       [1],
       [2],
       [3],
       [6],
       [0],
       [1],
       [4],
       [1],
       [0],
       [6],
       [9],
       [9],
       [3],
       [6],
       [1],
       [4],
       [8],
       [0],
       [2]], dtype=int32))])

Nous avons presque tous les éléments de base en place pour construire des ensembles de données fédérés.

L'un des moyens de fournir des données fédérées à TFF dans une simulation est simplement sous forme de liste Python, chaque élément de la liste contenant les données d'un utilisateur individuel, que ce soit sous forme de liste ou de tf.data.Dataset . Puisque nous avons déjà une interface qui fournit ce dernier, utilisons-la.

Voici une fonction d'assistance simple qui construira une liste d'ensembles de données à partir de l'ensemble d'utilisateurs donné en tant qu'entrée pour un cycle de formation ou d'évaluation.

def make_federated_data(client_data, client_ids):
  return [
      preprocess(client_data.create_tf_dataset_for_client(x))
      for x in client_ids
  ]

Maintenant, comment choisissons-nous les clients?

Dans un scénario de formation fédérée typique, nous avons affaire à une population potentiellement très importante de dispositifs utilisateurs, dont seule une fraction peut être disponible pour la formation à un moment donné. C'est le cas, par exemple, lorsque les appareils clients sont des téléphones mobiles qui ne participent à la formation que lorsqu'ils sont branchés à une source d'alimentation, hors d'un réseau mesuré et autrement inactifs.

Bien sûr, nous sommes dans un environnement de simulation, et toutes les données sont disponibles localement. En règle générale, lors de l'exécution de simulations, nous échantillonnerions simplement un sous-ensemble aléatoire de clients à impliquer dans chaque cycle de formation, généralement différents à chaque cycle.

Cela dit, comme vous pouvez le découvrir en étudiant l'article sur l'algorithme de moyenne fédérée , parvenir à une convergence dans un système avec des sous-ensembles de clients échantillonnés au hasard à chaque tour peut prendre un certain temps, et il serait peu pratique de devoir exécuter des centaines de tours dans ce didacticiel interactif.

Ce que nous allons faire à la place, c'est échantillonner l'ensemble de clients une fois, et réutiliser le même ensemble à travers les tours pour accélérer la convergence (intentionnellement sur-ajustement aux données de ces quelques utilisateurs). Nous laissons au lecteur le soin de modifier ce tutoriel pour simuler un échantillonnage aléatoire - c'est assez facile à faire (une fois que vous le faites, gardez à l'esprit que la convergence du modèle peut prendre un certain temps).

sample_clients = emnist_train.client_ids[0:NUM_CLIENTS]

federated_train_data = make_federated_data(emnist_train, sample_clients)

print('Number of client datasets: {l}'.format(l=len(federated_train_data)))
print('First dataset: {d}'.format(d=federated_train_data[0]))
Number of client datasets: 10
First dataset: <DatasetV1Adapter shapes: OrderedDict([(x, (None, 784)), (y, (None, 1))]), types: OrderedDict([(x, tf.float32), (y, tf.int32)])>

Créer un modèle avec Keras

Si vous utilisez Keras, vous avez probablement déjà du code qui construit un modèle Keras. Voici un exemple de modèle simple qui suffira à nos besoins.

def create_keras_model():
  return tf.keras.models.Sequential([
      tf.keras.layers.Input(shape=(784,)),
      tf.keras.layers.Dense(10, kernel_initializer='zeros'),
      tf.keras.layers.Softmax(),
  ])

Afin d'utiliser n'importe quel modèle avec TFF, il doit être enveloppé dans une instance de l'interface tff.learning.Model , qui expose des méthodes pour tamponner le passage avant du modèle, les propriétés des métadonnées, etc., de la même manière que Keras, mais introduit également des éléments, tels que les moyens de contrôler le processus de calcul des métriques fédérées. Ne nous inquiétons pas de cela pour le moment; si vous avez un modèle Keras comme celui que nous venons de définir ci-dessus, vous pouvez demander à TFF de l'envelopper pour vous en invoquant tff.learning.from_keras_model , en passant le modèle et un exemple de lot de données comme arguments, comme indiqué ci-dessous.

def model_fn():
  # We _must_ create a new model here, and _not_ capture it from an external
  # scope. TFF will call this within different graph contexts.
  keras_model = create_keras_model()
  return tff.learning.from_keras_model(
      keras_model,
      input_spec=preprocessed_example_dataset.element_spec,
      loss=tf.keras.losses.SparseCategoricalCrossentropy(),
      metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])

Entraîner le modèle sur des données fédérées

Maintenant que nous avons un modèle tff.learning.Model comme tff.learning.Model à utiliser avec TFF, nous pouvons laisser TFF construire un algorithme de moyenne fédérée en tff.learning.build_federated_averaging_process fonction d'assistance tff.learning.build_federated_averaging_process , comme suit.

Gardez à l'esprit que l'argument doit être un constructeur (tel que model_fn ci-dessus), pas une instance déjà construite, afin que la construction de votre modèle puisse avoir lieu dans un contexte contrôlé par TFF (si vous êtes curieux de connaître les raisons de ceci, nous vous encourageons à lire le tutoriel de suivi sur les algorithmes personnalisés ).

Une note critique sur l'algorithme de moyenne fédérée ci-dessous, il existe 2 optimiseurs: un optimiseur _client et un optimiseur _server. L' optimiseur _client est uniquement utilisé pour calculer les mises à jour du modèle local sur chaque client. L' optimiseur _server applique la mise à jour moyenne au modèle global sur le serveur. En particulier, cela signifie que le choix de l'optimiseur et du taux d'apprentissage utilisé peut devoir être différent de ceux que vous avez utilisés pour entraîner le modèle sur un ensemble de données iid standard. Nous vous recommandons de commencer avec SGD standard, éventuellement avec un taux d'apprentissage plus faible que d'habitude. Le taux d'apprentissage que nous utilisons n'a pas été soigneusement réglé, n'hésitez pas à expérimenter.

iterative_process = tff.learning.build_federated_averaging_process(
    model_fn,
    client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=0.02),
    server_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=1.0))

Qu'est-ce qui vient juste de se passer? TFF a construit une paire de calculs fédérés et les a regroupés dans un tff.templates.IterativeProcess dans lequel ces calculs sont disponibles sous la forme d'une paire de propriétés initialize et next .

En un mot, les calculs fédérés sont des programmes dans le langage interne de TFF qui peuvent exprimer divers algorithmes fédérés (vous pouvez en savoir plus à ce sujet dans le didacticiel sur les algorithmes personnalisés ). Dans ce cas, les deux calculs générés et regroupés dans iterative_process implémentent la moyenne fédérée .

C'est un objectif de TFF de définir les calculs de manière à ce qu'ils puissent être exécutés dans des paramètres d'apprentissage fédérés réels, mais actuellement, seul le runtime de simulation d'exécution locale est implémenté. Pour exécuter un calcul dans un simulateur, vous l'appelez simplement comme une fonction Python. Cet environnement interprété par défaut n'est pas conçu pour des performances élevées, mais il suffira pour ce didacticiel; nous prévoyons de fournir des temps d'exécution de simulation plus performants pour faciliter la recherche à plus grande échelle dans les versions futures.

Commençons par le calcul d' initialize . Comme c'est le cas pour tous les calculs fédérés, vous pouvez le considérer comme une fonction. Le calcul ne prend aucun argument et renvoie un résultat - la représentation de l'état du processus de calcul de la moyenne fédérée sur le serveur. Bien que nous ne souhaitons pas entrer dans les détails de TFF, il peut être instructif de voir à quoi ressemble cet état. Vous pouvez le visualiser comme suit.

str(iterative_process.initialize.type_signature)
'( -> <model=<trainable=<float32[784,10],float32[10]>,non_trainable=<>>,optimizer_state=<int64>,delta_aggregate_state=<>,model_broadcast_state=<>>@SERVER)'

Bien que la signature de type ci-dessus puisse au premier abord sembler un peu cryptique, vous pouvez reconnaître que l'état du serveur se compose d'un model (les paramètres de modèle initiaux pour MNIST qui seront distribués à tous les périphériques) et optimizer_state (informations supplémentaires maintenues par le serveur, comme le nombre de tours à utiliser pour les horaires d'hyperparamètres, etc.).

Appelons le calcul d' initialize pour construire l'état du serveur.

state = iterative_process.initialize()

Le deuxième de la paire de calculs fédérés, next , représente un cycle unique de moyenne fédérée, qui consiste à pousser l'état du serveur (y compris les paramètres du modèle) aux clients, à s'entraîner sur l'appareil sur leurs données locales, à collecter et à calculer la moyenne des mises à jour du modèle. , et la production d'un nouveau modèle mis à jour sur le serveur.

Conceptuellement, vous pouvez penser à next comme ayant une signature de type fonctionnel qui ressemble à ceci.

SERVER_STATE, FEDERATED_DATA -> SERVER_STATE, TRAINING_METRICS

En particulier, il ne faut pas penser à next() comme étant une fonction qui s'exécute sur un serveur, mais plutôt comme une représentation fonctionnelle déclarative de l'ensemble du calcul décentralisé - certaines des entrées sont fournies par le serveur ( SERVER_STATE ), mais chacune l'appareil fournit son propre jeu de données local.

Lançons un seul cycle de formation et visualisons les résultats. Nous pouvons utiliser les données fédérées que nous avons déjà générées ci-dessus pour un échantillon d'utilisateurs.

state, metrics = iterative_process.next(state, federated_train_data)
print('round  1, metrics={}'.format(metrics))
round  1, metrics=<broadcast=<>,aggregation=<>,train=<sparse_categorical_accuracy=0.12037037312984467,loss=3.0108425617218018>>

Faisons quelques tours de plus. Comme indiqué précédemment, généralement à ce stade, vous choisirez un sous-ensemble de vos données de simulation à partir d'un nouvel échantillon d'utilisateurs sélectionné au hasard pour chaque tour afin de simuler un déploiement réaliste dans lequel les utilisateurs vont et viennent continuellement, mais dans ce cahier interactif, pour dans un souci de démonstration, nous réutiliserons simplement les mêmes utilisateurs, de sorte que le système converge rapidement.

NUM_ROUNDS = 11
for round_num in range(2, NUM_ROUNDS):
  state, metrics = iterative_process.next(state, federated_train_data)
  print('round {:2d}, metrics={}'.format(round_num, metrics))
round  2, metrics=<broadcast=<>,aggregation=<>,train=<sparse_categorical_accuracy=0.14814814925193787,loss=2.8865506649017334>>
round  3, metrics=<broadcast=<>,aggregation=<>,train=<sparse_categorical_accuracy=0.148765429854393,loss=2.9079062938690186>>
round  4, metrics=<broadcast=<>,aggregation=<>,train=<sparse_categorical_accuracy=0.17633745074272156,loss=2.724686622619629>>
round  5, metrics=<broadcast=<>,aggregation=<>,train=<sparse_categorical_accuracy=0.20226337015628815,loss=2.6334855556488037>>
round  6, metrics=<broadcast=<>,aggregation=<>,train=<sparse_categorical_accuracy=0.22427983582019806,loss=2.5482592582702637>>
round  7, metrics=<broadcast=<>,aggregation=<>,train=<sparse_categorical_accuracy=0.24094650149345398,loss=2.4472343921661377>>
round  8, metrics=<broadcast=<>,aggregation=<>,train=<sparse_categorical_accuracy=0.259876549243927,loss=2.3809611797332764>>
round  9, metrics=<broadcast=<>,aggregation=<>,train=<sparse_categorical_accuracy=0.29814815521240234,loss=2.156442403793335>>
round 10, metrics=<broadcast=<>,aggregation=<>,train=<sparse_categorical_accuracy=0.31687241792678833,loss=2.122845411300659>>

La perte de formation diminue après chaque cycle de formation fédérée, ce qui indique que le modèle converge. Il y a quelques mises en garde importantes avec ces métriques de formation, cependant, consultez la section sur l' évaluation plus loin dans ce didacticiel.

Affichage des métriques de modèle dans TensorBoard

Ensuite, visualisons les métriques de ces calculs fédérés à l'aide de Tensorboard.

Commençons par créer le répertoire et le rédacteur de résumé correspondant dans lequel écrire les métriques.


logdir = "/tmp/logs/scalars/training/"
summary_writer = tf.summary.create_file_writer(logdir)
state = iterative_process.initialize()

Tracez les métriques scalaires pertinentes avec le même rédacteur de résumé.


with summary_writer.as_default():
  for round_num in range(1, NUM_ROUNDS):
    state, metrics = iterative_process.next(state, federated_train_data)
    for name, value in metrics.train._asdict().items():
      tf.summary.scalar(name, value, step=round_num)

Démarrez TensorBoard avec le répertoire de journal racine spécifié ci-dessus. Le chargement des données peut prendre quelques secondes.


%tensorboard --logdir /tmp/logs/scalars/ --port=0

# Run this this cell to clean your directory of old output for future graphs from this directory.
rm -R /tmp/logs/scalars/*

Afin d'afficher les métriques d'évaluation de la même manière, vous pouvez créer un dossier eval distinct, comme «logs / scalars / eval», pour écrire dans TensorBoard.

Personnalisation de l'implémentation du modèle

Keras est l' API de modèle de haut niveau recommandée pour TensorFlow , et nous vous encourageons à utiliser les modèles Keras (via tff.learning.from_keras_model ) dans TFF chaque fois que possible.

Cependant, tff.learning fournit une interface de modèle de niveau inférieur, tff.learning.Model , qui expose les fonctionnalités minimales nécessaires à l'utilisation d'un modèle pour l'apprentissage fédéré. L'implémentation directe de cette interface (peut-être toujours en utilisant des blocs de construction comme tf.keras.layers ) permet une personnalisation maximale sans modifier les éléments internes des algorithmes d'apprentissage fédérés.

Alors recommençons à partir de zéro.

Définition des variables de modèle, du passage en avant et des métriques

La première étape consiste à identifier les variables TensorFlow avec lesquelles nous allons travailler. Afin de rendre le code suivant plus lisible, définissons une structure de données pour représenter l'ensemble complet. Cela comprendra des variables telles que les weights et les bias que nous allons entraîner, ainsi que des variables qui contiendront diverses statistiques cumulatives et des compteurs que nous mettrons à jour pendant la formation, tels que loss_sum , accuracy_sum et num_examples .

MnistVariables = collections.namedtuple(
    'MnistVariables', 'weights bias num_examples loss_sum accuracy_sum')

Voici une méthode qui crée les variables. Par souci de simplicité, nous représentons toutes les statistiques sous la forme tf.float32 , car cela éliminera le besoin de conversions de type à un stade ultérieur. L'emballage des initialiseurs de variables en tant que lambdas est une exigence imposée par les variables de ressources .

def create_mnist_variables():
  return MnistVariables(
      weights=tf.Variable(
          lambda: tf.zeros(dtype=tf.float32, shape=(784, 10)),
          name='weights',
          trainable=True),
      bias=tf.Variable(
          lambda: tf.zeros(dtype=tf.float32, shape=(10)),
          name='bias',
          trainable=True),
      num_examples=tf.Variable(0.0, name='num_examples', trainable=False),
      loss_sum=tf.Variable(0.0, name='loss_sum', trainable=False),
      accuracy_sum=tf.Variable(0.0, name='accuracy_sum', trainable=False))

Avec les variables pour les paramètres du modèle et les statistiques cumulatives en place, nous pouvons maintenant définir la méthode de transfert avant qui calcule la perte, émet des prédictions et met à jour les statistiques cumulatives pour un seul lot de données d'entrée, comme suit.

def mnist_forward_pass(variables, batch):
  y = tf.nn.softmax(tf.matmul(batch['x'], variables.weights) + variables.bias)
  predictions = tf.cast(tf.argmax(y, 1), tf.int32)

  flat_labels = tf.reshape(batch['y'], [-1])
  loss = -tf.reduce_mean(
      tf.reduce_sum(tf.one_hot(flat_labels, 10) * tf.math.log(y), axis=[1]))
  accuracy = tf.reduce_mean(
      tf.cast(tf.equal(predictions, flat_labels), tf.float32))

  num_examples = tf.cast(tf.size(batch['y']), tf.float32)

  variables.num_examples.assign_add(num_examples)
  variables.loss_sum.assign_add(loss * num_examples)
  variables.accuracy_sum.assign_add(accuracy * num_examples)

  return loss, predictions

Ensuite, nous définissons une fonction qui renvoie un ensemble de métriques locales, à nouveau en utilisant TensorFlow. Ce sont les valeurs (en plus des mises à jour de modèle, qui sont gérées automatiquement) qui sont éligibles pour être agrégées sur le serveur dans un processus d'apprentissage ou d'évaluation fédéré.

Ici, nous num_examples simplement la loss et la accuracy moyennes, ainsi que les num_examples , dont nous aurons besoin pour pondérer correctement les contributions des différents utilisateurs lors du calcul des agrégats fédérés.

def get_local_mnist_metrics(variables):
  return collections.OrderedDict(
      num_examples=variables.num_examples,
      loss=variables.loss_sum / variables.num_examples,
      accuracy=variables.accuracy_sum / variables.num_examples)

Enfin, nous devons déterminer comment agréger les métriques locales émises par chaque appareil via get_local_mnist_metrics . C'est la seule partie du code qui n'est pas écrite dans TensorFlow - c'est un calcul fédéré exprimé en TFF. Si vous souhaitez approfondir, parcourez le didacticiel sur les algorithmes personnalisés , mais dans la plupart des applications, vous n'en aurez pas vraiment besoin; des variantes du modèle ci-dessous devraient suffire. Voici à quoi cela ressemble:

@tff.federated_computation
def aggregate_mnist_metrics_across_clients(metrics):
  return collections.OrderedDict(
      num_examples=tff.federated_sum(metrics.num_examples),
      loss=tff.federated_mean(metrics.loss, metrics.num_examples),
      accuracy=tff.federated_mean(metrics.accuracy, metrics.num_examples))
  

L'argument de metrics entrée correspond au OrderedDict renvoyé par get_local_mnist_metrics ci-dessus, mais de manière critique, les valeurs ne sont plus tf.Tensors - elles sont "encadrées" en tant que tff.Value s, pour qu'il soit clair que vous ne pouvez plus les manipuler à l'aide de TensorFlow, mais seulement en utilisant les opérateurs fédérés de TFF tels que tff.federated_mean et tff.federated_sum . Le dictionnaire retourné des agrégats globaux définit l'ensemble de métriques qui seront disponibles sur le serveur.

Construire une instance de tff.learning.Model

Avec tout ce qui précède en place, nous sommes prêts à construire une représentation de modèle à utiliser avec TFF similaire à celle qui est générée pour vous lorsque vous laissez TFF ingérer un modèle Keras.

class MnistModel(tff.learning.Model):

  def __init__(self):
    self._variables = create_mnist_variables()

  @property
  def trainable_variables(self):
    return [self._variables.weights, self._variables.bias]

  @property
  def non_trainable_variables(self):
    return []

  @property
  def local_variables(self):
    return [
        self._variables.num_examples, self._variables.loss_sum,
        self._variables.accuracy_sum
    ]

  @property
  def input_spec(self):
    return collections.OrderedDict(
        x=tf.TensorSpec([None, 784], tf.float32),
        y=tf.TensorSpec([None, 1], tf.int32))

  @tf.function
  def forward_pass(self, batch, training=True):
    del training
    loss, predictions = mnist_forward_pass(self._variables, batch)
    num_exmaples = tf.shape(batch['x'])[0]
    return tff.learning.BatchOutput(
        loss=loss, predictions=predictions, num_examples=num_exmaples)

  @tf.function
  def report_local_outputs(self):
    return get_local_mnist_metrics(self._variables)

  @property
  def federated_output_computation(self):
    return aggregate_mnist_metrics_across_clients

Comme vous pouvez le voir, les méthodes abstraites et les propriétés définies par tff.learning.Model correspondent aux extraits de code de la section précédente qui ont introduit les variables et défini la perte et les statistiques.

Voici quelques points à souligner:

  • Tous les états que votre modèle utilisera doivent être capturés en tant que variables TensorFlow, car TFF n'utilise pas Python au moment de l'exécution (rappelez-vous que votre code doit être écrit de manière à pouvoir être déployé sur des appareils mobiles; voir le didacticiel sur les algorithmes personnalisés pour plus de détails commentaire sur les raisons).
  • Votre modèle doit décrire la forme de données qu'il accepte ( input_spec ), car en général, TFF est un environnement fortement typé et souhaite déterminer les signatures de type pour tous les composants. La déclaration du format de l'entrée de votre modèle en est une partie essentielle.
  • Bien que techniquement non requis, nous vous recommandons d'encapsuler toute la logique TensorFlow (passe en avant, calculs de métriques, etc.) en tant que tf.function s, car cela permet de garantir que TensorFlow peut être sérialisé et supprime le besoin de dépendances de contrôle explicites.

Ce qui précède est suffisant pour l'évaluation et les algorithmes tels que Federated SGD. Cependant, pour la moyenne fédérée, nous devons spécifier comment le modèle doit s'entraîner localement sur chaque lot. Nous spécifierons un optimiseur local lors de la construction de l'algorithme de moyenne fédérée.

Simuler la formation fédérée avec le nouveau modèle

Avec tout ce qui précède en place, le reste du processus ressemble à ce que nous avons déjà vu - remplacez simplement le constructeur de modèle par le constructeur de notre nouvelle classe de modèle et utilisez les deux calculs fédérés dans le processus itératif que vous avez créé pour parcourir tours de formation.

iterative_process = tff.learning.build_federated_averaging_process(
    MnistModel,
    client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=0.02))
state = iterative_process.initialize()
state, metrics = iterative_process.next(state, federated_train_data)
print('round  1, metrics={}'.format(metrics))
round  1, metrics=<broadcast=<>,aggregation=<>,train=<num_examples=4860.0,loss=2.9713594913482666,accuracy=0.13518518209457397>>

for round_num in range(2, 11):
  state, metrics = iterative_process.next(state, federated_train_data)
  print('round {:2d}, metrics={}'.format(round_num, metrics))
round  2, metrics=<broadcast=<>,aggregation=<>,train=<num_examples=4860.0,loss=2.975412607192993,accuracy=0.14032921195030212>>
round  3, metrics=<broadcast=<>,aggregation=<>,train=<num_examples=4860.0,loss=2.9395227432250977,accuracy=0.1594650149345398>>
round  4, metrics=<broadcast=<>,aggregation=<>,train=<num_examples=4860.0,loss=2.710164785385132,accuracy=0.17139917612075806>>
round  5, metrics=<broadcast=<>,aggregation=<>,train=<num_examples=4860.0,loss=2.5891618728637695,accuracy=0.20267489552497864>>
round  6, metrics=<broadcast=<>,aggregation=<>,train=<num_examples=4860.0,loss=2.5148487091064453,accuracy=0.21666666865348816>>
round  7, metrics=<broadcast=<>,aggregation=<>,train=<num_examples=4860.0,loss=2.2816808223724365,accuracy=0.2580246925354004>>
round  8, metrics=<broadcast=<>,aggregation=<>,train=<num_examples=4860.0,loss=2.3656885623931885,accuracy=0.25884774327278137>>
round  9, metrics=<broadcast=<>,aggregation=<>,train=<num_examples=4860.0,loss=2.23549222946167,accuracy=0.28477364778518677>>
round 10, metrics=<broadcast=<>,aggregation=<>,train=<num_examples=4860.0,loss=1.974222183227539,accuracy=0.35329216718673706>>

Pour afficher ces métriques dans TensorBoard, reportez-vous aux étapes répertoriées ci-dessus dans "Affichage des métriques de modèle dans TensorBoard".

Évaluation

Toutes nos expériences jusqu'à présent ne présentaient que des métriques d'entraînement fédérées - les métriques moyennes sur tous les lots de données entraînées sur tous les clients du cycle. Cela introduit les préoccupations normales concernant le surajustement, d'autant plus que nous avons utilisé le même ensemble de clients à chaque tour pour plus de simplicité, mais il existe une notion supplémentaire de surajustement dans les métriques d'entraînement spécifiques à l'algorithme de moyenne fédérée. Il est plus facile de voir si nous imaginons que chaque client avait un seul lot de données, et nous nous entraînons sur ce lot pendant de nombreuses itérations (époques). Dans ce cas, le modèle local s'adaptera rapidement exactement à ce lot, et donc la métrique de précision locale que nous moyenne approcherons 1.0. Ainsi, ces mesures de formation peuvent être considérées comme un signe que la formation progresse, mais pas beaucoup plus.

Pour effectuer une évaluation sur des données fédérées, vous pouvez construire un autre calcul fédéré conçu à cet effet, à l'aide de la fonction tff.learning.build_federated_evaluation et en passant votre constructeur de modèle en tant qu'argument. Notez que contrairement à Federated Averaging, où nous avons utilisé MnistTrainableModel , il suffit de passer le MnistModel . L'évaluation n'effectue pas de descente de gradient et il n'est pas nécessaire de construire des optimiseurs.

Pour l'expérimentation et la recherche, lorsqu'un ensemble de données de test centralisé est disponible, Federated Learning for Text Generation démontre une autre option d'évaluation: prendre les pondérations entraînées de l'apprentissage fédéré, les appliquer à un modèle Keras standard, puis appeler simplement tf.keras.models.Model.evaluate() sur un ensemble de données centralisé.

evaluation = tff.learning.build_federated_evaluation(MnistModel)

Vous pouvez inspecter la signature de type abstrait de la fonction d'évaluation comme suit.

str(evaluation.type_signature)
'(<<trainable=<float32[784,10],float32[10]>,non_trainable=<>>@SERVER,{<x=float32[?,784],y=int32[?,1]>*}@CLIENTS> -> <num_examples=float32@SERVER,loss=float32@SERVER,accuracy=float32@SERVER>)'

Inutile de vous soucier des détails à ce stade, sachez simplement qu'il prend la forme générale suivante, similaire à tff.templates.IterativeProcess.next mais avec deux différences importantes. Premièrement, nous ne retournons pas l'état du serveur, car l'évaluation ne modifie pas le modèle ou tout autre aspect de l'état - vous pouvez le considérer comme sans état. Deuxièmement, l'évaluation n'a besoin que du modèle et ne nécessite aucune autre partie de l'état du serveur qui pourrait être associée à la formation, comme les variables d'optimisation.

SERVER_MODEL, FEDERATED_DATA -> TRAINING_METRICS

Invoquons l'évaluation sur le dernier état auquel nous sommes arrivés pendant l'entraînement. Pour extraire le dernier modèle entraîné de l'état du serveur, vous accédez simplement au membre .model , comme suit.

train_metrics = evaluation(state.model, federated_train_data)

Voici ce que nous obtenons. Notez que les chiffres semblent légèrement meilleurs que ce qui a été rapporté par le dernier cycle de formation ci-dessus. Par convention, les métriques de formation rapportées par le processus de formation itératif reflètent généralement les performances du modèle au début du cycle de formation, de sorte que les métriques d'évaluation auront toujours une longueur d'avance.

str(train_metrics)
'<num_examples=4860.0,loss=1.7142657041549683,accuracy=0.38683128356933594>'

Maintenant, compilons un échantillon de test de données fédérées et réexécutons l'évaluation sur les données de test. Les données proviendront du même échantillon d'utilisateurs réels, mais d'un ensemble de données distinct.

federated_test_data = make_federated_data(emnist_test, sample_clients)

len(federated_test_data), federated_test_data[0]
(10,
 <DatasetV1Adapter shapes: OrderedDict([(x, (None, 784)), (y, (None, 1))]), types: OrderedDict([(x, tf.float32), (y, tf.int32)])>)
test_metrics = evaluation(state.model, federated_test_data)
str(test_metrics)
'<num_examples=580.0,loss=1.861915111541748,accuracy=0.3362068831920624>'

Ceci conclut le didacticiel. Nous vous encourageons à jouer avec les paramètres (par exemple, la taille des lots, le nombre d'utilisateurs, les époques, les taux d'apprentissage, etc.), à modifier le code ci-dessus pour simuler l'entraînement sur des échantillons aléatoires d'utilisateurs à chaque tour, et à explorer les autres tutoriels nous avons développé.