Salva la data! Google I / O ritorna dal 18 al 20 maggio Registrati ora
Questa pagina è stata tradotta dall'API Cloud Translation.
Switch to English

Costruire il proprio algoritmo di apprendimento federato

Visualizza su TensorFlow.org Esegui in Google Colab Visualizza sorgente su GitHub Scarica taccuino

Prima di iniziare

Prima di iniziare, eseguire quanto segue per assicurarsi che l'ambiente sia configurato correttamente. Se non vedi un saluto, fai riferimento alla Guida all'installazione per le istruzioni.

!pip install --quiet --upgrade tensorflow-federated-nightly
!pip install --quiet --upgrade nest-asyncio

import nest_asyncio
nest_asyncio.apply()
import collections
import attr
import functools
import numpy as np
import tensorflow as tf
import tensorflow_federated as tff

np.random.seed(0)

Nelle esercitazioni sulla classificazione delle immagini e sulla generazione di testo , abbiamo appreso come configurare il modello e le pipeline di dati per Federated Learning (FL) e abbiamo eseguito la formazione federata tramite il tff.learning API tff.learning di TFF.

Questa è solo la punta dell'iceberg quando si tratta di ricerca FL. In questo tutorial, discuteremo come implementare algoritmi di apprendimento federato senza tff.learning all'API tff.learning . Miriamo a realizzare quanto segue:

Obiettivi:

  • Comprendere la struttura generale degli algoritmi di apprendimento federato.
  • Esplora il Federated Core di TFF.
  • Utilizza Federated Core per implementare direttamente la media federata.

Sebbene questo tutorial sia autonomo, consigliamo di leggere prima i tutorial sulla classificazione delle immagini e sulla generazione del testo .

Preparazione dei dati di input

Per prima cosa carichiamo e preelaboriamo il set di dati EMNIST incluso in TFF. Per ulteriori dettagli, vedere il tutorial sulla classificazione delle immagini .

emnist_train, emnist_test = tff.simulation.datasets.emnist.load_data()

Per inserire il set di dati nel nostro modello, appiattiamo i dati e convertiamo ogni esempio in una tupla del modulo (flattened_image_vector, label) .

NUM_CLIENTS = 10
BATCH_SIZE = 20

def preprocess(dataset):

  def batch_format_fn(element):
    """Flatten a batch of EMNIST data and return a (features, label) tuple."""
    return (tf.reshape(element['pixels'], [-1, 784]), 
            tf.reshape(element['label'], [-1, 1]))

  return dataset.batch(BATCH_SIZE).map(batch_format_fn)

Ora campioniamo un piccolo numero di client e applichiamo la preelaborazione sopra ai loro set di dati.

client_ids = np.random.choice(emnist_train.client_ids, size=NUM_CLIENTS, replace=False)

federated_train_data = [preprocess(emnist_train.create_tf_dataset_for_client(x))
  for x in client_ids
]

Preparazione del modello

Usiamo lo stesso modello del tutorial sulla classificazione delle immagini . Questo modello (implementato tramite tf.keras ) ha un singolo livello nascosto, seguito da un livello softmax.

def create_keras_model():
  return tf.keras.models.Sequential([
      tf.keras.layers.Input(shape=(784,)),
      tf.keras.layers.Dense(10, kernel_initializer='zeros'),
      tf.keras.layers.Softmax(),
  ])

Per utilizzare questo modello in TFF, avvolgiamo il modello Keras come tff.learning.Model . Questo ci consente di eseguire il passaggio in avanti del modello all'interno di TFF ed estrarre gli output del modello . Per ulteriori dettagli, vedere anche il tutorial sulla classificazione delle immagini .

def model_fn():
  keras_model = create_keras_model()
  return tff.learning.from_keras_model(
      keras_model,
      input_spec=federated_train_data[0].element_spec,
      loss=tf.keras.losses.SparseCategoricalCrossentropy(),
      metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])

Mentre abbiamo usato tf.keras per creare un tff.learning.Model , TFF supporta modelli molto più generali. Questi modelli hanno i seguenti attributi rilevanti che catturano i pesi del modello:

  • trainable_variables : un iterabile dei tensori corrispondenti ai livelli addestrabili.
  • non_trainable_variables : un iterabile dei tensori corrispondenti a strati non addestrabili.

