Merken Sie den Termin vor! Google I / O kehrt vom 18. bis 20. Mai zurück Registrieren Sie sich jetzt
Diese Seite wurde von der Cloud Translation API übersetzt.
Switch to English

Föderiertes Lernen zur Bildklassifizierung

Ansicht auf TensorFlow.org In Google Colab ausführen Quelle auf GitHub anzeigen Notizbuch herunterladen

In diesem Lernprogramm verwenden wir das klassische MNIST-Schulungsbeispiel, um die FL-API-Schicht (Federated Learning) von TFF, tff.learning , tff.learning - eine Reihe übergeordneter Schnittstellen, mit denen gängige Arten von Verbundlernaufgaben ausgeführt werden können, z Verbundschulung gegen vom Benutzer bereitgestellte Modelle, die in TensorFlow implementiert sind.

Dieses Tutorial und die Federated Learning API sind in erster Linie für Benutzer gedacht, die ihre eigenen TensorFlow-Modelle in TFF einbinden möchten, wobei letztere meist als Black Box behandelt werden. Weitere Informationen zu TFF und zur Implementierung Ihrer eigenen Verbundlernalgorithmen finden Sie in den Tutorials zur FC Core API - Benutzerdefinierte Verbundalgorithmen Teil 1 und Teil 2 .

Weitere tff.learning zu tff.learning Tutorial Federated Learning for Text Generation. tff.learning Tutorial behandelt nicht nur wiederkehrende Modelle, tff.learning zeigt auch das Laden eines vorab trainierten serialisierten Keras-Modells zur Verfeinerung mit föderiertem Lernen in Kombination mit der Bewertung mit Keras.

Bevor wir anfangen

Bevor wir beginnen, führen Sie bitte die folgenden Schritte aus, um sicherzustellen, dass Ihre Umgebung korrekt eingerichtet ist. Wenn Sie keine Begrüßung sehen, finden Sie Anweisungen in der Installationsanleitung .

# tensorflow_federated_nightly also bring in tf_nightly, which
# can causes a duplicate tensorboard install, leading to errors.
!pip uninstall --yes tensorboard tb-nightly

!pip install --quiet --upgrade tensorflow-federated-nightly
!pip install --quiet --upgrade nest-asyncio
!pip install --quiet --upgrade tb-nightly  # or tensorboard, but not both

import nest_asyncio
nest_asyncio.apply()
%load_ext tensorboard
Fetching TensorBoard MPM... done.
import collections

import numpy as np
import tensorflow as tf
import tensorflow_federated as tff

np.random.seed(0)

tff.federated_computation(lambda: 'Hello, World!')()
b'Hello, World!'

Eingabedaten vorbereiten

Beginnen wir mit den Daten. Für das Verbundlernen ist ein Verbunddatensatz erforderlich, dh eine Sammlung von Daten von mehreren Benutzern. Verbunddaten sind in der Regel nicht flüssig , was eine einzigartige Reihe von Herausforderungen darstellt.

Um das Experimentieren zu erleichtern, haben wir das TFF-Repository mit einigen Datensätzen versehen, einschließlich einer Verbundversion von MNIST, die eine Version des ursprünglichen NIST-Datensatzes enthält , die mit Leaf erneut verarbeitet wurde, sodass die Daten vom ursprünglichen Verfasser von verschlüsselt werden die Ziffern. Da jeder Writer einen eindeutigen Stil hat, weist dieser Datensatz das Verhalten von Nicht-ID auf, das von Verbunddatensätzen erwartet wird.

So können wir es laden.

emnist_train, emnist_test = tff.simulation.datasets.emnist.load_data()

Die von load_data() zurückgegebenen load_data() sind Instanzen von tff.simulation.ClientData , einer Schnittstelle, mit der Sie die Benutzergruppetf.data.Dataset , einetf.data.Dataset , die die Daten eines bestimmten Benutzers darstellt, und die abfragen können Struktur einzelner Elemente. Hier erfahren Sie, wie Sie diese Schnittstelle verwenden können, um den Inhalt des Datensatzes zu untersuchen. Beachten Sie, dass Sie mit dieser Schnittstelle zwar über Client-IDs iterieren können, dies jedoch nur eine Funktion der Simulationsdaten ist. Wie Sie in Kürze sehen werden, werden Kundenidentitäten vom Verbundlern-Framework nicht verwendet. Sie dienen lediglich dazu, Teilmengen der Daten für Simulationen auszuwählen.

len(emnist_train.client_ids)
3383
emnist_train.element_type_structure
OrderedDict([('pixels', TensorSpec(shape=(28, 28), dtype=tf.float32, name=None)), ('label', TensorSpec(shape=(), dtype=tf.int32, name=None))])
example_dataset = emnist_train.create_tf_dataset_for_client(
    emnist_train.client_ids[0])

example_element = next(iter(example_dataset))

example_element['label'].numpy()
1
from matplotlib import pyplot as plt
plt.imshow(example_element['pixels'].numpy(), cmap='gray', aspect='equal')
plt.grid(False)
_ = plt.show()

png

Untersuchung der Heterogenität in Verbunddaten

