Questa pagina è stata tradotta dall'API Cloud Translation.
Switch to English

Apprendimento federato per la classificazione delle immagini

Visualizza su TensorFlow.org Esegui in Google Colab Visualizza sorgente su GitHub

In questo tutorial, utilizziamo il classico esempio di formazione MNIST per introdurre il livello API Federated Learning (FL) di TFF, tff.learning , un set di interfacce di livello superiore che possono essere utilizzate per eseguire tipi comuni di attività di apprendimento federate, come formazione federata, rispetto ai modelli forniti dagli utenti implementati in TensorFlow.

Questo tutorial e l'API Federated Learning sono destinati principalmente agli utenti che desiderano collegare i propri modelli TensorFlow a TFF, trattando quest'ultimo principalmente come una scatola nera. Per una comprensione più approfondita del TFF e di come implementare i propri algoritmi di apprendimento federato, vedere i tutorial sull'API FC Core - Algoritmi federati personalizzati Parte 1 e Parte 2 .

Per ulteriori informazioni su tff.learning , continua con il Federated Learning for Text Generation , tutorial che oltre a coprire i modelli ricorrenti, dimostra anche il caricamento di un modello Keras serializzato pre-addestrato per il perfezionamento con l'apprendimento federato combinato con la valutazione utilizzando Keras.

Prima di iniziare

Prima di iniziare, eseguire quanto segue per assicurarsi che l'ambiente sia configurato correttamente. Se non vedi un saluto, consulta la Guida all'installazione per le istruzioni.


!pip install --quiet --upgrade tensorflow_federated_nightly
!pip install --quiet --upgrade nest_asyncio

import nest_asyncio
nest_asyncio.apply()

%load_ext tensorboard
import collections

import numpy as np
import tensorflow as tf
import tensorflow_federated as tff

np.random.seed(0)

tff.federated_computation(lambda: 'Hello, World!')()
b'Hello, World!'

Preparazione dei dati di input

Cominciamo con i dati. L'apprendimento federato richiede un set di dati federato, ovvero una raccolta di dati da più utenti. I dati federati sono in genere non iid , il che pone una serie di sfide uniche.

Per facilitare la sperimentazione, abbiamo seminato il repository TFF con alcuni set di dati, inclusa una versione federata di MNIST che contiene una versione del set di dati NIST originale che è stata rielaborata utilizzando Leaf in modo che i dati siano codificati dall'autore originale di le cifre. Poiché ogni writer ha uno stile unico, questo set di dati mostra il tipo di comportamento non iid previsto per i set di dati federati.

Ecco come possiamo caricarlo.

emnist_train, emnist_test = tff.simulation.datasets.emnist.load_data()

I set di dati restituiti da load_data() sono istanze di tff.simulation.ClientData , un'interfaccia che consente di enumerare il set di utenti, di costruire un tf.data.Dataset che rappresenta i dati di un particolare utente e di interrogare il struttura dei singoli elementi. Ecco come puoi utilizzare questa interfaccia per esplorare il contenuto del set di dati. Tieni presente che mentre questa interfaccia ti consente di iterare sugli ID client, questa è solo una caratteristica dei dati di simulazione. Come vedrai tra poco, le identità dei client non vengono utilizzate dal framework di apprendimento federato: il loro unico scopo è consentire di selezionare sottoinsiemi di dati per le simulazioni.

len(emnist_train.client_ids)
3383
emnist_train.element_type_structure
OrderedDict([('pixels', TensorSpec(shape=(28, 28), dtype=tf.float32, name=None)), ('label', TensorSpec(shape=(), dtype=tf.int32, name=None))])
example_dataset = emnist_train.create_tf_dataset_for_client(
    emnist_train.client_ids[0])

example_element = next(iter(example_dataset))

example_element['label'].numpy()
1
from matplotlib import pyplot as plt
plt.imshow(example_element['pixels'].numpy(), cmap='gray', aspect='equal')
plt.grid(False)
_ = plt.show()

png

Esplorazione dell'eterogeneità nei dati federati

I dati federati sono in genere non iid , gli utenti hanno generalmente distribuzioni di dati diverse a seconda dei modelli di utilizzo. Alcuni client potrebbero avere un numero inferiore di esempi di addestramento sul dispositivo, a causa della scarsità di dati a livello locale, mentre alcuni client avranno esempi di addestramento più che sufficienti. Esploriamo questo concetto di eterogeneità dei dati tipico di un sistema federato con i dati EMNIST che abbiamo a disposizione. È importante notare che questa analisi approfondita dei dati di un cliente è disponibile solo per noi perché si tratta di un ambiente di simulazione in cui tutti i dati sono disponibili localmente. In un ambiente federato di produzione reale non sarebbe possibile ispezionare i dati di un singolo cliente.