Per i nostri scopi, utilizzeremo solo trainable_variables . (poiché il nostro modello ha solo quelli!).

Costruire il proprio algoritmo di apprendimento federato

Sebbene l'API tff.learning consenta di creare molte varianti di Federated Averaging, esistono altri algoritmi federati che non si adattano perfettamente a questo framework. Ad esempio, potresti voler aggiungere regolarizzazione, clipping o algoritmi più complicati come l' addestramento GAN federato . Potresti anche essere interessato all'analisi federata .

Per questi algoritmi più avanzati, dovremo scrivere il nostro algoritmo personalizzato utilizzando TFF. In molti casi, gli algoritmi federati hanno 4 componenti principali:

  1. Un passaggio di trasmissione da server a client.
  2. Un passaggio di aggiornamento del client locale.
  3. Un passaggio di caricamento da client a server.
  4. Un passaggio di aggiornamento del server.

In TFF, generalmente rappresentiamo gli algoritmi federati come tff.templates.IterativeProcess (a cui ci riferiamo solo come IterativeProcess tutto). Questa è una classe che contiene le funzioni initialize e next . Qui, initialize viene utilizzato per inizializzare il server e next eseguirà un round di comunicazione dell'algoritmo federato. Scriviamo uno scheletro di come dovrebbe essere il nostro processo iterativo per FedAvg.

Innanzitutto, abbiamo una funzione di inizializzazione che crea semplicemente un tff.learning.Model e restituisce i suoi pesi addestrabili.

def initialize_fn():
  model = model_fn()
  return model.trainable_variables

Questa funzione sembra buona, ma come vedremo in seguito, avremo bisogno di fare una piccola modifica per renderla un "calcolo TFF".

Vogliamo anche disegnare il next_fn .

def next_fn(server_weights, federated_dataset):
  # Broadcast the server weights to the clients.
  server_weights_at_client = broadcast(server_weights)

  # Each client computes their updated weights.
  client_weights = client_update(federated_dataset, server_weights_at_client)

  # The server averages these updates.
  mean_client_weights = mean(client_weights)

  # The server updates its model.
  server_weights = server_update(mean_client_weights)

  return server_weights

Ci concentreremo sull'implementazione di questi quattro componenti separatamente. In primo luogo ci concentriamo sulle parti che possono essere implementate in TensorFlow puro, vale a dire i passaggi di aggiornamento del client e del server.

Blocchi TensorFlow

Aggiornamento client

Useremo il nostro tff.learning.Model per la formazione dei clienti essenzialmente nello stesso modo in cui addestreresti un modello TensorFlow. In particolare, useremotf.GradientTape per calcolare il gradiente su batch di dati, quindi applicheremo questo gradiente utilizzando un client_optimizer . Ci concentriamo solo sui pesi allenabili.

@tf.function
def client_update(model, dataset, server_weights, client_optimizer):
  """Performs training (using the server model weights) on the client's dataset."""
  # Initialize the client model with the current server weights.
  client_weights = model.trainable_variables
  # Assign the server weights to the client model.
  tf.nest.map_structure(lambda x, y: x.assign(y),
                        client_weights, server_weights)

  # Use the client_optimizer to update the local model.
  for batch in dataset:
    with tf.GradientTape() as tape:
      # Compute a forward pass on the batch of data
      outputs = model.forward_pass(batch)

    # Compute the corresponding gradient
    grads = tape.gradient(outputs.loss, client_weights)
    grads_and_vars = zip(grads, client_weights)

    # Apply the gradient using a client optimizer.
    client_optimizer.apply_gradients(grads_and_vars)

  return client_weights

Aggiornamento del server

L'aggiornamento del server per FedAvg è più semplice dell'aggiornamento del client. Implementeremo la media federata "vanilla", in cui sostituiremo semplicemente i pesi del modello server con la media dei pesi del modello client. Ancora una volta, ci concentriamo solo sui pesi allenabili.

@tf.function
def server_update(model, mean_client_weights):
  """Updates the server model weights as the average of the client model weights."""
  model_weights = model.trainable_variables
  # Assign the mean client weights to the server model.
  tf.nest.map_structure(lambda x, y: x.assign(y),
                        model_weights, mean_client_weights)
  return model_weights

