Algoritmi federati personalizzati, parte 2: implementazione della media federata

Visualizza su TensorFlow.org Esegui in Google Colab Visualizza sorgente su GitHub Scarica taccuino

Questo tutorial è la seconda parte di una serie in due parti che dimostra come implementare tipi personalizzati di algoritmi federati in TFF utilizzando Federated Core (FC) , che funge da base per il livello Federated Learning (FL) ( tff.learning ) .

Ti invitiamo a leggere prima la prima parte di questa serie , che introduce alcuni dei concetti chiave e delle astrazioni di programmazione qui utilizzate.

Questa seconda parte della serie utilizza i meccanismi introdotti nella prima parte per implementare una versione semplice di algoritmi di addestramento e valutazione federati.

Ti invitiamo a rivedere la classificazione delle immagini e le esercitazioni sulla generazione di testo per un'introduzione più delicata e di livello superiore alle API Federated Learning di TFF, in quanto ti aiuteranno a contestualizzare i concetti che descriviamo qui.

Prima di iniziare

Prima di iniziare, prova a eseguire il seguente esempio "Hello World" per assicurarti che il tuo ambiente sia configurato correttamente. Se non funziona, fare riferimento alla Guida all'installazione per le istruzioni.

!pip install --quiet --upgrade tensorflow-federated-nightly
!pip install --quiet --upgrade nest-asyncio

import nest_asyncio
nest_asyncio.apply()
import collections

import numpy as np
import tensorflow as tf
import tensorflow_federated as tff

# TODO(b/148678573,b/148685415): must use the reference context because it
# supports unbounded references and tff.sequence_* intrinsics.
tff.backends.reference.set_reference_context()
@tff.federated_computation
def hello_world():
  return 'Hello, World!'

hello_world()
'Hello, World!'

Implementazione della media federata

Come in Federated Learning for Image Classification , useremo l'esempio MNIST, ma poiché questo è inteso come un tutorial di basso livello, bypasseremo l'API di Keras e tff.simulation . tff.simulation , scriveremo codice modello grezzo e costruiremo un set di dati federati da zero.

Preparazione di set di dati federati

Per motivi di dimostrazione, simuleremo uno scenario in cui abbiamo dati da 10 utenti e ciascuno di essi contribuisce alla conoscenza di come riconoscere una cifra diversa. Questo è quanto di meno non iid possa sembrare.

Innanzitutto, carichiamo i dati MNIST standard:

mnist_train, mnist_test = tf.keras.datasets.mnist.load_data()
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz
11493376/11490434 [==============================] - 0s 0us/step
11501568/11490434 [==============================] - 0s 0us/step
[(x.dtype, x.shape) for x in mnist_train]
[(dtype('uint8'), (60000, 28, 28)), (dtype('uint8'), (60000,))]

I dati vengono forniti come array Numpy, uno con immagini e un altro con etichette di cifre, entrambi con la prima dimensione che ricopre i singoli esempi. Scriviamo una funzione di supporto che la formatta in un modo compatibile con il modo in cui inseriamo sequenze federate nei calcoli TFF, cioè come un elenco di elenchi: l'elenco esterno che va dagli utenti (cifre), quello interno che va dai batch di dati in sequenza di ogni cliente. Come è consuetudine, struttureremo ogni batch come una coppia di tensori denominati x e y , ciascuno con la dimensione del batch iniziale. Nel frattempo, appiattiremo ogni immagine in un vettore a 784 elementi e ridimensioneremo i pixel in essa nell'intervallo 0..1 , in modo da non dover ingombrare la logica del modello con conversioni di dati.

NUM_EXAMPLES_PER_USER = 1000
BATCH_SIZE = 100


def get_data_for_digit(source, digit):
  output_sequence = []
  all_samples = [i for i, d in enumerate(source[1]) if d == digit]
  for i in range(0, min(len(all_samples), NUM_EXAMPLES_PER_USER), BATCH_SIZE):
    batch_samples = all_samples[i:i + BATCH_SIZE]
    output_sequence.append({
        'x':
            np.array([source[0][i].flatten() / 255.0 for i in batch_samples],
                     dtype=np.float32),
        'y':
            np.array([source[1][i] for i in batch_samples], dtype=np.int32)
    })
  return output_sequence