Verbunddaten sind in der Regel nicht flüssig . Benutzer haben in der Regel unterschiedliche Datenverteilungen, abhängig von den Verwendungsmustern. Einige Clients verfügen möglicherweise über weniger Schulungsbeispiele auf dem Gerät, da lokal Datenmangel herrscht, während einige Clients über mehr als genügend Schulungsbeispiele verfügen. Lassen Sie uns dieses für ein Verbundsystem typische Konzept der Datenheterogenität mit den verfügbaren EMNIST-Daten untersuchen. Es ist wichtig zu beachten, dass diese gründliche Analyse der Daten eines Kunden nur für uns verfügbar ist, da dies eine Simulationsumgebung ist, in der alle Daten lokal für uns verfügbar sind. In einer realen Produktionsverbundumgebung können Sie die Daten eines einzelnen Kunden nicht überprüfen.

Lassen Sie uns zunächst eine Stichprobe der Daten eines Kunden erstellen, um ein Gefühl für die Beispiele auf einem simulierten Gerät zu bekommen. Da der von uns verwendete Datensatz von einem eindeutigen Schreiber eingegeben wurde, stellen die Daten eines Clients die Handschrift einer Person für eine Stichprobe der Ziffern 0 bis 9 dar und simulieren das eindeutige "Verwendungsmuster" eines Benutzers.

## Example MNIST digits for one client
figure = plt.figure(figsize=(20, 4))
j = 0

for example in example_dataset.take(40):
  plt.subplot(4, 10, j+1)
  plt.imshow(example['pixels'].numpy(), cmap='gray', aspect='equal')
  plt.axis('off')
  j += 1

png

Lassen Sie uns nun die Anzahl der Beispiele auf jedem Client für jedes MNIST-Ziffernetikett visualisieren. In der Verbundumgebung kann die Anzahl der Beispiele auf jedem Client je nach Benutzerverhalten erheblich variieren.

# Number of examples per layer for a sample of clients
f = plt.figure(figsize=(12, 7))
f.suptitle('Label Counts for a Sample of Clients')
for i in range(6):
  client_dataset = emnist_train.create_tf_dataset_for_client(
      emnist_train.client_ids[i])
  plot_data = collections.defaultdict(list)
  for example in client_dataset:
    # Append counts individually per label to make plots
    # more colorful instead of one color per plot.
    label = example['label'].numpy()
    plot_data[label].append(label)
  plt.subplot(2, 3, i+1)
  plt.title('Client {}'.format(i))
  for j in range(10):
    plt.hist(
        plot_data[j],
        density=False,
        bins=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10])

png

Lassen Sie uns nun das mittlere Bild pro Client für jedes MNIST-Label visualisieren. Dieser Code erzeugt den Mittelwert jedes Pixelwerts für alle Beispiele des Benutzers für ein Etikett. Wir werden sehen, dass das Durchschnittsbild eines Kunden für eine Ziffer aufgrund des einzigartigen Handschriftstils jeder Person anders aussieht als das Durchschnittsbild eines anderen Kunden für dieselbe Ziffer. Wir können darüber nachdenken, wie jede lokale Trainingsrunde das Modell auf jedem Kunden in eine andere Richtung bewegt, da wir aus den eigenen eindeutigen Daten dieses Benutzers in dieser lokalen Runde lernen. Später im Tutorial werden wir sehen, wie wir jedes Update des Modells von allen Clients übernehmen und zu unserem neuen globalen Modell zusammenfassen können, das aus den eindeutigen Daten jedes Kunden gelernt hat.

# Each client has different mean images, meaning each client will be nudging
# the model in their own directions locally.

for i in range(5):
  client_dataset = emnist_train.create_tf_dataset_for_client(
      emnist_train.client_ids[i])
  plot_data = collections.defaultdict(list)
  for example in client_dataset:
    plot_data[example['label'].numpy()].append(example['pixels'].numpy())
  f = plt.figure(i, figsize=(12, 5))
  f.suptitle("Client #{}'s Mean Image Per Label".format(i))
  for j in range(10):
    mean_img = np.mean(plot_data[j], 0)
    plt.subplot(2, 5, j+1)
    plt.imshow(mean_img.reshape((28, 28)))
    plt.axis('off')

png

png

png

png

png

Benutzerdaten können verrauscht und unzuverlässig gekennzeichnet sein. Wenn wir uns beispielsweise die Daten von Client Nr. 2 oben ansehen, können wir feststellen, dass es bei Etikett 2 möglicherweise einige falsch beschriftete Beispiele gegeben hat, die ein lauteres mittleres Bild erzeugen.

Vorverarbeitung der Eingabedaten

Da es sich bei den Daten bereits um eintf.data.Dataset , kann die Vorverarbeitung mithilfe von Dataset-Transformationen durchgeführt werden. Hier 28x28 wir die 28x28 Bilder in 784 Element-Arrays, mischen die einzelnen Beispiele, organisieren sie in Stapeln und benennen die Features für die Verwendung mit Keras von pixels und label in x und y um. Wir repeat den Datensatz auch, um mehrere Epochen auszuführen.

NUM_CLIENTS = 10
NUM_EPOCHS = 5
BATCH_SIZE = 20
SHUFFLE_BUFFER = 100
PREFETCH_BUFFER = 10

def preprocess(dataset):

  def batch_format_fn(element):
    """Flatten a batch `pixels` and return the features as an `OrderedDict`."""
    return collections.OrderedDict(
        x=tf.reshape(element['pixels'], [-1, 784]),
        y=tf.reshape(element['label'], [-1, 1]))

  return dataset.repeat(NUM_EPOCHS).shuffle(SHUFFLE_BUFFER).batch(
      BATCH_SIZE).map(batch_format_fn).prefetch(PREFETCH_BUFFER)

Lassen Sie uns überprüfen, ob dies funktioniert hat.

preprocessed_example_dataset = preprocess(example_dataset)