Lo snippet potrebbe essere semplificato restituendo semplicemente mean_client_weights . Tuttavia, le implementazioni più avanzate di Federated Averaging utilizzano mean_client_weights con tecniche più sofisticate, come lo slancio o l'adattività.

Sfida : implementare una versione di server_update che aggiorni i pesi del server in modo che siano il punto medio di model_weights e mean_client_weights. (Nota: questo tipo di approccio "punto medio" è analogo al lavoro recente sull'ottimizzatore Lookahead !).

Finora abbiamo scritto solo codice TensorFlow puro. Questo è di progettazione, poiché TFF ti consente di utilizzare gran parte del codice TensorFlow con cui hai già familiarità. Tuttavia, ora dobbiamo specificare la logica di orchestrazione , ovvero la logica che determina ciò che il server trasmette al client e ciò che il client carica sul server.

Ciò richiederà il Federated Core di TFF.

Introduzione al Federated Core

Federated Core (FC) è un insieme di interfacce di livello inferiore che fungono da base per l'API tff.learning . Tuttavia, queste interfacce non si limitano all'apprendimento. In effetti, possono essere utilizzati per analisi e molti altri calcoli su dati distribuiti.

Ad alto livello, il core federato è un ambiente di sviluppo che consente alla logica del programma espresso in modo compatto di combinare il codice TensorFlow con operatori di comunicazione distribuiti (come somme distribuite e trasmissioni). L'obiettivo è fornire a ricercatori e professionisti il ​​controllo esplicito sulla comunicazione distribuita nei loro sistemi, senza richiedere dettagli di implementazione del sistema (come la specifica di scambi di messaggi di rete punto-punto).

Un punto chiave è che TFF è progettato per la tutela della privacy. Pertanto, consente un controllo esplicito su dove risiedono i dati, per prevenire l'accumulo indesiderato di dati nella posizione del server centralizzato.

Dati federati

Un concetto chiave in TFF è "dati federati", che si riferisce a una raccolta di elementi di dati ospitati su un gruppo di dispositivi in ​​un sistema distribuito (ad es. Set di dati client o pesi del modello server). Modelliamo l'intera raccolta di elementi di dati su tutti i dispositivi come un singolo valore federato .

Ad esempio, supponiamo di avere dispositivi client che hanno ciascuno un galleggiante che rappresenta la temperatura di un sensore. Potremmo rappresentarlo come un carro federato di

federated_float_on_clients = tff.FederatedType(tf.float32, tff.CLIENTS)

I tipi federati sono specificati da un tipo T dei suoi componenti membri (es. tf.float32 ) e da un gruppo G di dispositivi. Ci concentreremo sui casi in cui G è tff.CLIENTS o tff.SERVER . Tale tipo federato è rappresentato come {T}@G , come mostrato di seguito.

str(federated_float_on_clients)
'{float32}@CLIENTS'

Perché ci preoccupiamo così tanto dei posizionamenti? Un obiettivo chiave di TFF è abilitare la scrittura di codice che potrebbe essere distribuito su un vero sistema distribuito. Ciò significa che è fondamentale ragionare su quali sottoinsiemi di dispositivi eseguono quale codice e dove risiedono i diversi pezzi di dati.

TFF si concentra su tre cose: dati , dove vengono inseriti i dati e come vengono trasformati i dati. I primi due sono incapsulati in tipi federati, mentre l'ultimo è incapsulato in calcoli federati .

Calcoli federati

TFF è un ambiente di programmazione funzionale fortemente tipizzato le cui unità di base sono calcoli federati . Si tratta di elementi logici che accettano valori federati come input e restituiscono valori federati come output.

Ad esempio, supponiamo di voler fare la media delle temperature sui sensori dei nostri clienti. Potremmo definire quanto segue (usando il nostro float federato):

@tff.federated_computation(tff.FederatedType(tf.float32, tff.CLIENTS))
def get_average_temperature(client_temperatures):
  return tff.federated_mean(client_temperatures)

Potresti chiedere, in che modo è diverso dal decoratore tf.function in TensorFlow? La risposta chiave è che il codice generato da tff.federated_computation non è né codice TensorFlow né codice Python; È una specifica di un sistema distribuito in un linguaggio collante interno indipendente dalla piattaforma.

Anche se questo può sembrare complicato, puoi pensare ai calcoli TFF come a funzioni con firme di tipo ben definite. Queste firme di tipo possono essere interrogate direttamente.

str(get_average_temperature.type_signature)
'({float32}@CLIENTS -> float32@SERVER)'

Questo tff.federated_computation accetta argomenti di tipo federato {float32}@CLIENTS e restituisce valori di tipo federato {float32}@SERVER . I calcoli federati possono anche passare da server a client, da client a client o da server a server. I calcoli federati possono anche essere composti come normali funzioni, purché le loro firme di tipo corrispondano.

Per supportare lo sviluppo, TFF ti consente di invocare un tff.federated_computation come funzione Python. Ad esempio, possiamo chiamare

get_average_temperature([68.5, 70.3, 69.8])
69.53334

Calcoli non desiderosi e TensorFlow

Ci sono due limitazioni chiave di cui essere consapevoli. Innanzitutto, quando l'interprete Python incontra un decoratore tff.federated_computation , la funzione viene tracciata una volta e serializzata per un uso futuro. A causa della natura decentralizzata del Federated Learning, questo utilizzo futuro potrebbe verificarsi altrove, come un ambiente di esecuzione remota. Pertanto, i calcoli TFF sono fondamentalmente non desiderosi . Questo comportamento è in qualche modo analogo a quello del decoratore tf.function in TensorFlow.

In secondo luogo, un calcolo federato può essere costituito solo da operatori federati (come tff.federated_mean ), non possono contenere operazioni TensorFlow. Il codice TensorFlow deve essere limitato a blocchi decorati con tff.tf_computation . La maggior parte del codice TensorFlow ordinario può essere decorato direttamente, come la seguente funzione che prende un numero e aggiunge 0.5 ad esso.

@tff.tf_computation(tf.float32)
def add_half(x):
  return tf.add(x, 0.5)

Anche questi hanno firme di tipo, ma senza posizionamenti . Ad esempio, possiamo chiamare

str(add_half.type_signature)
'(float32 -> float32)'

Qui vediamo un'importante differenza tra tff.federated_computation e tff.tf_computation . Il primo ha posizionamenti espliciti, mentre il secondo no.

Possiamo utilizzare i blocchi tff.tf_computation nei calcoli federati specificando i posizionamenti. Creiamo una funzione che aggiunge metà, ma solo ai float federati nei client. Possiamo farlo usando tff.federated_map , che applica un dato tff.tf_computation , preservando il posizionamento.

@tff.federated_computation(tff.FederatedType(tf.float32, tff.CLIENTS))
def add_half_on_clients(x):
  return tff.federated_map(add_half, x)

Questa funzione è quasi identica a add_half , tranne per il fatto che accetta solo valori con posizionamento in tff.CLIENTS e restituisce valori con lo stesso posizionamento. Possiamo vederlo nella sua firma del tipo:

str(add_half_on_clients.type_signature)
'({float32}@CLIENTS -> {float32}@CLIENTS)'

In sintesi:

  • TFF opera su valori federati.
  • Ogni valore federato ha un tipo federato , con un tipo (ad es. tf.float32 ) e un posizionamento (ad es. tff.CLIENTS ).
  • I valori federati possono essere trasformati utilizzando calcoli federati , che devono essere decorati con tff.federated_computation e una firma di tipo federato.
  • Il codice TensorFlow deve essere contenuto in blocchi con decoratori tff.tf_computation .
  • Questi blocchi possono quindi essere incorporati in calcoli federati.

Creazione del proprio algoritmo di Federated Learning, rivisitato

Ora che abbiamo intravisto il Federated Core, possiamo costruire il nostro algoritmo di apprendimento federato. Ricorda che sopra, abbiamo definito initialize_fn e next_fn per il nostro algoritmo. next_fn utilizzerà client_update e server_update abbiamo definito utilizzando il codice TensorFlow puro.

Tuttavia, per rendere il nostro algoritmo un calcolo federato, avremo bisogno che sia next_fn che initialize_fn siano ciascuno un tff.federated_computation .

Blocchi TensorFlow Federated

Creazione del calcolo di inizializzazione

La funzione di inizializzazione sarà abbastanza semplice: creeremo un modello usando model_fn . Tuttavia, ricorda che dobbiamo separare il nostro codice TensorFlow utilizzando tff.tf_computation .

@tff.tf_computation
def server_init():
  model = model_fn()
  return model.trainable_variables

Possiamo quindi passare questo direttamente a un calcolo federato usando tff.federated_value .

@tff.federated_computation
def initialize_fn():
  return tff.federated_value(server_init(), tff.SERVER)

Creazione del file next_fn

Ora usiamo il nostro codice di aggiornamento client e server per scrivere l'algoritmo effettivo. Per prima cosa trasformeremo il nostro client_update in un tff.tf_computation che accetta i set di dati del client e i pesi del server e genera un tensore dei pesi del client aggiornato.

Avremo bisogno dei tipi corrispondenti per decorare adeguatamente la nostra funzione. Fortunatamente, il tipo di pesi del server può essere estratto direttamente dal nostro modello.

whimsy_model = model_fn()
tf_dataset_type = tff.SequenceType(whimsy_model.input_spec)

Diamo un'occhiata alla firma del tipo di set di dati. Ricorda che abbiamo preso 28 per 28 immagini (con etichette intere) e le abbiamo appiattite.

str(tf_dataset_type)
'<float32[?,784],int32[?,1]>*'

Possiamo anche estrarre il tipo di pesi del modello usando la nostra funzione server_init sopra.

model_weights_type = server_init.type_signature.result

Esaminando la firma del tipo, saremo in grado di vedere l'architettura del nostro modello!

str(model_weights_type)
'<float32[784,10],float32[10]>'

Ora possiamo creare il nostro tff.tf_computation per l'aggiornamento del client.

@tff.tf_computation(tf_dataset_type, model_weights_type)
def client_update_fn(tf_dataset, server_weights):
  model = model_fn()
  client_optimizer = tf.keras.optimizers.SGD(learning_rate=0.01)
  return client_update(model, tf_dataset, server_weights, client_optimizer)

La versione tff.tf_computation dell'aggiornamento del server può essere definita in modo simile, utilizzando i tipi che abbiamo già estratto.

@tff.tf_computation(model_weights_type)
def server_update_fn(mean_client_weights):
  model = model_fn()
  return server_update(model, mean_client_weights)

Ultimo, ma non meno importante, dobbiamo creare il tff.federated_computation che riunisce tutto questo. Questa funzione accetterà due valori federati , uno corrispondente ai pesi del server (con posizionamento tff.SERVER ) e l'altro corrispondente ai dataset del client (con posizionamento tff.CLIENTS ).

Nota che entrambi questi tipi sono stati definiti sopra! Dobbiamo semplicemente dare loro il posizionamento corretto usando tff.FederatedType .

federated_server_type = tff.FederatedType(model_weights_type, tff.SERVER)
federated_dataset_type = tff.FederatedType(tf_dataset_type, tff.CLIENTS)

Ricordi i 4 elementi di un algoritmo FL?

  1. Un passaggio di trasmissione da server a client.
  2. Un passaggio di aggiornamento del client locale.
  3. Un passaggio di caricamento da client a server.
  4. Un passaggio di aggiornamento del server.

Ora che abbiamo creato quanto sopra, ogni parte può essere rappresentata in modo compatto come una singola riga di codice TFF. Questa semplicità è il motivo per cui abbiamo dovuto prestare particolare attenzione a specificare cose come i tipi federati!

@tff.federated_computation(federated_server_type, federated_dataset_type)
def next_fn(server_weights, federated_dataset):
  # Broadcast the server weights to the clients.
  server_weights_at_client = tff.federated_broadcast(server_weights)

  # Each client computes their updated weights.
  client_weights = tff.federated_map(
      client_update_fn, (federated_dataset, server_weights_at_client))

  # The server averages these updates.
  mean_client_weights = tff.federated_mean(client_weights)

  # The server updates its model.
  server_weights = tff.federated_map(server_update_fn, mean_client_weights)

  return server_weights

Ora abbiamo un tff.federated_computation sia per l'inizializzazione dell'algoritmo, sia per l'esecuzione di un passaggio dell'algoritmo. Per completare il nostro algoritmo, li passiamo a tff.templates.IterativeProcess .

federated_algorithm = tff.templates.IterativeProcess(
    initialize_fn=initialize_fn,
    next_fn=next_fn
)

Diamo un'occhiata alla firma del tipo delle funzioni di initialize e next del nostro processo iterativo.

str(federated_algorithm.initialize.type_signature)
'( -> <float32[784,10],float32[10]>@SERVER)'

Ciò riflette il fatto che federated_algorithm.initialize è una funzione no-arg che restituisce un modello a strato singolo (con una matrice di peso 784 x 10 e 10 unità di polarizzazione).

str(federated_algorithm.next.type_signature)
'(<<float32[784,10],float32[10]>@SERVER,{<float32[?,784],int32[?,1]>*}@CLIENTS> -> <float32[784,10],float32[10]>@SERVER)'

Qui, vediamo che federated_algorithm.next accetta un modello di server e dati client e restituisce un modello di server aggiornato.

Valutare l'algoritmo

Facciamo alcuni round e vediamo come cambia la perdita. In primo luogo, definiremo una funzione di valutazione utilizzando l'approccio centralizzato discusso nel secondo tutorial.

Creiamo prima un set di dati di valutazione centralizzato, quindi applichiamo la stessa preelaborazione che abbiamo utilizzato per i dati di addestramento.

Nota che abbiamo solo take i primi 1000 elementi per ragioni di efficienza computazionale, ma in genere ci piacerebbe utilizzare l'intero set di dati di test.

central_emnist_test = emnist_test.create_tf_dataset_from_all_clients().take(1000)
central_emnist_test = preprocess(central_emnist_test)

Successivamente, scriviamo una funzione che accetta uno stato del server e utilizza Keras per valutare sul set di dati di test. Se hai familiarità con tf.Keras , ti sembrerà tutto familiare, anche se nota l'uso di set_weights !

def evaluate(server_state):
  keras_model = create_keras_model()
  keras_model.compile(
      loss=tf.keras.losses.SparseCategoricalCrossentropy(),
      metrics=[tf.keras.metrics.SparseCategoricalAccuracy()]  
  )
  keras_model.set_weights(server_state)
  keras_model.evaluate(central_emnist_test)

Ora inizializziamo il nostro algoritmo e valutiamo sul set di test.

server_state = federated_algorithm.initialize()
evaluate(server_state)
50/50 [==============================] - 0s 2ms/step - loss: 2.3026 - sparse_categorical_accuracy: 0.0910

Alleniamoci per qualche round e vediamo se cambia qualcosa.

for round in range(15):
  server_state = federated_algorithm.next(server_state, federated_train_data)
evaluate(server_state)
50/50 [==============================] - 0s 1ms/step - loss: 2.1706 - sparse_categorical_accuracy: 0.2440

Vediamo una leggera diminuzione nella funzione di perdita. Sebbene il salto sia piccolo, abbiamo eseguito solo 10 round di allenamento e su un piccolo sottogruppo di clienti. Per ottenere risultati migliori, potremmo dover fare centinaia se non migliaia di round.

Modifica il nostro algoritmo

A questo punto, fermiamoci e pensiamo a ciò che abbiamo realizzato. Abbiamo implementato la media federata direttamente combinando il codice TensorFlow puro (per gli aggiornamenti del client e del server) con i calcoli federati dal Federated Core di TFF.

Per eseguire un apprendimento più sofisticato, possiamo semplicemente modificare ciò che abbiamo sopra. In particolare, modificando il codice TF puro sopra, possiamo cambiare il modo in cui il client esegue l'addestramento o il modo in cui il server aggiorna il suo modello.

Sfida: aggiungi il ritaglio del gradiente alla funzione client_update .

Se volessimo apportare modifiche maggiori, potremmo anche fare in modo che il server memorizzi e trasmetta più dati. Ad esempio, il server potrebbe anche memorizzare il tasso di apprendimento del client e farlo decadere nel tempo! Si noti che ciò richiederà modifiche alle firme del tipo utilizzate nelle chiamate tff.tf_computation sopra.

Sfida più difficile: implementare la media federata con decadimento del tasso di apprendimento sui clienti.

A questo punto, potresti iniziare a capire quanta flessibilità c'è in ciò che puoi implementare in questo framework. Per idee (inclusa la risposta alla sfida più difficile sopra) puoi vedere il codice sorgente per tff.learning.build_federated_averaging_process , o controllare vari progetti di ricerca usando TFF.