federated_train_data = [get_data_for_digit(mnist_train, d) for d in range(10)]

federated_test_data = [get_data_for_digit(mnist_test, d) for d in range(10)]

Per un rapido controllo di integrità, diamo un'occhiata al tensore Y nell'ultimo batch di dati fornito dal quinto client (quello corrispondente alla cifra 5 ).

federated_train_data[5][-1]['y']
array([5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5,
       5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5,
       5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5,
       5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5,
       5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5], dtype=int32)

Per essere sicuri, guardiamo anche l'immagine corrispondente all'ultimo elemento di quel lotto.

from matplotlib import pyplot as plt

plt.imshow(federated_train_data[5][-1]['x'][-1].reshape(28, 28), cmap='gray')
plt.grid(False)
plt.show()

png

Sulla combinazione di TensorFlow e TFF

In questo tutorial, per compattezza, decoriamo immediatamente le funzioni che introducono la logica TensorFlow con tff.tf_computation . Tuttavia, per una logica più complessa, questo non è il modello che consigliamo. Il debug di TensorFlow può già essere una sfida e il debug di TensorFlow dopo che è stato completamente serializzato e quindi reimportato perde necessariamente alcuni metadati e limita l'interattività, rendendo il debug ancora più difficile.

Pertanto, consigliamo vivamente di scrivere logica TF complessa come funzioni Python autonome (cioè senza decorazione tff.tf_computation ). In questo modo la logica di TensorFlow può essere sviluppata e testata utilizzando le migliori pratiche e strumenti di TF (come la modalità eager), prima di serializzare il calcolo per TFF (ad esempio, invocando tff.tf_computation con una funzione Python come argomento).

Definizione di una funzione di perdita

Ora che abbiamo i dati, definiamo una funzione di perdita che possiamo utilizzare per l'addestramento. Innanzitutto, definiamo il tipo di input come una tupla denominata TFF. Poiché la dimensione dei batch di dati può variare, impostiamo la dimensione batch su None per indicare che la dimensione di questa dimensione è sconosciuta.

BATCH_SPEC = collections.OrderedDict(
    x=tf.TensorSpec(shape=[None, 784], dtype=tf.float32),
    y=tf.TensorSpec(shape=[None], dtype=tf.int32))
BATCH_TYPE = tff.to_type(BATCH_SPEC)

str(BATCH_TYPE)
'<x=float32[?,784],y=int32[?]>'

Forse ti starai chiedendo perché non possiamo semplicemente definire un normale tipo Python. Ricordiamo la discussione nella parte 1 , dove abbiamo spiegato che mentre possiamo esprimere la logica dei calcoli TFF usando Python, sotto il cofano i calcoli TFF non sono Python. Il simbolo BATCH_TYPE definito sopra rappresenta una specifica del tipo TFF astratta. È importante distinguere questo tipo TFF astratto dai tipi concreti di rappresentazione Python, ad esempio contenitori come dict o collections.namedtuple che possono essere usati per rappresentare il tipo TFF nel corpo di una funzione Python. A differenza di Python, TFF ha un unico costruttore di tipo astratto tff.StructType per contenitori simili a tuple, con elementi che possono essere nominati individualmente o lasciati senza nome. Questo tipo viene utilizzato anche per modellare i parametri formali dei calcoli, poiché i calcoli TFF possono dichiarare formalmente solo un parametro e un risultato: vedrai esempi di questo a breve.

Definiamo ora il tipo TFF dei parametri del modello, ancora una volta come una tupla di pesi e bias denominata TFF.

MODEL_SPEC = collections.OrderedDict(
    weights=tf.TensorSpec(shape=[784, 10], dtype=tf.float32),
    bias=tf.TensorSpec(shape=[10], dtype=tf.float32))