sample_batch = tf.nest.map_structure(lambda x: x.numpy(),
                                     next(iter(preprocessed_example_dataset)))

sample_batch
OrderedDict([('x', array([[1., 1., 1., ..., 1., 1., 1.],
       [1., 1., 1., ..., 1., 1., 1.],
       [1., 1., 1., ..., 1., 1., 1.],
       ...,
       [1., 1., 1., ..., 1., 1., 1.],
       [1., 1., 1., ..., 1., 1., 1.],
       [1., 1., 1., ..., 1., 1., 1.]], dtype=float32)), ('y', array([[0],
       [5],
       [0],
       [1],
       [3],
       [0],
       [5],
       [4],
       [1],
       [7],
       [0],
       [4],
       [0],
       [1],
       [7],
       [2],
       [2],
       [0],
       [7],
       [1]], dtype=int32))])

Wir haben fast alle Bausteine, um Verbunddatensätze zu erstellen.

Eine Möglichkeit, Verbunddaten in einer Simulation an TFF weiterzuleiten, besteht einfach in einer Python-Liste, wobei jedes Element der Liste die Daten eines einzelnen Benutzers enthält, sei es als Liste oder alstf.data.Dataset . Da wir bereits eine Schnittstelle haben, die letztere bereitstellt, verwenden wir sie.

Hier ist eine einfache Hilfsfunktion, mit der eine Liste von Datensätzen aus der angegebenen Benutzergruppe als Eingabe für eine Schulungs- oder Evaluierungsrunde erstellt wird.

def make_federated_data(client_data, client_ids):
  return [
      preprocess(client_data.create_tf_dataset_for_client(x))
      for x in client_ids
  ]

Wie wählen wir nun Kunden aus?

In einem typischen Verbundschulungsszenario handelt es sich möglicherweise um eine sehr große Anzahl von Benutzergeräten, von denen möglicherweise nur ein Bruchteil zu einem bestimmten Zeitpunkt für Schulungen verfügbar ist. Dies ist beispielsweise der Fall, wenn es sich bei den Clientgeräten um Mobiltelefone handelt, die nur dann an Schulungen teilnehmen, wenn sie an eine Stromquelle angeschlossen sind, sich nicht in einem gemessenen Netzwerk befinden und ansonsten im Leerlauf sind.

Natürlich befinden wir uns in einer Simulationsumgebung und alle Daten sind lokal verfügbar. Normalerweise führen wir dann beim Ausführen von Simulationen einfach eine zufällige Teilmenge der Kunden aus, die an jeder Trainingsrunde beteiligt sind, die in jeder Runde im Allgemeinen unterschiedlich ist.

Wie Sie anhand des Artikels über den Federated Averaging- Algorithmus herausfinden können, kann es jedoch eine Weile dauern, bis in einem System Konvergenz in einem System mit zufällig ausgewählten Teilmengen von Clients erreicht ist, und es wäre unpraktisch, Hunderte von Runden ausführen zu müssen dieses interaktive Tutorial.

Stattdessen werden wir die Gruppe von Clients einmal testen und dieselbe Gruppe über Runden hinweg wiederverwenden, um die Konvergenz zu beschleunigen (absichtlich übermäßig an die Daten dieser wenigen Benutzer angepasst). Wir überlassen es dem Leser als Übung, dieses Tutorial zu ändern, um Zufallsstichproben zu simulieren - dies ist ziemlich einfach (wenn Sie dies tun, denken Sie daran, dass es eine Weile dauern kann, bis das Modell konvergiert).

sample_clients = emnist_train.client_ids[0:NUM_CLIENTS]

federated_train_data = make_federated_data(emnist_train, sample_clients)

print('Number of client datasets: {l}'.format(l=len(federated_train_data)))
print('First dataset: {d}'.format(d=federated_train_data[0]))
Number of client datasets: 10
First dataset: <DatasetV1Adapter shapes: OrderedDict([(x, (None, 784)), (y, (None, 1))]), types: OrderedDict([(x, tf.float32), (y, tf.int32)])>

Erstellen eines Modells mit Keras

Wenn Sie Keras verwenden, verfügen Sie wahrscheinlich bereits über Code, der ein Keras-Modell erstellt. Hier ist ein Beispiel für ein einfaches Modell, das für unsere Anforderungen ausreicht.

def create_keras_model():
  return tf.keras.models.Sequential([
      tf.keras.layers.InputLayer(input_shape=(784,)),
      tf.keras.layers.Dense(10, kernel_initializer='zeros'),
      tf.keras.layers.Softmax(),
  ])

Um ein Modell mit TFF verwenden zu können, muss es in eine Instanz der Schnittstelle tff.learning.Model werden, die Methoden zum Stempeln des Vorwärtsdurchlaufs, der Metadateneigenschaften usw. des Modells ähnlich wie Keras verfügbar macht, aber auch zusätzliche einführt Elemente, z. B. Möglichkeiten zur Steuerung des Prozesses zur Berechnung von Verbundmetriken. Machen wir uns vorerst keine Sorgen. Wenn Sie ein Keras-Modell wie das oben definierte haben, können Sie es von TFF tff.learning.from_keras_model , indem Sie tff.learning.from_keras_model und das Modell und einen Beispieldatenstapel als Argumente übergeben, wie unten gezeigt.

def model_fn():
  # We _must_ create a new model here, and _not_ capture it from an external
  # scope. TFF will call this within different graph contexts.
  keras_model = create_keras_model()
  return tff.learning.from_keras_model(
      keras_model,
      input_spec=preprocessed_example_dataset.element_spec,
      loss=tf.keras.losses.SparseCategoricalCrossentropy(),
      metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])