Per prima cosa, prendiamo un campione dei dati di un cliente per avere un'idea degli esempi su un dispositivo simulato. Poiché il set di dati che stiamo utilizzando è stato codificato da un autore unico, i dati di un client rappresentano la scrittura a mano di una persona per un campione delle cifre da 0 a 9, simulando il "modello di utilizzo" univoco di un utente.

## Example MNIST digits for one client
figure = plt.figure(figsize=(20, 4))
j = 0

for example in example_dataset.take(40):
  plt.subplot(4, 10, j+1)
  plt.imshow(example['pixels'].numpy(), cmap='gray', aspect='equal')
  plt.axis('off')
  j += 1

png

Ora visualizziamo il numero di esempi su ogni client per ciascuna etichetta di cifre MNIST. Nell'ambiente federato, il numero di esempi su ciascun client può variare notevolmente, a seconda del comportamento dell'utente.

# Number of examples per layer for a sample of clients
f = plt.figure(figsize=(12, 7))
f.suptitle('Label Counts for a Sample of Clients')
for i in range(6):
  client_dataset = emnist_train.create_tf_dataset_for_client(
      emnist_train.client_ids[i])
  plot_data = collections.defaultdict(list)
  for example in client_dataset:
    # Append counts individually per label to make plots
    # more colorful instead of one color per plot.
    label = example['label'].numpy()
    plot_data[label].append(label)
  plt.subplot(2, 3, i+1)
  plt.title('Client {}'.format(i))
  for j in range(10):
    plt.hist(
        plot_data[j],
        density=False,
        bins=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10])

png

Ora visualizziamo l'immagine media per cliente per ciascuna etichetta MNIST. Questo codice produrrà la media di ogni valore di pixel per tutti gli esempi dell'utente per un'etichetta. Vedremo che l'immagine media di un cliente per una cifra avrà un aspetto diverso dall'immagine media di un altro cliente per la stessa cifra, a causa dello stile di scrittura unico di ogni persona. Possiamo riflettere su come ogni round di formazione locale spingerà il modello in una direzione diversa su ciascun cliente, poiché stiamo imparando dai dati univoci dell'utente in quel round locale. Più avanti nel tutorial vedremo come possiamo prendere ogni aggiornamento del modello da tutti i client e aggregarli insieme nel nostro nuovo modello globale, che ha imparato da ciascuno dei dati univoci del nostro cliente.

# Each client has different mean images, meaning each client will be nudging
# the model in their own directions locally.

for i in range(5):
  client_dataset = emnist_train.create_tf_dataset_for_client(
      emnist_train.client_ids[i])
  plot_data = collections.defaultdict(list)
  for example in client_dataset:
    plot_data[example['label'].numpy()].append(example['pixels'].numpy())
  f = plt.figure(i, figsize=(12, 5))
  f.suptitle("Client #{}'s Mean Image Per Label".format(i))
  for j in range(10):
    mean_img = np.mean(plot_data[j], 0)
    plt.subplot(2, 5, j+1)
    plt.imshow(mean_img.reshape((28, 28)))
    plt.axis('off')

png

png

png

png

png

I dati utente possono essere rumorosi e etichettati in modo inaffidabile. Ad esempio, guardando i dati del Cliente n. 2 sopra, possiamo vedere che per l'etichetta 2, è possibile che ci siano stati alcuni esempi con etichette errate che creano un'immagine media più rumorosa.

Pre-elaborazione dei dati di input

Poiché i dati sono già un tf.data.Dataset , la preelaborazione può essere eseguita utilizzando le trasformazioni del tf.data.Dataset dati. Qui, si appiattire le 28x28 immagini in 784 array -element, mescolate le singoli esempi, organizzarli in lotti, e rinomina le caratteristiche dei pixels e label di x ed y per l'uso con Keras. Inseriamo anche una repeat sul set di dati per eseguire diverse epoche.

NUM_CLIENTS = 10
NUM_EPOCHS = 5
BATCH_SIZE = 20
SHUFFLE_BUFFER = 100
PREFETCH_BUFFER= 10

def preprocess(dataset):

  def batch_format_fn(element):
    """Flatten a batch `pixels` and return the features as an `OrderedDict`."""
    return collections.OrderedDict(
        x=tf.reshape(element['pixels'], [-1, 784]),
        y=tf.reshape(element['label'], [-1, 1]))

  return dataset.repeat(NUM_EPOCHS).shuffle(SHUFFLE_BUFFER).batch(
      BATCH_SIZE).map(batch_format_fn).prefetch(PREFETCH_BUFFER)

Verifichiamo che abbia funzionato.

preprocessed_example_dataset = preprocess(example_dataset)

sample_batch = tf.nest.map_structure(lambda x: x.numpy(),
                                     next(iter(preprocessed_example_dataset)))