MODEL_TYPE = tff.to_type(MODEL_SPEC)

print(MODEL_TYPE)
<weights=float32[784,10],bias=float32[10]>

Con queste definizioni in atto, ora possiamo definire la perdita per un dato modello, su un singolo batch. Notare l'utilizzo di @tf.function decorator all'interno del decorator @tff.tf_computation . Questo ci permette di scrivere TF usando Python come la semantica anche se erano all'interno di un contesto tf.Graph creato dal decoratore tff.tf_computation .

# NOTE: `forward_pass` is defined separately from `batch_loss` so that it can 
# be later called from within another tf.function. Necessary because a
# @tf.function  decorated method cannot invoke a @tff.tf_computation.

@tf.function
def forward_pass(model, batch):
  predicted_y = tf.nn.softmax(
      tf.matmul(batch['x'], model['weights']) + model['bias'])
  return -tf.reduce_mean(
      tf.reduce_sum(
          tf.one_hot(batch['y'], 10) * tf.math.log(predicted_y), axis=[1]))

@tff.tf_computation(MODEL_TYPE, BATCH_TYPE)
def batch_loss(model, batch):
  return forward_pass(model, batch)

Come previsto, il calcolo batch_loss restituisce la perdita float32 dato il modello e un singolo batch di dati. Nota come MODEL_TYPE e BATCH_TYPE sono stati raggruppati insieme in una BATCH_TYPE tupla di parametri formali; puoi riconoscere il tipo di batch_loss come (<MODEL_TYPE,BATCH_TYPE> -> float32) .

str(batch_loss.type_signature)
'(<<weights=float32[784,10],bias=float32[10]>,<x=float32[?,784],y=int32[?]>> -> float32)'

Come controllo di integrità, costruiamo un modello iniziale pieno di zeri e calcoliamo la perdita sul batch di dati che abbiamo visualizzato sopra.

initial_model = collections.OrderedDict(
    weights=np.zeros([784, 10], dtype=np.float32),
    bias=np.zeros([10], dtype=np.float32))

sample_batch = federated_train_data[5][-1]

batch_loss(initial_model, sample_batch)
2.3025854

Si noti che alimentiamo il calcolo del TFF con il modello iniziale definito come un dict , anche se il corpo della funzione Python che la definisce utilizza i parametri del model['weight'] come model['weight'] e model['bias'] . Gli argomenti della chiamata a batch_loss non vengono semplicemente passati al corpo di quella funzione.

Cosa succede quando invochiamo batch_loss ? Il corpo Python di batch_loss è già stato tracciato e serializzato nella cella sopra dove è stato definito. TFF funge da chiamante per batch_loss al momento della definizione del calcolo e come destinazione della batch_loss momento in cui batch_loss viene richiamato. In entrambi i ruoli, TFF funge da ponte tra il sistema di tipi astratti di TFF e i tipi di rappresentazione Python. Al momento della chiamata, TFF accetterà la maggior parte dei tipi di contenitori Python standard ( dict , list , tuple , collections.namedtuple , ecc.) Come rappresentazioni concrete di tuple TFF astratte. Inoltre, sebbene come notato sopra, i calcoli TFF accettino formalmente solo un singolo parametro, puoi usare la familiare sintassi delle chiamate Python con argomenti posizionali e / o parole chiave nel caso in cui il tipo del parametro sia una tupla - funziona come previsto.

Discesa in pendenza su un unico lotto

Ora, definiamo un calcolo che utilizzi questa funzione di perdita per eseguire un singolo passaggio di discesa del gradiente. Nota come nella definizione di questa funzione, usiamo batch_loss come sottocomponente. È possibile richiamare un calcolo costruito con tff.tf_computation all'interno del corpo di un altro calcolo, sebbene in genere non sia necessario - come notato sopra, poiché la serializzazione perde alcune informazioni di debug, è spesso preferibile per calcoli più complessi scrivere e testare tutto il TensorFlow senza il decoratore tff.tf_computation .