Training des Modells auf Verbunddaten

tff.learning.Model wir ein Modell als tff.learning.Model zur Verwendung mit TFF verpackt haben, können wir TFF einen Federated Averaging-Algorithmus tff.learning.build_federated_averaging_process , indem wir die tff.learning.build_federated_averaging_process wie folgt tff.learning.build_federated_averaging_process .

model_fn Sie, dass das Argument ein Konstruktor (wie z. B. model_fn oben) und keine bereits erstellte Instanz sein muss, damit die Erstellung Ihres Modells in einem von TFF kontrollierten Kontext erfolgen kann (wenn Sie neugierig auf die Gründe dafür sind Wir empfehlen Ihnen daher, das nachfolgende Tutorial zu benutzerdefinierten Algorithmen zu lesen.

Ein kritischer Hinweis zum Federated Averaging-Algorithmus unten sind zwei Optimierer: ein _client- Optimierer und ein _server- Optimierer . Das Optimierungsprogramm _client wird nur zum Berechnen lokaler Modellaktualisierungen auf jedem Client verwendet. Der _server- Optimierer wendet die gemittelte Aktualisierung auf das globale Modell auf dem Server an. Dies bedeutet insbesondere, dass die Auswahl des verwendeten Optimierers und der verwendeten Lernrate möglicherweise anders sein muss als die, die Sie zum Trainieren des Modells auf einem Standard-ID-Datensatz verwendet haben. Wir empfehlen, mit regulären SGD zu beginnen, möglicherweise mit einer geringeren Lernrate als gewöhnlich. Die Lernrate, die wir verwenden, wurde nicht sorgfältig abgestimmt. Sie können gerne experimentieren.

iterative_process = tff.learning.build_federated_averaging_process(
    model_fn,
    client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=0.02),
    server_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=1.0))

Was ist gerade passiert? TFF hat ein Paar von föderierten Berechnungen konstruiert und in ein verpackten tff.templates.IterativeProcess in dem diese Berechnungen sind als ein Paar von Eigenschaften initialize und next .

Kurz gesagt, Verbundberechnungen sind Programme in der internen Sprache von TFF, die verschiedene Verbundalgorithmen ausdrücken können (mehr dazu finden Sie im Tutorial für benutzerdefinierte Algorithmen ). In diesem Fall implementieren die beiden Berechnungen, die generiert und in iterative_process gepackt wurden, Federated Averaging .

Es ist ein Ziel von TFF, Berechnungen so zu definieren, dass sie in realen Verbundlerneinstellungen ausgeführt werden können. Derzeit ist jedoch nur die Laufzeit der lokalen Ausführungssimulation implementiert. Um eine Berechnung in einem Simulator auszuführen, rufen Sie sie einfach wie eine Python-Funktion auf. Diese standardmäßig interpretierte Umgebung ist nicht für hohe Leistung ausgelegt, reicht jedoch für dieses Lernprogramm aus. Wir gehen davon aus, dass Simulationslaufzeiten mit höherer Leistung bereitgestellt werden, um in zukünftigen Versionen umfangreichere Forschungsarbeiten zu ermöglichen.

Beginnen wir mit der initialize . Wie bei allen Verbundberechnungen können Sie sich das als Funktion vorstellen. Die Berechnung akzeptiert keine Argumente und gibt ein Ergebnis zurück - die Darstellung des Status des Federated Averaging-Prozesses auf dem Server. Wir möchten zwar nicht auf die Details von TFF eingehen, aber es kann lehrreich sein, zu sehen, wie dieser Zustand aussieht. Sie können es wie folgt visualisieren.

str(iterative_process.initialize.type_signature)
'( -> <model=<trainable=<float32[784,10],float32[10]>,non_trainable=<>>,optimizer_state=<int64>,delta_aggregate_state=<value_sum_process=<>,weight_sum_process=<>>,model_broadcast_state=<>>@SERVER)'

Während die obige Typensignatur zunächst etwas kryptisch erscheint, können Sie erkennen, dass der Serverstatus aus einem model (den anfänglichen Modellparametern für MNIST, die an alle Geräte verteilt werden) und optimizer_state (zusätzliche Informationen, die vom Server verwaltet werden) besteht. wie die Anzahl der Runden, die für Hyperparameter-Zeitpläne verwendet werden sollen usw.).

Rufen wir die initialize auf, um den Serverstatus zu erstellen.

state = iterative_process.initialize()

Die zweite der beiden Verbundberechnungen stellt als next eine einzelne Runde der Verbundmittelung dar, die darin besteht, den Serverstatus (einschließlich der Modellparameter) an die Clients weiterzuleiten, die lokalen Daten auf dem Gerät zu schulen, Modellaktualisierungen zu sammeln und zu mitteln und Erstellen eines neuen aktualisierten Modells auf dem Server.

Konzeptionell können Sie sich als next eine funktionale Typensignatur vorstellen, die wie folgt aussieht.

SERVER_STATE, FEDERATED_DATA -> SERVER_STATE, TRAINING_METRICS

Insbesondere sollte man sich next() nicht als eine Funktion SERVER_STATE , die auf einem Server ausgeführt wird, sondern als deklarative funktionale Darstellung der gesamten dezentralen Berechnung - einige der Eingaben werden vom Server ( SERVER_STATE ) SERVER_STATE , aber jeder nimmt teil Gerät trägt seinen eigenen lokalen Datensatz bei.