sample_batch
OrderedDict([('x', array([[1., 1., 1., ..., 1., 1., 1.],
       [1., 1., 1., ..., 1., 1., 1.],
       [1., 1., 1., ..., 1., 1., 1.],
       ...,
       [1., 1., 1., ..., 1., 1., 1.],
       [1., 1., 1., ..., 1., 1., 1.],
       [1., 1., 1., ..., 1., 1., 1.]], dtype=float32)), ('y', array([[2],
       [1],
       [2],
       [3],
       [6],
       [0],
       [1],
       [4],
       [1],
       [0],
       [6],
       [9],
       [9],
       [3],
       [6],
       [1],
       [4],
       [8],
       [0],
       [2]], dtype=int32))])

Abbiamo quasi tutti gli elementi costitutivi per costruire set di dati federati.

Uno dei modi per fornire dati federati a TFF in una simulazione è semplicemente come un elenco Python, con ogni elemento dell'elenco che tf.data.Dataset i dati di un singolo utente, sia come elenco che come tf.data.Dataset . Poiché abbiamo già un'interfaccia che fornisce quest'ultima, usiamola.

Ecco una semplice funzione di supporto che costruirà un elenco di set di dati dal set di utenti specificato come input per un ciclo di formazione o valutazione.

def make_federated_data(client_data, client_ids):
  return [
      preprocess(client_data.create_tf_dataset_for_client(x))
      for x in client_ids
  ]

Ora, come scegliamo i clienti?

In un tipico scenario di formazione federata, abbiamo a che fare con una popolazione potenzialmente molto ampia di dispositivi utente, solo una parte dei quali può essere disponibile per la formazione in un dato momento. Questo è il caso, ad esempio, quando i dispositivi client sono telefoni cellulari che partecipano alla formazione solo quando sono collegati a una fonte di alimentazione, fuori da una rete a consumo e altrimenti inattivi.

Ovviamente siamo in un ambiente di simulazione e tutti i dati sono disponibili localmente. In genere, quindi, durante l'esecuzione di simulazioni, campioneremmo semplicemente un sottoinsieme casuale di clienti da coinvolgere in ogni round di formazione, generalmente diverso in ogni round.

Detto questo, come puoi scoprire studiando il documento sull'algoritmo di media federata , raggiungere la convergenza in un sistema con sottoinsiemi di client campionati casualmente in ogni round può richiedere del tempo e sarebbe poco pratico dover eseguire centinaia di round in questo tutorial interattivo.

Quello che faremo invece è campionare il set di client una volta e riutilizzare lo stesso set in più round per accelerare la convergenza (adattandosi intenzionalmente eccessivamente ai dati di questi pochi utenti). Lasciamo come esercizio al lettore la modifica di questo tutorial per simulare il campionamento casuale: è abbastanza facile da fare (una volta fatto, tieni presente che la convergenza del modello potrebbe richiedere del tempo).

sample_clients = emnist_train.client_ids[0:NUM_CLIENTS]

federated_train_data = make_federated_data(emnist_train, sample_clients)

print('Number of client datasets: {l}'.format(l=len(federated_train_data)))
print('First dataset: {d}'.format(d=federated_train_data[0]))
Number of client datasets: 10
First dataset: <DatasetV1Adapter shapes: OrderedDict([(x, (None, 784)), (y, (None, 1))]), types: OrderedDict([(x, tf.float32), (y, tf.int32)])>

Creare un modello con Keras

Se stai usando Keras, probabilmente hai già del codice che costruisce un modello Keras. Ecco un esempio di un modello semplice che sarà sufficiente per le nostre esigenze.

def create_keras_model():
  return tf.keras.models.Sequential([
      tf.keras.layers.Input(shape=(784,)),
      tf.keras.layers.Dense(10, kernel_initializer='zeros'),
      tf.keras.layers.Softmax(),
  ])

Per poter utilizzare qualsiasi modello con TFF, deve essere racchiuso in un'istanza dell'interfaccia tff.learning.Model , che espone metodi per timbrare il forward pass del modello, proprietà dei metadati, ecc., In modo simile a Keras, ma introduce anche metodi aggiuntivi elementi, come i modi per controllare il processo di elaborazione delle metriche federate. Non preoccupiamoci di questo per ora; se hai un modello Keras come quello che abbiamo appena definito sopra, puoi fare in modo che TFF lo tff.learning.from_keras_model invocando tff.learning.from_keras_model , passando il modello e un batch di dati di esempio come argomenti, come mostrato di seguito.

def model_fn():
  # We _must_ create a new model here, and _not_ capture it from an external
  # scope. TFF will call this within different graph contexts.
  keras_model = create_keras_model()
  return tff.learning.from_keras_model(
      keras_model,
      input_spec=preprocessed_example_dataset.element_spec,
      loss=tf.keras.losses.SparseCategoricalCrossentropy(),
      metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])