@tff.tf_computation(MODEL_TYPE, BATCH_TYPE, tf.float32)
def batch_train(initial_model, batch, learning_rate):
  # Define a group of model variables and set them to `initial_model`. Must
  # be defined outside the @tf.function.
  model_vars = collections.OrderedDict([
      (name, tf.Variable(name=name, initial_value=value))
      for name, value in initial_model.items()
  ])
  optimizer = tf.keras.optimizers.SGD(learning_rate)

  @tf.function
  def _train_on_batch(model_vars, batch):
    # Perform one step of gradient descent using loss from `batch_loss`.
    with tf.GradientTape() as tape:
      loss = forward_pass(model_vars, batch)
    grads = tape.gradient(loss, model_vars)
    optimizer.apply_gradients(
        zip(tf.nest.flatten(grads), tf.nest.flatten(model_vars)))
    return model_vars

  return _train_on_batch(model_vars, batch)
str(batch_train.type_signature)
'(<<weights=float32[784,10],bias=float32[10]>,<x=float32[?,784],y=int32[?]>,float32> -> <weights=float32[784,10],bias=float32[10]>)'

Quando si richiama una funzione Python decorata con tff.tf_computation all'interno del corpo di un'altra funzione simile, la logica del calcolo TFF interno è incorporata (essenzialmente, inlined) nella logica di quella esterna. Come notato sopra, se stai scrivendo entrambi i calcoli, è probabilmente preferibile rendere la funzione interna ( batch_loss in questo caso) un normale Python o tf.function piuttosto che un tff.tf_computation . Tuttavia, qui illustriamo che chiamare un tff.tf_computation all'interno di un altro funziona fondamentalmente come previsto. Ciò potrebbe essere necessario se, ad esempio, non hai il codice Python che definisce batch_loss , ma solo la sua rappresentazione TFF serializzata.

Ora, applichiamo questa funzione alcune volte al modello iniziale per vedere se la perdita diminuisce.

model = initial_model
losses = []
for _ in range(5):
  model = batch_train(model, sample_batch, 0.1)
  losses.append(batch_loss(model, sample_batch))
losses
[0.19690022, 0.13176313, 0.10113226, 0.082738124, 0.0703014]

Discesa del gradiente su una sequenza di dati locali

Ora, poiché batch_train sembra funzionare, scriviamo una funzione di addestramento simile local_train che consuma l'intera sequenza di tutti i batch da un utente invece di un singolo batch. Il nuovo calcolo dovrà ora consumare tff.SequenceType(BATCH_TYPE) invece di BATCH_TYPE .

LOCAL_DATA_TYPE = tff.SequenceType(BATCH_TYPE)

@tff.federated_computation(MODEL_TYPE, tf.float32, LOCAL_DATA_TYPE)
def local_train(initial_model, learning_rate, all_batches):

  # Mapping function to apply to each batch.
  @tff.federated_computation(MODEL_TYPE, BATCH_TYPE)
  def batch_fn(model, batch):
    return batch_train(model, batch, learning_rate)

  return tff.sequence_reduce(all_batches, initial_model, batch_fn)
str(local_train.type_signature)
'(<<weights=float32[784,10],bias=float32[10]>,float32,<x=float32[?,784],y=int32[?]>*> -> <weights=float32[784,10],bias=float32[10]>)'

Ci sono alcuni dettagli nascosti in questa breve sezione di codice, esaminiamoli uno per uno.

Innanzitutto, mentre avremmo potuto implementare questa logica interamente in TensorFlow, affidandoci a tf.data.Dataset.reduce per elaborare la sequenza in modo simile a come l'abbiamo fatto in precedenza, questa volta abbiamo scelto di esprimere la logica nel linguaggio glue , come tff.federated_computation . Abbiamo utilizzato l'operatore federato tff.sequence_reduce per eseguire la riduzione.