Lassen Sie uns eine einzelne Trainingsrunde durchführen und die Ergebnisse visualisieren. Wir können die oben bereits generierten Verbunddaten für eine Stichprobe von Benutzern verwenden.

state, metrics = iterative_process.next(state, federated_train_data)
print('round  1, metrics={}'.format(metrics))
round  1, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('value_sum_process', ()), ('weight_sum_process', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.11502057), ('loss', 3.244929)]))])

Lassen Sie uns noch ein paar Runden laufen. Wie bereits erwähnt, wählen Sie zu diesem Zeitpunkt normalerweise eine Teilmenge Ihrer Simulationsdaten aus einer neuen zufällig ausgewählten Stichprobe von Benutzern für jede Runde aus, um eine realistische Bereitstellung zu simulieren, bei der Benutzer kontinuierlich kommen und gehen, jedoch in diesem interaktiven Notizbuch Zur Demonstration werden wir nur dieselben Benutzer wiederverwenden, damit das System schnell konvergiert.

NUM_ROUNDS = 11
for round_num in range(2, NUM_ROUNDS):
  state, metrics = iterative_process.next(state, federated_train_data)
  print('round {:2d}, metrics={}'.format(round_num, metrics))
round  2, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('value_sum_process', ()), ('weight_sum_process', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.14609054), ('loss', 2.9141645)]))])
round  3, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('value_sum_process', ()), ('weight_sum_process', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.15205762), ('loss', 2.9237952)]))])
round  4, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('value_sum_process', ()), ('weight_sum_process', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.18600823), ('loss', 2.7629454)]))])
round  5, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('value_sum_process', ()), ('weight_sum_process', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.20884773), ('loss', 2.622908)]))])
round  6, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('value_sum_process', ()), ('weight_sum_process', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.21872428), ('loss', 2.543587)]))])
round  7, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('value_sum_process', ()), ('weight_sum_process', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.2372428), ('loss', 2.4210362)]))])
round  8, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('value_sum_process', ()), ('weight_sum_process', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.28209877), ('loss', 2.2297976)]))])
round  9, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('value_sum_process', ()), ('weight_sum_process', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.2685185), ('loss', 2.195803)]))])
round 10, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('value_sum_process', ()), ('weight_sum_process', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.33868313), ('loss', 2.0523348)]))])

Der Trainingsverlust nimmt nach jeder Runde des Verbundtrainings ab, was darauf hinweist, dass das Modell konvergiert. Es gibt einige wichtige Einschränkungen bei diesen Trainingsmetriken. Weitere Informationen finden Sie im Abschnitt zur Evaluierung weiter unten in diesem Lernprogramm.

Anzeigen von Modellmetriken in TensorBoard

Als nächstes visualisieren wir die Metriken aus diesen Verbundberechnungen mit Tensorboard.

Beginnen wir mit der Erstellung des Verzeichnisses und des entsprechenden Zusammenfassungsschreibers, in den die Metriken geschrieben werden sollen.

logdir = "/tmp/logs/scalars/training/"
summary_writer = tf.summary.create_file_writer(logdir)
state = iterative_process.initialize()

Zeichnen Sie die relevanten skalaren Metriken mit demselben Zusammenfassungsschreiber.

with summary_writer.as_default():
  for round_num in range(1, NUM_ROUNDS):
    state, metrics = iterative_process.next(state, federated_train_data)
    for name, value in metrics['train'].items():
      tf.summary.scalar(name, value, step=round_num)

Starten Sie TensorBoard mit dem oben angegebenen Stammprotokollverzeichnis. Das Laden der Daten kann einige Sekunden dauern.

!ls {logdir}
%tensorboard --logdir {logdir} --port=0
events.out.tfevents.1604020204.isim77-20020ad609500000b02900f40f27a5f6.prod.google.com.686098.10633.v2
events.out.tfevents.1604020602.isim77-20020ad609500000b02900f40f27a5f6.prod.google.com.794554.10607.v2
Launching TensorBoard...
<IPython.core.display.Javascript at 0x7fc5e8d3c128>
# Uncomment and run this this cell to clean your directory of old output for
# future graphs from this directory. We don't run it by default so that if 
# you do a "Runtime > Run all" you don't lose your results.

# !rm -R /tmp/logs/scalars/*

Um die Bewertungsmetriken auf dieselbe Weise anzuzeigen, können Sie einen separaten Bewertungsordner wie "logs / scalars / eval" erstellen, um in TensorBoard zu schreiben.

Anpassen der Modellimplementierung

Keras ist die empfohlene High-Level-Modell-API für TensorFlow . Wir empfehlen , Keras-Modelle (über tff.learning.from_keras_model ) nach Möglichkeit in TFF zu verwenden.

tff.learning bietet jedoch eine untergeordnete Modellschnittstelle, tff.learning.Model , die die minimale Funktionalität tff.learning.Model , die für die Verwendung eines Modells für das tff.learning.Model erforderlich ist. Dietf.keras.layers Implementierung dieser Schnittstelle (möglicherweise immer noch unter Verwendung von Bausteinen wietf.keras.layers ) ermöglicht eine maximale Anpassung, ohne die Interna der föderierten Lernalgorithmen zutf.keras.layers .

Also machen wir es noch einmal von Grund auf neu.

Definieren von Modellvariablen, Vorwärtsdurchlauf und Metriken