Addestramento del modello su dati federati

Ora che abbiamo un modello impacchettato come tff.learning.Model da utilizzare con TFF, possiamo lasciare che TFF costruisca un algoritmo di media federata invocando la funzione di supporto tff.learning.build_federated_averaging_process , come segue.

Tieni presente che l'argomento deve essere un costruttore (come model_fn sopra), non un'istanza già costruita, in modo che la costruzione del tuo modello possa avvenire in un contesto controllato da TFF (se sei curioso dei motivi per questo, ti invitiamo a leggere il tutorial di follow-up sugli algoritmi personalizzati ).

Una nota fondamentale sull'algoritmo della media federata di seguito, ci sono 2 ottimizzatori: un _client optimizer e un _server optimizer . L' ottimizzatore _client viene utilizzato solo per calcolare gli aggiornamenti del modello locale su ogni client. L' ottimizzatore _server applica l'aggiornamento medio al modello globale sul server. In particolare, ciò significa che la scelta dell'ottimizzatore e della velocità di apprendimento utilizzati potrebbe dover essere diversi da quelli utilizzati per addestrare il modello su un set di dati iid standard. Ti consigliamo di iniziare con SGD regolare, possibilmente con un tasso di apprendimento inferiore al solito. Il tasso di apprendimento che utilizziamo non è stato attentamente regolato, sentiti libero di sperimentare.

iterative_process = tff.learning.build_federated_averaging_process(
    model_fn,
    client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=0.02),
    server_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=1.0))

Cosa è appena successo? TFF ha costruito una coppia di calcoli federati e li ha impacchettati in un tff.templates.IterativeProcess in cui questi calcoli sono disponibili come coppia di proprietà initialize e next .

In poche parole, i calcoli federati sono programmi nel linguaggio interno di TFF che possono esprimere vari algoritmi federati (puoi trovare ulteriori informazioni su questo nel tutorial sugli algoritmi personalizzati ). In questo caso, i due calcoli generati e compressi in iterative_process implementano Federated Averaging .

Uno degli obiettivi di TFF è definire i calcoli in modo che possano essere eseguiti in contesti di apprendimento federati reali, ma attualmente è implementato solo il runtime di simulazione di esecuzione locale. Per eseguire un calcolo in un simulatore, lo invocate semplicemente come una funzione Python. Questo ambiente interpretato di default non è progettato per prestazioni elevate, ma sarà sufficiente per questo tutorial; prevediamo di fornire runtime di simulazione a prestazioni più elevate per facilitare la ricerca su scala più ampia nelle versioni future.

Cominciamo con l' initialize calcolo. Come nel caso di tutti i calcoli federati, puoi pensarlo come una funzione. Il calcolo non accetta argomenti e restituisce un risultato: la rappresentazione dello stato del processo di media federata sul server. Anche se non vogliamo immergerci nei dettagli del TFF, potrebbe essere istruttivo vedere come appare questo stato. Puoi visualizzarlo come segue.

str(iterative_process.initialize.type_signature)
'( -> <model=<trainable=<float32[784,10],float32[10]>,non_trainable=<>>,optimizer_state=<int64>,delta_aggregate_state=<>,model_broadcast_state=<>>@SERVER)'

Sebbene la firma del tipo di cui sopra possa sembrare a prima vista un po 'criptica, puoi riconoscere che lo stato del server è costituito da un model (i parametri del modello iniziale per MNIST che verranno distribuiti a tutti i dispositivi) e optimizer_state (informazioni aggiuntive mantenute dal server, come il numero di cicli da utilizzare per le pianificazioni degli iperparametri, ecc.).

Richiamiamo il calcolo initialize per costruire lo stato del server.

state = iterative_process.initialize()

Il secondo della coppia di calcoli federati, il next , rappresenta un singolo round di Federated Averaging, che consiste nel forzare lo stato del server (inclusi i parametri del modello) ai client, l'addestramento sul dispositivo sui dati locali, la raccolta e la media degli aggiornamenti del modello e producendo un nuovo modello aggiornato sul server.

Concettualmente, puoi pensare che il next abbia una firma di tipo funzionale che appare come segue.

SERVER_STATE, FEDERATED_DATA -> SERVER_STATE, TRAINING_METRICS

In particolare, si dovrebbe pensare a next() non come a una funzione che viene eseguita su un server, ma piuttosto come una rappresentazione funzionale dichiarativa dell'intero calcolo decentralizzato - alcuni degli input sono forniti dal server ( SERVER_STATE ), ma ciascuno partecipante il dispositivo fornisce il proprio set di dati locale.

Facciamo un singolo round di allenamento e visualizziamo i risultati. Possiamo utilizzare i dati federati che abbiamo già generato sopra per un campione di utenti.