L'operatore tff.sequence_reduce viene utilizzato in modo simile a tf.data.Dataset.reduce . Puoi pensarlo essenzialmente come tf.data.Dataset.reduce , ma per l'uso all'interno di calcoli federati che, come forse ricorderai, non possono contenere codice TensorFlow. È un operatore modello con un parametro formale 3-tupla che consiste in una sequenza di elementi di tipo T , lo stato iniziale della riduzione (lo chiameremo astrattamente zero ) di qualche tipo U e l' operatore di riduzione di tipo (<U,T> -> U) che altera lo stato della riduzione elaborando un singolo elemento. Il risultato è lo stato finale della riduzione, dopo aver elaborato tutti gli elementi in ordine sequenziale. Nel nostro esempio, lo stato della riduzione è il modello addestrato su un prefisso dei dati e gli elementi sono batch di dati.

In secondo luogo, nota che abbiamo nuovamente utilizzato un calcolo ( batch_train ) come componente all'interno di un altro ( local_train ), ma non direttamente. Non possiamo usarlo come operatore di riduzione perché richiede un parametro aggiuntivo: il tasso di apprendimento. Per risolvere questo problema, definiamo un batch_fn calcolo federato incorporato che si lega al parametro learning_rate local_train nel suo corpo. È consentito che un calcolo figlio definito in questo modo catturi un parametro formale del suo genitore fintanto che il calcolo figlio non viene invocato al di fuori del corpo del suo genitore. Puoi pensare a questo modello come un equivalente di functools.partial in Python.

L'implicazione pratica dell'acquisizione di learning_rate questo modo è, ovviamente, che lo stesso valore della velocità di apprendimento viene utilizzato in tutti i batch.

Ora, proviamo la funzione di addestramento locale appena definita sull'intera sequenza di dati dallo stesso utente che ha fornito il batch campione (cifra 5 ).

locally_trained_model = local_train(initial_model, 0.1, federated_train_data[5])

Ha funzionato? Per rispondere a questa domanda, dobbiamo implementare la valutazione.

Valutazione locale

Ecco un modo per implementare la valutazione locale sommando le perdite in tutti i batch di dati (avremmo potuto calcolare altrettanto bene la media; lo lasceremo come esercizio per il lettore).

@tff.federated_computation(MODEL_TYPE, LOCAL_DATA_TYPE)
def local_eval(model, all_batches):
  # TODO(b/120157713): Replace with `tff.sequence_average()` once implemented.
  return tff.sequence_sum(
      tff.sequence_map(
          tff.federated_computation(lambda b: batch_loss(model, b), BATCH_TYPE),
          all_batches))
str(local_eval.type_signature)
'(<<weights=float32[784,10],bias=float32[10]>,<x=float32[?,784],y=int32[?]>*> -> float32)'

Di nuovo, ci sono alcuni nuovi elementi illustrati da questo codice, esaminiamoli uno per uno.

Innanzitutto, abbiamo utilizzato due nuovi operatori federati per l'elaborazione delle sequenze: tff.sequence_map che accetta una funzione di mappatura T->U e una sequenza di T , ed emette una sequenza di U ottenuta applicando la funzione di mappatura pointwise, e tff.sequence_sum che aggiunge solo tutti gli elementi. Qui, mappiamo ogni batch di dati a un valore di perdita, quindi aggiungiamo i valori di perdita risultanti per calcolare la perdita totale.

Nota che avremmo potuto usare di nuovo tff.sequence_reduce , ma questa non sarebbe la scelta migliore: il processo di riduzione è, per definizione, sequenziale, mentre la mappatura e la somma possono essere calcolate in parallelo. Quando viene data una scelta, è meglio attenersi a operatori che non limitano le scelte di implementazione, in modo che quando il nostro calcolo TFF verrà compilato in futuro per essere distribuito in un ambiente specifico, si possa sfruttare appieno tutte le potenziali opportunità per un più veloce , esecuzione più scalabile e più efficiente in termini di risorse.