Der erste Schritt besteht darin, die TensorFlow-Variablen zu identifizieren, mit denen wir arbeiten werden. Um den folgenden Code besser lesbar zu machen, definieren wir eine Datenstruktur, die den gesamten Satz darstellt. Dies umfasst Variablen wie weights und bias , die wir trainieren, sowie Variablen, die verschiedene kumulative Statistiken und Zähler enthalten, die wir während des Trainings aktualisieren, wie z. B. loss_sum , accuracy_sum und num_examples .

MnistVariables = collections.namedtuple(
    'MnistVariables', 'weights bias num_examples loss_sum accuracy_sum')

Hier ist eine Methode, mit der die Variablen erstellt werden. Der Einfachheit halber stellen wir alle Statistiken als tf.float32 , da dadurch zu einem späteren Zeitpunkt keine Typkonvertierungen mehr erforderlich sind. Das Umschließen von Variableninitialisierern als Lambdas ist eine Anforderung, die von Ressourcenvariablen auferlegt wird.

def create_mnist_variables():
  return MnistVariables(
      weights=tf.Variable(
          lambda: tf.zeros(dtype=tf.float32, shape=(784, 10)),
          name='weights',
          trainable=True),
      bias=tf.Variable(
          lambda: tf.zeros(dtype=tf.float32, shape=(10)),
          name='bias',
          trainable=True),
      num_examples=tf.Variable(0.0, name='num_examples', trainable=False),
      loss_sum=tf.Variable(0.0, name='loss_sum', trainable=False),
      accuracy_sum=tf.Variable(0.0, name='accuracy_sum', trainable=False))

Mit den vorhandenen Variablen für Modellparameter und kumulative Statistiken können wir nun die Forward-Pass-Methode definieren, die Verluste berechnet, Vorhersagen ausgibt und die kumulativen Statistiken für einen einzelnen Stapel von Eingabedaten wie folgt aktualisiert.

def mnist_forward_pass(variables, batch):
  y = tf.nn.softmax(tf.matmul(batch['x'], variables.weights) + variables.bias)
  predictions = tf.cast(tf.argmax(y, 1), tf.int32)

  flat_labels = tf.reshape(batch['y'], [-1])
  loss = -tf.reduce_mean(
      tf.reduce_sum(tf.one_hot(flat_labels, 10) * tf.math.log(y), axis=[1]))
  accuracy = tf.reduce_mean(
      tf.cast(tf.equal(predictions, flat_labels), tf.float32))

  num_examples = tf.cast(tf.size(batch['y']), tf.float32)

  variables.num_examples.assign_add(num_examples)
  variables.loss_sum.assign_add(loss * num_examples)
  variables.accuracy_sum.assign_add(accuracy * num_examples)

  return loss, predictions

Als Nächstes definieren wir eine Funktion, die eine Reihe lokaler Metriken zurückgibt, wiederum unter Verwendung von TensorFlow. Dies sind die Werte (zusätzlich zu Modellaktualisierungen, die automatisch verarbeitet werden), die in einem Verbundlern- oder Bewertungsprozess auf dem Server aggregiert werden können.

Hier geben wir einfach den durchschnittlichen loss und die accuracy sowie die num_examples , die wir benötigen, um die Beiträge verschiedener Benutzer bei der Berechnung von num_examples korrekt zu gewichten.

def get_local_mnist_metrics(variables):
  return collections.OrderedDict(
      num_examples=variables.num_examples,
      loss=variables.loss_sum / variables.num_examples,
      accuracy=variables.accuracy_sum / variables.num_examples)

Schließlich müssen wir bestimmen, wie die von jedem Gerät über get_local_mnist_metrics lokalen Metriken aggregiert werden get_local_mnist_metrics . Dies ist der einzige Teil des Codes, der nicht in TensorFlow geschrieben ist - es handelt sich um eine Verbundberechnung, die in TFF ausgedrückt wird. Wenn Sie tiefer graben möchten, überfliegen Sie das Tutorial für benutzerdefinierte Algorithmen , aber in den meisten Anwendungen müssen Sie dies nicht wirklich tun. Varianten des unten gezeigten Musters sollten ausreichen. So sieht es aus:

@tff.federated_computation
def aggregate_mnist_metrics_across_clients(metrics):
  return collections.OrderedDict(
      num_examples=tff.federated_sum(metrics.num_examples),
      loss=tff.federated_mean(metrics.loss, metrics.num_examples),
      accuracy=tff.federated_mean(metrics.accuracy, metrics.num_examples))

Die metrics Argument entspricht dem OrderedDict durch zurück get_local_mnist_metrics oben, aber kritisch die Werte sind nicht mehr tf.Tensors - sie „boxed“ als tff.Value s, um es Sie können nicht sie deutlich zu machen mehr manipulieren TensorFlow verwenden, aber nur Verwenden der tff.federated_mean von TFF wie tff.federated_mean und tff.federated_sum . Das zurückgegebene Wörterbuch der globalen Aggregate definiert den Satz von Metriken, die auf dem Server verfügbar sein werden.

tff.learning.Model einer Instanz von tff.learning.Model

Mit all dem oben Genannten sind wir bereit, eine Modelldarstellung zur Verwendung mit TFF zu erstellen, die derjenigen ähnelt, die für Sie generiert wurde, wenn Sie TFF ein Keras-Modell aufnehmen lassen.