state, metrics = iterative_process.next(state, federated_train_data)
print('round  1, metrics={}'.format(metrics))
round  1, metrics=<broadcast=<>,aggregation=<>,train=<sparse_categorical_accuracy=0.12037037312984467,loss=3.0108425617218018>>

Facciamo ancora qualche round. Come notato in precedenza, in genere a questo punto sceglieresti un sottoinsieme dei tuoi dati di simulazione da un nuovo campione di utenti selezionato casualmente per ogni round al fine di simulare una distribuzione realistica in cui gli utenti vanno e vengono continuamente, ma in questo taccuino interattivo, per per motivi di dimostrazione, riutilizzeremo semplicemente gli stessi utenti, in modo che il sistema converga rapidamente.

NUM_ROUNDS = 11
for round_num in range(2, NUM_ROUNDS):
  state, metrics = iterative_process.next(state, federated_train_data)
  print('round {:2d}, metrics={}'.format(round_num, metrics))
round  2, metrics=<broadcast=<>,aggregation=<>,train=<sparse_categorical_accuracy=0.14814814925193787,loss=2.8865506649017334>>
round  3, metrics=<broadcast=<>,aggregation=<>,train=<sparse_categorical_accuracy=0.148765429854393,loss=2.9079062938690186>>
round  4, metrics=<broadcast=<>,aggregation=<>,train=<sparse_categorical_accuracy=0.17633745074272156,loss=2.724686622619629>>
round  5, metrics=<broadcast=<>,aggregation=<>,train=<sparse_categorical_accuracy=0.20226337015628815,loss=2.6334855556488037>>
round  6, metrics=<broadcast=<>,aggregation=<>,train=<sparse_categorical_accuracy=0.22427983582019806,loss=2.5482592582702637>>
round  7, metrics=<broadcast=<>,aggregation=<>,train=<sparse_categorical_accuracy=0.24094650149345398,loss=2.4472343921661377>>
round  8, metrics=<broadcast=<>,aggregation=<>,train=<sparse_categorical_accuracy=0.259876549243927,loss=2.3809611797332764>>
round  9, metrics=<broadcast=<>,aggregation=<>,train=<sparse_categorical_accuracy=0.29814815521240234,loss=2.156442403793335>>
round 10, metrics=<broadcast=<>,aggregation=<>,train=<sparse_categorical_accuracy=0.31687241792678833,loss=2.122845411300659>>

La perdita di formazione sta diminuendo dopo ogni round di formazione federata, indicando che il modello sta convergendo. Ci sono alcuni importanti avvertimenti con queste metriche di allenamento, tuttavia, vedere la sezione sulla valutazione più avanti in questo tutorial.

Visualizzazione delle metriche del modello in TensorBoard

Successivamente, visualizziamo le metriche di questi calcoli federati utilizzando Tensorboard.

Iniziamo creando la directory e il corrispondente scrittore di riepilogo in cui scrivere le metriche.


logdir = "/tmp/logs/scalars/training/"
summary_writer = tf.summary.create_file_writer(logdir)
state = iterative_process.initialize()

Traccia le metriche scalari pertinenti con lo stesso scrittore di riepilogo.


with summary_writer.as_default():
  for round_num in range(1, NUM_ROUNDS):
    state, metrics = iterative_process.next(state, federated_train_data)
    for name, value in metrics.train._asdict().items():
      tf.summary.scalar(name, value, step=round_num)

Avvia TensorBoard con la directory di log principale specificata sopra. Il caricamento dei dati può richiedere alcuni secondi.


%tensorboard --logdir /tmp/logs/scalars/ --port=0