In secondo luogo, nota che proprio come in local_train , la funzione del componente di cui abbiamo bisogno ( batch_loss ) accetta più parametri di quanto si aspetta l'operatore federato ( tff.sequence_map ), quindi definiamo di nuovo un parziale, questa volta inline avvolgendo direttamente un lambda come tff.federated_computation . L'utilizzo di wrapper inline con una funzione come argomento è il modo consigliato per utilizzare tff.tf_computation per incorporare la logica TensorFlow in TFF.

Ora vediamo se la nostra formazione ha funzionato.

print('initial_model loss =', local_eval(initial_model,
                                         federated_train_data[5]))
print('locally_trained_model loss =',
      local_eval(locally_trained_model, federated_train_data[5]))
initial_model loss = 23.025854
locally_trained_model loss = 0.4348469

In effetti, la perdita è diminuita. Ma cosa succede se lo valutiamo sui dati di un altro utente?

print('initial_model loss =', local_eval(initial_model,
                                         federated_train_data[0]))
print('locally_trained_model loss =',
      local_eval(locally_trained_model, federated_train_data[0]))
initial_model loss = 23.025854
locally_trained_model loss = 74.50075

Come previsto, le cose sono peggiorate. Il modello è stato addestrato per riconoscere 5 e non ha mai visto uno 0 . Questo porta alla domanda: in che modo la formazione locale ha influito sulla qualità del modello dal punto di vista globale?

Valutazione federata

Questo è il punto del nostro viaggio in cui torniamo finalmente ai tipi federati e ai calcoli federati, l'argomento da cui siamo partiti. Ecco una coppia di definizioni dei tipi di TFF per il modello che ha origine sul server e i dati che rimangono sui client.

SERVER_MODEL_TYPE = tff.type_at_server(MODEL_TYPE)
CLIENT_DATA_TYPE = tff.type_at_clients(LOCAL_DATA_TYPE)

Con tutte le definizioni introdotte finora, esprimere la valutazione federata in TFF è una riga: distribuiamo il modello ai client, permettiamo a ogni client di invocare la valutazione locale sulla sua porzione locale di dati e quindi calcolare la media della perdita. Ecco un modo per scrivere questo.

@tff.federated_computation(SERVER_MODEL_TYPE, CLIENT_DATA_TYPE)
def federated_eval(model, data):
  return tff.federated_mean(
      tff.federated_map(local_eval, [tff.federated_broadcast(model), data]))

Abbiamo già visto esempi di tff.federated_mean e tff.federated_map in scenari più semplici e, a livello intuitivo, funzionano come previsto, ma in questa sezione di codice c'è di più di quanto sembri, quindi esaminiamolo attentamente.

Per prima cosa, analizziamo il lascia che ogni client invoca la valutazione locale sulla sua parte locale di dati . Come puoi ricordare dalle sezioni precedenti, local_eval ha una firma del tipo del modulo (<MODEL_TYPE, LOCAL_DATA_TYPE> -> float32) .

L'operatore federato tff.federated_map è un modello che accetta come parametro una tupla a 2 che consiste nella funzione di mappatura di qualche tipo T->U e un valore federato di tipo {T}@CLIENTS (cioè, con componenti membri del stesso tipo del parametro della funzione di mappatura) e restituisce un risultato di tipo {U}@CLIENTS .

Poiché forniamo local_eval come funzione di mappatura da applicare in base al client, il secondo argomento deve essere di tipo federato {<MODEL_TYPE, LOCAL_DATA_TYPE>}@CLIENTS , ovvero, nella nomenclatura delle sezioni precedenti, dovrebbe essere una tupla federata. Ogni client dovrebbe contenere un set completo di argomenti per local_eval come membro costituente. Invece, gli stiamo alimentando un list Python a 2 elementi. Cosa sta succedendo qui?

In effetti, questo è un esempio di un cast di tipo implicito in TFF, simile ai cast di tipo implicito che potresti aver incontrato altrove, ad esempio, quando fornisci un int a una funzione che accetta un float . Il casting implicito è usato scarsamente a questo punto, ma abbiamo intenzione di renderlo più pervasivo in TFF come un modo per ridurre al minimo il boilerplate.