class MnistModel(tff.learning.Model):

  def __init__(self):
    self._variables = create_mnist_variables()

  @property
  def trainable_variables(self):
    return [self._variables.weights, self._variables.bias]

  @property
  def non_trainable_variables(self):
    return []

  @property
  def local_variables(self):
    return [
        self._variables.num_examples, self._variables.loss_sum,
        self._variables.accuracy_sum
    ]

  @property
  def input_spec(self):
    return collections.OrderedDict(
        x=tf.TensorSpec([None, 784], tf.float32),
        y=tf.TensorSpec([None, 1], tf.int32))

  @tf.function
  def forward_pass(self, batch, training=True):
    del training
    loss, predictions = mnist_forward_pass(self._variables, batch)
    num_exmaples = tf.shape(batch['x'])[0]
    return tff.learning.BatchOutput(
        loss=loss, predictions=predictions, num_examples=num_exmaples)

  @tf.function
  def report_local_outputs(self):
    return get_local_mnist_metrics(self._variables)

  @property
  def federated_output_computation(self):
    return aggregate_mnist_metrics_across_clients

Wie Sie sehen können, entsprechen die von tff.learning.Model definierten abstrakten Methoden und Eigenschaften den Codefragmenten im vorhergehenden Abschnitt, in denen die Variablen eingeführt und die Verluste und Statistiken definiert wurden.

Hier sind einige Punkte hervorzuheben:

  • Alle Zustände, die Ihr Modell verwenden wird, müssen als TensorFlow-Variablen erfasst werden, da TFF zur Laufzeit kein Python verwendet (denken Sie daran, dass Ihr Code so geschrieben sein sollte, dass er auf Mobilgeräten bereitgestellt werden kann; eine ausführlichere Beschreibung finden Sie im Tutorial für benutzerdefinierte Algorithmen Kommentar zu den Gründen).
  • Ihr Modell sollte beschreiben, welche Form von Daten es akzeptiert ( input_spec ), da TFF im Allgemeinen eine stark typisierte Umgebung ist und input_spec für alle Komponenten ermitteln möchte. Das Festlegen des Formats der Eingabe Ihres Modells ist ein wesentlicher Bestandteil davon.
  • Obwohl dies technisch nicht erforderlich ist, empfehlen wir, die gesamte TensorFlow-Logik (Vorwärtsdurchlauf, tf.function usw.) als tf.function s zu tf.function , da dies dazu beiträgt, dass der TensorFlow serialisiert werden kann und keine expliziten Steuerungsabhängigkeiten erforderlich sind.

Das Obige reicht für Auswertungen und Algorithmen wie Federated SGD aus. Für Federated Averaging müssen wir jedoch angeben, wie das Modell für jeden Stapel lokal trainiert werden soll. Wir werden einen lokalen Optimierer angeben, wenn wir den Federated Averaging-Algorithmus erstellen.

Simulation des Verbundtrainings mit dem neuen Modell

Mit all dem oben Gesagten sieht der Rest des Prozesses so aus, wie wir es bereits gesehen haben - ersetzen Sie einfach den Modellkonstruktor durch den Konstruktor unserer neuen Modellklasse und verwenden Sie die beiden Verbundberechnungen in dem iterativen Prozess, den Sie zum Durchlaufen erstellt haben Trainingsrunden.

iterative_process = tff.learning.build_federated_averaging_process(
    MnistModel,
    client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=0.02))
state = iterative_process.initialize()
state, metrics = iterative_process.next(state, federated_train_data)
print('round  1, metrics={}'.format(metrics))
round  1, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('value_sum_process', ()), ('weight_sum_process', ())])), ('train', OrderedDict([('num_examples', 4860.0), ('loss', 3.1527398), ('accuracy', 0.12469136)]))])
for round_num in range(2, 11):
  state, metrics = iterative_process.next(state, federated_train_data)
  print('round {:2d}, metrics={}'.format(round_num, metrics))
round  2, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('value_sum_process', ()), ('weight_sum_process', ())])), ('train', OrderedDict([('num_examples', 4860.0), ('loss', 2.941014), ('accuracy', 0.14218107)]))])
round  3, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('value_sum_process', ()), ('weight_sum_process', ())])), ('train', OrderedDict([('num_examples', 4860.0), ('loss', 2.9052832), ('accuracy', 0.14444445)]))])
round  4, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('value_sum_process', ()), ('weight_sum_process', ())])), ('train', OrderedDict([('num_examples', 4860.0), ('loss', 2.7491086), ('accuracy', 0.17962962)]))])
round  5, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('value_sum_process', ()), ('weight_sum_process', ())])), ('train', OrderedDict([('num_examples', 4860.0), ('loss', 2.5129666), ('accuracy', 0.19526748)]))])
round  6, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('value_sum_process', ()), ('weight_sum_process', ())])), ('train', OrderedDict([('num_examples', 4860.0), ('loss', 2.4175923), ('accuracy', 0.23600823)]))])
round  7, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('value_sum_process', ()), ('weight_sum_process', ())])), ('train', OrderedDict([('num_examples', 4860.0), ('loss', 2.4273515), ('accuracy', 0.24176955)]))])
round  8, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('value_sum_process', ()), ('weight_sum_process', ())])), ('train', OrderedDict([('num_examples', 4860.0), ('loss', 2.2426176), ('accuracy', 0.2802469)]))])
round  9, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('value_sum_process', ()), ('weight_sum_process', ())])), ('train', OrderedDict([('num_examples', 4860.0), ('loss', 2.1567981), ('accuracy', 0.295679)]))])
round 10, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('value_sum_process', ()), ('weight_sum_process', ())])), ('train', OrderedDict([('num_examples', 4860.0), ('loss', 2.1092515), ('accuracy', 0.30843621)]))])