# Run this this cell to clean your directory of old output for future graphs from this directory.
!rm -R /tmp/logs/scalars/*

Per visualizzare le metriche di valutazione allo stesso modo, è possibile creare una cartella eval separata, come "logs / scalars / eval", da scrivere su TensorBoard.

Personalizzazione dell'implementazione del modello

Keras è l' API del modello di alto livello consigliata per TensorFlow e incoraggiamo l'utilizzo di modelli Keras (tramite tff.learning.from_keras_model ) in TFF quando possibile.

Tuttavia, tff.learning fornisce un'interfaccia del modello di livello inferiore, tff.learning.Model , che espone le funzionalità minime necessarie per l'utilizzo di un modello per l'apprendimento federato. L'implementazione diretta di questa interfaccia (possibilmente ancora utilizzando blocchi tf.keras.layers come tf.keras.layers ) consente la massima personalizzazione senza modificare gli interni degli algoritmi di apprendimento federati.

Quindi rifacciamo tutto da capo.

Definizione di variabili del modello, forward pass e metriche

Il primo passo è identificare le variabili TensorFlow con cui lavoreremo. Per rendere più leggibile il codice seguente, definiamo una struttura dati per rappresentare l'intero set. Ciò includerà variabili come weights e bias che addestreremo, nonché variabili che loss_sum varie statistiche cumulative e contatori che aggiorneremo durante l'allenamento, come loss_sum , accuracy_sum e num_examples .

MnistVariables = collections.namedtuple(
    'MnistVariables', 'weights bias num_examples loss_sum accuracy_sum')

Ecco un metodo che crea le variabili. Per semplicità, rappresentiamo tutte le statistiche come tf.float32 , poiché ciò eliminerà la necessità di conversioni di tipo in una fase successiva. Il wrapping degli inizializzatori di variabili come lambda è un requisito imposto dalle variabili di risorsa .

def create_mnist_variables():
  return MnistVariables(
      weights=tf.Variable(
          lambda: tf.zeros(dtype=tf.float32, shape=(784, 10)),
          name='weights',
          trainable=True),
      bias=tf.Variable(
          lambda: tf.zeros(dtype=tf.float32, shape=(10)),
          name='bias',
          trainable=True),
      num_examples=tf.Variable(0.0, name='num_examples', trainable=False),
      loss_sum=tf.Variable(0.0, name='loss_sum', trainable=False),
      accuracy_sum=tf.Variable(0.0, name='accuracy_sum', trainable=False))

Con le variabili per i parametri del modello e le statistiche cumulative in atto, ora possiamo definire il metodo del passaggio in avanti che calcola la perdita, emette previsioni e aggiorna le statistiche cumulative per un singolo batch di dati di input, come segue.

def mnist_forward_pass(variables, batch):
  y = tf.nn.softmax(tf.matmul(batch['x'], variables.weights) + variables.bias)
  predictions = tf.cast(tf.argmax(y, 1), tf.int32)

  flat_labels = tf.reshape(batch['y'], [-1])
  loss = -tf.reduce_mean(
      tf.reduce_sum(tf.one_hot(flat_labels, 10) * tf.math.log(y), axis=[1]))
  accuracy = tf.reduce_mean(
      tf.cast(tf.equal(predictions, flat_labels), tf.float32))

  num_examples = tf.cast(tf.size(batch['y']), tf.float32)

  variables.num_examples.assign_add(num_examples)
  variables.loss_sum.assign_add(loss * num_examples)
  variables.accuracy_sum.assign_add(accuracy * num_examples)

  return loss, predictions

Successivamente, definiamo una funzione che restituisce un set di metriche locali, sempre utilizzando TensorFlow. Questi sono i valori (oltre agli aggiornamenti del modello, che vengono gestiti automaticamente) che possono essere aggregati al server in un processo di apprendimento o valutazione federato.

Qui, restituiamo semplicemente la loss media e l' accuracy , nonché i num_examples , di cui avremo bisogno per pesare correttamente i contributi di diversi utenti durante il calcolo degli aggregati federati.

def get_local_mnist_metrics(variables):
  return collections.OrderedDict(
      num_examples=variables.num_examples,
      loss=variables.loss_sum / variables.num_examples,
      accuracy=variables.accuracy_sum / variables.num_examples)

Infine, dobbiamo determinare come aggregare le metriche locali emesse da ciascun dispositivo tramite get_local_mnist_metrics . Questa è l'unica parte del codice che non è scritta in TensorFlow: è un calcolo federato espresso in TFF. Se desideri approfondire, scorri il tutorial sugli algoritmi personalizzati , ma nella maggior parte delle applicazioni, non ne avrai davvero bisogno; le varianti del modello mostrato di seguito dovrebbero essere sufficienti. Ecco come appare:

@tff.federated_computation
def aggregate_mnist_metrics_across_clients(metrics):
  return collections.OrderedDict(
      num_examples=tff.federated_sum(metrics.num_examples),
      loss=tff.federated_mean(metrics.loss, metrics.num_examples),
      accuracy=tff.federated_mean(metrics.accuracy, metrics.num_examples))
  

L'argomento delle metrics input corrisponde OrderedDict restituito da get_local_mnist_metrics sopra, ma in modo critico i valori non sono più tf.Tensors - sono "inscatolati" come tff.Value s, per chiarire che non puoi più manipolarli utilizzando TensorFlow, ma solo utilizzando gli operatori federati di TFF come tff.federated_mean e tff.federated_sum . Il dizionario restituito degli aggregati globali definisce l'insieme di metriche che saranno disponibili sul server.

Costruire un'istanza di tff.learning.Model

Con tutto quanto sopra in atto, siamo pronti per costruire una rappresentazione del modello da utilizzare con TFF simile a quella generata per te quando lasci che TFF ingerisca un modello Keras.

class MnistModel(tff.learning.Model):

  def __init__(self):
    self._variables = create_mnist_variables()

  @property
  def trainable_variables(self):
    return [self._variables.weights, self._variables.bias]

  @property
  def non_trainable_variables(self):
    return []

  @property
  def local_variables(self):
    return [
        self._variables.num_examples, self._variables.loss_sum,
        self._variables.accuracy_sum
    ]

  @property
  def input_spec(self):
    return collections.OrderedDict(
        x=tf.TensorSpec([None, 784], tf.float32),
        y=tf.TensorSpec([None, 1], tf.int32))

  @tf.function
  def forward_pass(self, batch, training=True):
    del training
    loss, predictions = mnist_forward_pass(self._variables, batch)
    num_exmaples = tf.shape(batch['x'])[0]
    return tff.learning.BatchOutput(
        loss=loss, predictions=predictions, num_examples=num_exmaples)

  @tf.function
  def report_local_outputs(self):
    return get_local_mnist_metrics(self._variables)

  @property
  def federated_output_computation(self):
    return aggregate_mnist_metrics_across_clients

Come puoi vedere, i metodi e le proprietà astratti definiti da tff.learning.Model corrispondono agli snippet di codice nella sezione precedente che ha introdotto le variabili e definito la perdita e le statistiche.

Ecco alcuni punti che vale la pena sottolineare:

  • Tutti gli stati che il tuo modello utilizzerà devono essere catturati come variabili TensorFlow, poiché TFF non usa Python in fase di runtime (ricorda che il tuo codice deve essere scritto in modo tale da poter essere distribuito su dispositivi mobili; guarda il tutorial sugli algoritmi personalizzati per un approfondimento più approfondito commento sui motivi).
  • Il tuo modello dovrebbe descrivere quale forma di dati accetta ( input_spec ), poiché in generale, TFF è un ambiente fortemente tipizzato e vuole determinare le firme del tipo per tutti i componenti. La dichiarazione del formato dell'input del tuo modello è una parte essenziale di esso.
  • Sebbene non sia tecnicamente richiesto, consigliamo di racchiudere tutta la logica TensorFlow (passaggio in avanti, calcoli metrici, ecc.) Come tf.function s, poiché ciò aiuta a garantire che TensorFlow possa essere serializzato ed elimina la necessità di dipendenze di controllo esplicite.

Quanto sopra è sufficiente per la valutazione e gli algoritmi come Federated SGD. Tuttavia, per la media federata, è necessario specificare in che modo il modello deve essere addestrato localmente su ogni batch. Specificheremo un ottimizzatore locale durante la creazione dell'algoritmo di media federata.

Simulazione della formazione federata con il nuovo modello

Con tutto quanto sopra in atto, il resto del processo assomiglia a quello che abbiamo già visto: basta sostituire il costruttore del modello con il costruttore della nostra nuova classe del modello e utilizzare i due calcoli federati nel processo iterativo che hai creato per scorrere round di formazione.

iterative_process = tff.learning.build_federated_averaging_process(
    MnistModel,
    client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=0.02))
state = iterative_process.initialize()
state, metrics = iterative_process.next(state, federated_train_data)
print('round  1, metrics={}'.format(metrics))
round  1, metrics=<broadcast=<>,aggregation=<>,train=<num_examples=4860.0,loss=2.9713594913482666,accuracy=0.13518518209457397>>

for round_num in range(2, 11):
  state, metrics = iterative_process.next(state, federated_train_data)
  print('round {:2d}, metrics={}'.format(round_num, metrics))
round  2, metrics=<broadcast=<>,aggregation=<>,train=<num_examples=4860.0,loss=2.975412607192993,accuracy=0.14032921195030212>>
round  3, metrics=<broadcast=<>,aggregation=<>,train=<num_examples=4860.0,loss=2.9395227432250977,accuracy=0.1594650149345398>>
round  4, metrics=<broadcast=<>,aggregation=<>,train=<num_examples=4860.0,loss=2.710164785385132,accuracy=0.17139917612075806>>
round  5, metrics=<broadcast=<>,aggregation=<>,train=<num_examples=4860.0,loss=2.5891618728637695,accuracy=0.20267489552497864>>
round  6, metrics=<broadcast=<>,aggregation=<>,train=<num_examples=4860.0,loss=2.5148487091064453,accuracy=0.21666666865348816>>
round  7, metrics=<broadcast=<>,aggregation=<>,train=<num_examples=4860.0,loss=2.2816808223724365,accuracy=0.2580246925354004>>
round  8, metrics=<broadcast=<>,aggregation=<>,train=<num_examples=4860.0,loss=2.3656885623931885,accuracy=0.25884774327278137>>
round  9, metrics=<broadcast=<>,aggregation=<>,train=<num_examples=4860.0,loss=2.23549222946167,accuracy=0.28477364778518677>>
round 10, metrics=<broadcast=<>,aggregation=<>,train=<num_examples=4860.0,loss=1.974222183227539,accuracy=0.35329216718673706>>

Per visualizzare queste metriche in TensorBoard, fare riferimento ai passaggi elencati sopra in "Visualizzazione delle metriche del modello in TensorBoard".

Valutazione

Tutti i nostri esperimenti finora hanno presentato solo metriche di formazione federate, le metriche medie su tutti i batch di dati addestrati su tutti i clienti nel round. Ciò introduce le normali preoccupazioni sull'overfitting, soprattutto perché abbiamo utilizzato lo stesso set di client in ogni round per semplicità, ma esiste un'ulteriore nozione di overfitting nelle metriche di allenamento specifiche dell'algoritmo di media federata. È più facile vedere se immaginiamo che ogni cliente avesse un singolo batch di dati e ci addestriamo su quel batch per molte iterazioni (epoche). In questo caso, il modello locale si adatterà rapidamente esattamente a quel batch, quindi la metrica di accuratezza locale mediamente si avvicinerà a 1.0. Pertanto, queste metriche di allenamento possono essere prese come un segno che la formazione sta progredendo, ma non molto di più.

Per eseguire la valutazione sui dati federati, è possibile costruire un altro calcolo federato progettato proprio a questo scopo, utilizzando la funzione tff.learning.build_federated_evaluation e passando il costruttore del modello come argomento. Nota che a differenza di Federated Averaging, dove abbiamo usato MnistTrainableModel , è sufficiente passare MnistModel . La valutazione non esegue la discesa del gradiente e non è necessario creare ottimizzatori.

Per la sperimentazione e la ricerca, quando è disponibile un set di dati di test centralizzato, Federated Learning for Text Generation dimostra un'altra opzione di valutazione: prendere i pesi addestrati dall'apprendimento federato, applicarli a un modello Keras standard e quindi chiamare semplicemente tf.keras.models.Model.evaluate() su un set di dati centralizzato.

evaluation = tff.learning.build_federated_evaluation(MnistModel)

È possibile esaminare la firma del tipo astratto della funzione di valutazione come segue.

str(evaluation.type_signature)
'(<<trainable=<float32[784,10],float32[10]>,non_trainable=<>>@SERVER,{<x=float32[?,784],y=int32[?,1]>*}@CLIENTS> -> <num_examples=float32@SERVER,loss=float32@SERVER,accuracy=float32@SERVER>)'

Non c'è bisogno di preoccuparsi dei dettagli a questo punto, basta essere consapevoli che assume la seguente forma generale, simile a tff.templates.IterativeProcess.next ma con due importanti differenze. Primo, non restituiamo lo stato del server, poiché la valutazione non modifica il modello o qualsiasi altro aspetto dello stato: puoi considerarlo senza stato. In secondo luogo, la valutazione necessita solo del modello e non richiede altre parti dello stato del server che potrebbero essere associate all'addestramento, come le variabili di ottimizzazione.

SERVER_MODEL, FEDERATED_DATA -> TRAINING_METRICS

Richiamiamo la valutazione sull'ultimo stato a cui siamo arrivati ​​durante l'addestramento. Per estrarre l'ultimo modello addestrato dallo stato del server, è sufficiente accedere al membro .model , come segue.

train_metrics = evaluation(state.model, federated_train_data)

Ecco cosa otteniamo. Nota che i numeri sembrano leggermente migliori di quanto riportato dall'ultimo round di allenamento sopra. Per convenzione, le metriche di formazione riportate dal processo di formazione iterativo riflettono generalmente le prestazioni del modello all'inizio del ciclo di formazione, quindi le metriche di valutazione saranno sempre un passo avanti.

str(train_metrics)
'<num_examples=4860.0,loss=1.7142657041549683,accuracy=0.38683128356933594>'

Compiliamo ora un campione di prova di dati federati ed eseguiamo nuovamente la valutazione sui dati di prova. I dati proverranno dallo stesso campione di utenti reali, ma da un set di dati distinto.

federated_test_data = make_federated_data(emnist_test, sample_clients)

len(federated_test_data), federated_test_data[0]
(10,
 <DatasetV1Adapter shapes: OrderedDict([(x, (None, 784)), (y, (None, 1))]), types: OrderedDict([(x, tf.float32), (y, tf.int32)])>)
test_metrics = evaluation(state.model, federated_test_data)
str(test_metrics)
'<num_examples=580.0,loss=1.861915111541748,accuracy=0.3362068831920624>'

Questo conclude il tutorial. Ti invitiamo a giocare con i parametri (ad esempio, dimensioni del batch, numero di utenti, epoche, tassi di apprendimento, ecc.), A modificare il codice sopra per simulare l'addestramento su campioni casuali di utenti in ogni round e ad esplorare gli altri tutorial abbiamo sviluppato.