Il cast implicito applicato in questo caso è l'equivalenza tra tuple federate della forma {<X,Y>}@Z e tuple di valori federati <{X}@Z,{Y}@Z> . Anche se formalmente, queste due sono firme di tipo diverso, guardandole dal punto di vista dei programmatori, ogni dispositivo in Z contiene due unità di dati X e Y Quello che succede qui non è diverso da zip in Python, e in effetti, offriamo un operatore tff.federated_zip che ti consente di eseguire tali conversioni in modo esplicito. Quando tff.federated_map incontra una tupla come secondo argomento, invoca semplicemente tff.federated_zip per te.

Considerato quanto sopra, ora dovresti essere in grado di riconoscere l'espressione tff.federated_broadcast(model) come rappresentante un valore di tipo TFF {MODEL_TYPE}@CLIENTS e i data come valore di tipo TFF {LOCAL_DATA_TYPE}@CLIENTS (o semplicemente CLIENT_DATA_TYPE ) , i due vengono filtrati insieme attraverso un tff.federated_zip implicito per formare il secondo argomento di tff.federated_map .

L'operatore tff.federated_broadcast , come ci si aspetterebbe, trasferisce semplicemente i dati dal server ai client.

Vediamo ora come la nostra formazione locale ha influenzato la perdita media nel sistema.

print('initial_model loss =', federated_eval(initial_model,
                                             federated_train_data))
print('locally_trained_model loss =',
      federated_eval(locally_trained_model, federated_train_data))
initial_model loss = 23.025852
locally_trained_model loss = 54.432625

In effetti, come previsto, la perdita è aumentata. Per migliorare il modello per tutti gli utenti, dovremo allenarci sui dati di tutti.

Formazione federata

Il modo più semplice per implementare la formazione federata è addestrare localmente e quindi calcolare la media dei modelli. Questo utilizza gli stessi elementi costitutivi e schemi che abbiamo già discusso, come puoi vedere di seguito.

SERVER_FLOAT_TYPE = tff.type_at_server(tf.float32)


@tff.federated_computation(SERVER_MODEL_TYPE, SERVER_FLOAT_TYPE,
                           CLIENT_DATA_TYPE)
def federated_train(model, learning_rate, data):
  return tff.federated_mean(
      tff.federated_map(local_train, [
          tff.federated_broadcast(model),
          tff.federated_broadcast(learning_rate), data
      ]))

Nota che nell'implementazione completa di Federated Averaging fornita da tff.learning , piuttosto che fare la media dei modelli, preferiamo fare la media dei delta del modello, per una serie di ragioni, ad esempio, la possibilità di ritagliare le norme di aggiornamento, per la compressione, ecc. .

Vediamo se l'allenamento funziona eseguendo alcuni round di allenamento e confrontando la perdita media prima e dopo.

model = initial_model
learning_rate = 0.1
for round_num in range(5):
  model = federated_train(model, learning_rate, federated_train_data)
  learning_rate = learning_rate * 0.9
  loss = federated_eval(model, federated_train_data)
  print('round {}, loss={}'.format(round_num, loss))
round 0, loss=21.60552406311035
round 1, loss=20.365678787231445
round 2, loss=19.27480125427246
round 3, loss=18.31110954284668
round 4, loss=17.45725440979004

Per completezza, ora eseguiamo anche i dati del test per confermare che il nostro modello si generalizza bene.

print('initial_model test loss =',
      federated_eval(initial_model, federated_test_data))
print('trained_model test loss =', federated_eval(model, federated_test_data))
initial_model test loss = 22.795593
trained_model test loss = 17.278767

Questo conclude il nostro tutorial.

Ovviamente, il nostro esempio semplificato non riflette una serie di cose che dovresti fare in uno scenario più realistico, ad esempio, non abbiamo calcolato metriche diverse dalla perdita. Ti invitiamo a studiare l'implementazione della media federata in tff.learning come esempio più completo e come modo per dimostrare alcune delle pratiche di codifica che vorremmo incoraggiare.