Informationen zum Anzeigen dieser Metriken in TensorBoard finden Sie in den oben unter "Anzeigen von Modellmetriken in TensorBoard" aufgeführten Schritten.

Auswertung

Alle unsere bisherigen Experimente zeigten nur Verbundtrainingsmetriken - die durchschnittlichen Metriken über alle Datenstapel, die für alle Kunden in der Runde trainiert wurden. Dies führt zu den normalen Bedenken hinsichtlich einer Überanpassung, insbesondere da wir der Einfachheit halber in jeder Runde dieselbe Gruppe von Clients verwendet haben. Es gibt jedoch einen zusätzlichen Begriff der Überanpassung in Trainingsmetriken, die für den Federated Averaging-Algorithmus spezifisch sind. Dies ist am einfachsten zu erkennen, wenn wir uns vorstellen, dass jeder Client einen einzelnen Datenstapel hatte, und wir trainieren diesen Stapel für viele Iterationen (Epochen). In diesem Fall passt das lokale Modell schnell genau zu dieser einen Charge, sodass sich die von uns gemittelte lokale Genauigkeitsmetrik 1,0 nähert. Somit können diese Trainingsmetriken als Zeichen dafür angesehen werden, dass das Training Fortschritte macht, aber nicht viel mehr.

Zur Auswertung auf föderierte Daten durchführen, können Sie eine weitere föderierten Berechnung konstruieren nur für diesen Zweck entwickelt, mit der tff.learning.build_federated_evaluation Funktion und vorbei in Ihrem Modell Konstruktor als Argument. Beachten Sie, dass es im Gegensatz zu Federated Averaging, wo wir MnistTrainableModel verwendet MnistTrainableModel , ausreicht, das MnistModel . Bei der Auswertung wird kein Gradientenabstieg durchgeführt, und es müssen keine Optimierer erstellt werden.

Wenn für Experimente und Forschung ein zentraler Testdatensatz verfügbar ist, zeigt Federated Learning for Text Generation eine weitere Bewertungsoption: Die trainierten Gewichte aus dem Verbundlernen übernehmen, auf ein Standard-Keras-Modell anwenden und dann einfach tf.keras.models.Model.evaluate() aufrufen tf.keras.models.Model.evaluate() für ein zentrales Dataset.

evaluation = tff.learning.build_federated_evaluation(MnistModel)

Sie können die abstrakte Typensignatur der Bewertungsfunktion wie folgt überprüfen.

str(evaluation.type_signature)
'(<server_model_weights=<trainable=<float32[784,10],float32[10]>,non_trainable=<>>@SERVER,federated_dataset={<x=float32[?,784],y=int32[?,1]>*}@CLIENTS> -> <num_examples=float32@SERVER,loss=float32@SERVER,accuracy=float32@SERVER>)'

An dieser Stelle müssen Sie sich keine Gedanken über die Details machen. tff.templates.IterativeProcess.next jedoch, dass diese die folgende allgemeine Form haben, ähnlich wie tff.templates.IterativeProcess.next jedoch mit zwei wichtigen Unterschieden. Erstens geben wir den Serverstatus nicht zurück, da die Auswertung das Modell oder einen anderen Aspekt des Status nicht ändert - Sie können sich das als zustandslos vorstellen. Zweitens benötigt die Evaluierung nur das Modell und keinen anderen Teil des Serverstatus, der mit dem Training verbunden sein könnte, wie z. B. Optimierungsvariablen.

SERVER_MODEL, FEDERATED_DATA -> TRAINING_METRICS

Rufen wir die Bewertung des letzten Zustands auf, zu dem wir während des Trainings gekommen sind. Um das neueste trainierte Modell aus dem Serverstatus zu extrahieren, greifen Sie einfach wie folgt auf das .model Mitglied zu.

train_metrics = evaluation(state.model, federated_train_data)

Folgendes bekommen wir. Beachten Sie, dass die Zahlen geringfügig besser aussehen als in der letzten Trainingsrunde oben. Konventionell spiegeln die vom iterativen Trainingsprozess gemeldeten Trainingsmetriken im Allgemeinen die Leistung des Modells zu Beginn der Trainingsrunde wider, sodass die Bewertungsmetriken immer einen Schritt voraus sind.

str(train_metrics)
'<num_examples=4860.0,loss=1.7142657041549683,accuracy=0.38683128356933594>'

Lassen Sie uns nun ein Testbeispiel mit Verbunddaten zusammenstellen und die Auswertung der Testdaten erneut ausführen. Die Daten stammen aus derselben Stichprobe realer Benutzer, jedoch aus einem bestimmten Datensatz.

federated_test_data = make_federated_data(emnist_test, sample_clients)

len(federated_test_data), federated_test_data[0]
(10,
 <DatasetV1Adapter shapes: OrderedDict([(x, (None, 784)), (y, (None, 1))]), types: OrderedDict([(x, tf.float32), (y, tf.int32)])>)
test_metrics = evaluation(state.model, federated_test_data)
str(test_metrics)
'<num_examples=580.0,loss=1.861915111541748,accuracy=0.3362068831920624>'

Damit ist das Tutorial abgeschlossen. Wir empfehlen Ihnen, mit den Parametern (z. B. Stapelgrößen, Anzahl der Benutzer, Epochen, Lernraten usw.) zu spielen, den obigen Code zu ändern, um das Training mit zufälligen Stichproben von Benutzern in jeder Runde zu simulieren und die anderen Tutorials zu erkunden wir haben entwickelt.