Apprendimento federato per la classificazione delle immagini

Visualizza su TensorFlow.org Esegui in Google Colab Visualizza la fonte su GitHub Scarica il taccuino

In questo tutorial, utilizziamo l'esempio di formazione MNIST classico per introdurre l'apprendimento federati (FL) livello API di TFF, tff.learning - un insieme di interfacce di livello superiore che può essere utilizzato per eseguire i comuni tipi di compiti di apprendimento federati, come ad esempio formazione federata, rispetto ai modelli forniti dall'utente implementati in TensorFlow.

Questo tutorial e l'API di apprendimento federato sono destinati principalmente agli utenti che desiderano collegare i propri modelli TensorFlow a TFF, trattando quest'ultimo principalmente come una scatola nera. Per una comprensione più approfondita di TFF e come implementare i propri algoritmi di apprendimento federati, vedere le esercitazioni su API FC Nucleo - personalizzato Federati Algoritmi Parte 1 e Parte 2 .

Per ulteriori informazioni su tff.learning , continuare con la Federated Learning for Text Generation , tutorial che oltre a coprire modelli ricorrenti, dimostra anche il caricamento di un modello di Keras serializzato pre-addestrato per la raffinatezza con l'apprendimento federata combinato con la valutazione utilizzando Keras.

Prima di iniziare

Prima di iniziare, esegui quanto segue per assicurarti che il tuo ambiente sia configurato correttamente. Se non vedi un saluto, si prega di fare riferimento alla installazione guida per le istruzioni.

# tensorflow_federated_nightly also bring in tf_nightly, which
# can causes a duplicate tensorboard install, leading to errors.
!pip uninstall --yes tensorboard tb-nightly

!pip install --quiet --upgrade tensorflow-federated-nightly
!pip install --quiet --upgrade nest-asyncio
!pip install --quiet --upgrade tb-nightly  # or tensorboard, but not both

import nest_asyncio
nest_asyncio.apply()
%load_ext tensorboard
import collections

import numpy as np
import tensorflow as tf
import tensorflow_federated as tff

np.random.seed(0)

tff.federated_computation(lambda: 'Hello, World!')()
b'Hello, World!'

Preparazione dei dati di input

Cominciamo con i dati. L'apprendimento federato richiede un set di dati federato, ovvero una raccolta di dati da più utenti. Dati federati è tipicamente non- IID , che pone una serie di sfide.

Al fine di facilitare la sperimentazione, abbiamo seminato il repository TFF con alcuni set di dati, tra cui una versione federata di MNIST che contiene una versione del set di dati NIST originale che è stato nuovamente trattata con foglia in modo che i dati è calettato dallo scrittore originale le cifre. Poiché ogni writer ha uno stile univoco, questo set di dati mostra il tipo di comportamento non iid previsto per i set di dati federati.

Ecco come possiamo caricarlo.

emnist_train, emnist_test = tff.simulation.datasets.emnist.load_data()

I set di dati restituiti da load_data() sono casi di tff.simulation.ClientData , un'interfaccia che consente di enumerare l'insieme di utenti, per costruire un tf.data.Dataset che rappresenta i dati di un determinato utente, e di interrogare il struttura dei singoli elementi. Ecco come è possibile utilizzare questa interfaccia per esplorare il contenuto del set di dati. Tieni presente che, sebbene questa interfaccia ti permetta di scorrere gli ID dei client, questa è solo una caratteristica dei dati di simulazione. Come vedrai a breve, le identità dei client non vengono utilizzate dal framework di apprendimento federato: il loro unico scopo è consentire di selezionare sottoinsiemi di dati per le simulazioni.

len(emnist_train.client_ids)
3383
emnist_train.element_type_structure
OrderedDict([('label', TensorSpec(shape=(), dtype=tf.int32, name=None)), ('pixels', TensorSpec(shape=(28, 28), dtype=tf.float32, name=None))])
example_dataset = emnist_train.create_tf_dataset_for_client(
    emnist_train.client_ids[0])

example_element = next(iter(example_dataset))

example_element['label'].numpy()
1
from matplotlib import pyplot as plt
plt.imshow(example_element['pixels'].numpy(), cmap='gray', aspect='equal')
plt.grid(False)
_ = plt.show()

png

Esplorare l'eterogeneità nei dati federati

Dati federati è tipicamente non- IID , gli utenti in genere hanno diverse distribuzioni di dati a seconda della modalità di utilizzo. Alcuni client potrebbero avere meno esempi di addestramento sul dispositivo, soffrendo per la scarsità di dati a livello locale, mentre alcuni client avranno esempi di addestramento più che sufficienti. Esploriamo questo concetto di eterogeneità dei dati tipico di un sistema federato con i dati EMNIST che abbiamo a disposizione. È importante notare che questa analisi approfondita dei dati di un cliente è disponibile solo per noi perché si tratta di un ambiente di simulazione in cui tutti i dati sono disponibili per noi localmente. In un vero ambiente di produzione federata non saresti in grado di ispezionare i dati di un singolo cliente.

Innanzitutto, prendiamo un campione dei dati di un cliente per avere un'idea degli esempi su un dispositivo simulato. Poiché il set di dati che stiamo utilizzando è stato codificato da un writer univoco, i dati di un client rappresentano la scrittura a mano di una persona per un campione delle cifre da 0 a 9, simulando il "modello di utilizzo" univoco di un utente.

## Example MNIST digits for one client
figure = plt.figure(figsize=(20, 4))
j = 0

for example in example_dataset.take(40):
  plt.subplot(4, 10, j+1)
  plt.imshow(example['pixels'].numpy(), cmap='gray', aspect='equal')
  plt.axis('off')
  j += 1

png

Ora visualizziamo il numero di esempi su ogni client per ogni etichetta cifra MNIST. Nell'ambiente federato, il numero di esempi su ciascun client può variare notevolmente, a seconda del comportamento dell'utente.

# Number of examples per layer for a sample of clients
f = plt.figure(figsize=(12, 7))
f.suptitle('Label Counts for a Sample of Clients')
for i in range(6):
  client_dataset = emnist_train.create_tf_dataset_for_client(
      emnist_train.client_ids[i])
  plot_data = collections.defaultdict(list)
  for example in client_dataset:
    # Append counts individually per label to make plots
    # more colorful instead of one color per plot.
    label = example['label'].numpy()
    plot_data[label].append(label)
  plt.subplot(2, 3, i+1)
  plt.title('Client {}'.format(i))
  for j in range(10):
    plt.hist(
        plot_data[j],
        density=False,
        bins=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10])

png

Ora visualizziamo l'immagine media per client per ogni etichetta MNIST. Questo codice produrrà la media di ogni valore di pixel per tutti gli esempi dell'utente per un'etichetta. Vedremo che l'immagine media di un cliente per una cifra avrà un aspetto diverso dall'immagine media di un altro cliente per la stessa cifra, a causa dello stile di scrittura unico di ogni persona. Possiamo riflettere su come ogni ciclo di formazione locale sposterà il modello in una direzione diversa su ciascun cliente, poiché stiamo imparando dai dati unici di quell'utente in quel ciclo locale. Più avanti nel tutorial vedremo come possiamo prendere ogni aggiornamento al modello da tutti i clienti e aggregarli insieme nel nostro nuovo modello globale, che ha imparato dai dati unici di ciascuno dei nostri clienti.

# Each client has different mean images, meaning each client will be nudging
# the model in their own directions locally.

for i in range(5):
  client_dataset = emnist_train.create_tf_dataset_for_client(
      emnist_train.client_ids[i])
  plot_data = collections.defaultdict(list)
  for example in client_dataset:
    plot_data[example['label'].numpy()].append(example['pixels'].numpy())
  f = plt.figure(i, figsize=(12, 5))
  f.suptitle("Client #{}'s Mean Image Per Label".format(i))
  for j in range(10):
    mean_img = np.mean(plot_data[j], 0)
    plt.subplot(2, 5, j+1)
    plt.imshow(mean_img.reshape((28, 28)))
    plt.axis('off')

png

png

png

png

png

I dati dell'utente possono essere rumorosi e etichettati in modo inaffidabile. Ad esempio, guardando i dati del Cliente n. 2 sopra, possiamo vedere che per l'etichetta 2, è possibile che ci siano stati alcuni esempi con etichetta errata che creano un'immagine media più rumorosa.

Pre-elaborazione dei dati di input

Poiché i dati sono già un tf.data.Dataset , pre-elaborazione può essere realizzata utilizzando le trasformazioni set di dati. Qui, si appiattire le 28x28 immagini in 784 array -element, mescolate le singoli esempi, organizzarli in lotti, e rinominiamo le caratteristiche dei pixels e label di x ed y per l'uso con Keras. Gettiamo anche in una repeat sopra il set di dati per eseguire diverse epoche.

NUM_CLIENTS = 10
NUM_EPOCHS = 5
BATCH_SIZE = 20
SHUFFLE_BUFFER = 100
PREFETCH_BUFFER = 10

def preprocess(dataset):

  def batch_format_fn(element):
    """Flatten a batch `pixels` and return the features as an `OrderedDict`."""
    return collections.OrderedDict(
        x=tf.reshape(element['pixels'], [-1, 784]),
        y=tf.reshape(element['label'], [-1, 1]))

  return dataset.repeat(NUM_EPOCHS).shuffle(SHUFFLE_BUFFER, seed=1).batch(
      BATCH_SIZE).map(batch_format_fn).prefetch(PREFETCH_BUFFER)

Verifichiamo che ha funzionato.

preprocessed_example_dataset = preprocess(example_dataset)

sample_batch = tf.nest.map_structure(lambda x: x.numpy(),
                                     next(iter(preprocessed_example_dataset)))

sample_batch
OrderedDict([('x', array([[1., 1., 1., ..., 1., 1., 1.],
       [1., 1., 1., ..., 1., 1., 1.],
       [1., 1., 1., ..., 1., 1., 1.],
       ...,
       [1., 1., 1., ..., 1., 1., 1.],
       [1., 1., 1., ..., 1., 1., 1.],
       [1., 1., 1., ..., 1., 1., 1.]], dtype=float32)), ('y', array([[2],
       [1],
       [5],
       [7],
       [1],
       [7],
       [7],
       [1],
       [4],
       [7],
       [4],
       [2],
       [2],
       [5],
       [4],
       [1],
       [1],
       [0],
       [0],
       [9]], dtype=int32))])

Abbiamo quasi tutti gli elementi costitutivi in ​​atto per costruire set di dati federati.

Uno dei modi per alimentare dati federati al TFF in una simulazione è semplicemente come un elenco pitone, ad ogni elemento della lista che contiene i dati di un singolo utente, sia come un elenco o come tf.data.Dataset . Poiché abbiamo già un'interfaccia che fornisce quest'ultimo, usiamolo.

Ecco una semplice funzione di supporto che costruirà un elenco di set di dati dal dato insieme di utenti come input per un ciclo di formazione o valutazione.

def make_federated_data(client_data, client_ids):
  return [
      preprocess(client_data.create_tf_dataset_for_client(x))
      for x in client_ids
  ]

Ora, come scegliamo i clienti?

In un tipico scenario di formazione federata, abbiamo a che fare con una popolazione potenzialmente molto ampia di dispositivi utente, solo una parte dei quali potrebbe essere disponibile per la formazione in un determinato momento. Questo è il caso, ad esempio, quando i dispositivi client sono telefoni cellulari che partecipano alla formazione solo quando sono collegati a una fonte di alimentazione, fuori da una rete a consumo e altrimenti inattivi.

Ovviamente siamo in un ambiente di simulazione e tutti i dati sono disponibili localmente. In genere, quindi, durante l'esecuzione di simulazioni, campioniamo semplicemente un sottoinsieme casuale dei clienti da coinvolgere in ogni ciclo di allenamento, generalmente diverso in ogni turno.

Detto questo, come si può scoprire studiando la carta sul calcolo della media federati algoritmo, realizzare la convergenza in un sistema con sottoinsiemi campionatura casuale di clienti in ogni turno può prendere un po ', e sarebbe poco pratico dover eseguire centinaia di colpi in questo tutorial interattivo.

Quello che faremo invece è campionare il set di client una volta e riutilizzare lo stesso set tra i round per accelerare la convergenza (si adatta intenzionalmente a questi pochi dati utente). Lasciamo al lettore come esercizio la modifica di questo tutorial per simulare il campionamento casuale: è abbastanza facile da fare (una volta fatto, tieni presente che far convergere il modello potrebbe richiedere del tempo).

sample_clients = emnist_train.client_ids[0:NUM_CLIENTS]

federated_train_data = make_federated_data(emnist_train, sample_clients)

print('Number of client datasets: {l}'.format(l=len(federated_train_data)))
print('First dataset: {d}'.format(d=federated_train_data[0]))
Number of client datasets: 10
First dataset: <DatasetV1Adapter shapes: OrderedDict([(x, (None, 784)), (y, (None, 1))]), types: OrderedDict([(x, tf.float32), (y, tf.int32)])>

Creare un modello con Keras

Se stai usando Keras, probabilmente hai già del codice che costruisce un modello Keras. Ecco un esempio di un modello semplice che sarà sufficiente per le nostre esigenze.

def create_keras_model():
  return tf.keras.models.Sequential([
      tf.keras.layers.InputLayer(input_shape=(784,)),
      tf.keras.layers.Dense(10, kernel_initializer='zeros'),
      tf.keras.layers.Softmax(),
  ])

Per poter utilizzare qualsiasi modello con TFF, deve essere avvolto in un'istanza tff.learning.Model interfaccia, che espone i metodi per stampare in avanti del modello, proprietà dei metadati, ecc, in modo simile a Keras, ma introduce anche ulteriori elementi, come i modi per controllare il processo di calcolo delle metriche federate. Non preoccupiamoci di questo per ora; se si dispone di un modello Keras come quello che abbiamo appena sopra definito, si può avere TFF avvolgere per voi invocando tff.learning.from_keras_model , passando il modello e un lotto di dati di esempio come argomenti, come illustrato di seguito.

def model_fn():
  # We _must_ create a new model here, and _not_ capture it from an external
  # scope. TFF will call this within different graph contexts.
  keras_model = create_keras_model()
  return tff.learning.from_keras_model(
      keras_model,
      input_spec=preprocessed_example_dataset.element_spec,
      loss=tf.keras.losses.SparseCategoricalCrossentropy(),
      metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])

Addestrare il modello su dati federati

Ora che abbiamo un modello avvolto come tff.learning.Model per l'uso con TFF, possiamo lasciare TFF costruire un algoritmo di calcolo della media federati invocando la funzione di supporto tff.learning.build_federated_averaging_process , come segue.

Tenete a mente che l'argomento deve essere un costruttore (come model_fn sopra), non è un caso già costruito, in modo che la costruzione del modello può avvenire in un contesto controllato da TFF (se siete curiosi di sapere i motivi per questo, vi invitiamo a leggere il tutorial di follow-up su algoritmi personalizzati ).

Una nota critica sulla algoritmo di calcolo della media federati sotto, ci sono 2 ottimizzatori: un ottimizzatore _client e un ottimizzatore _SERVER. L'ottimizzatore _client viene utilizzato solo per calcolare gli aggiornamenti modello locale su ciascun client. L'ottimizzatore _SERVER si applica l'aggiornamento media al modello globale sul server. In particolare, ciò significa che la scelta dell'ottimizzatore e del tasso di apprendimento utilizzati potrebbe dover essere diversi da quelli utilizzati per addestrare il modello su un set di dati iid standard. Ti consigliamo di iniziare con un normale SGD, possibilmente con un tasso di apprendimento inferiore al solito. Il tasso di apprendimento che utilizziamo non è stato messo a punto con attenzione, sentiti libero di sperimentare.

iterative_process = tff.learning.build_federated_averaging_process(
    model_fn,
    client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=0.02),
    server_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=1.0))

Cosa è appena successo? TFF ha costruito un paio di calcoli federati e confezionato in un tff.templates.IterativeProcess in cui sono disponibili questi calcoli come una coppia di proprietà initialize e next .

In poche parole, i calcoli federati sono programmi in linguaggio interno del TFF che possono esprimere diversi algoritmi federati (si può trovare più su questo nella personalizzato algoritmi tutorial). In questo caso, i due calcoli generati e confezionati in iterative_process attrezzo federata mediazione .

L'obiettivo di TFF è definire i calcoli in modo che possano essere eseguiti in ambienti di apprendimento federati reali, ma attualmente è implementato solo il runtime di simulazione di esecuzione locale. Per eseguire un calcolo in un simulatore, lo invochi semplicemente come una funzione Python. Questo ambiente interpretato predefinito non è progettato per prestazioni elevate, ma sarà sufficiente per questo tutorial; prevediamo di fornire runtime di simulazione ad alte prestazioni per facilitare la ricerca su larga scala nelle versioni future.

Cominciamo con l' initialize di calcolo. Come nel caso di tutti i calcoli federati, puoi considerarlo come una funzione. Il calcolo non accetta argomenti e restituisce un risultato: la rappresentazione dello stato del processo di media federata sul server. Anche se non vogliamo immergerci nei dettagli di TFF, potrebbe essere istruttivo vedere come si presenta questo stato. Puoi visualizzarlo come segue.

str(iterative_process.initialize.type_signature)
'( -> <model=<trainable=<float32[784,10],float32[10]>,non_trainable=<>>,optimizer_state=<int64>,delta_aggregate_state=<value_sum_process=<>,weight_sum_process=<>>,model_broadcast_state=<>>@SERVER)'

Mentre la firma tipo di cui sopra può sembrare a prima vista un po 'criptico, si può riconoscere che lo stato del server è costituito da un model (i parametri del modello iniziale per MNIST che verranno distribuiti a tutti i dispositivi), e optimizer_state (ulteriori informazioni gestito dal server, come il numero di cicli da utilizzare per le pianificazioni degli iperparametri, ecc.).

Invochiamo l' initialize di calcolo per costruire lo stato del server.

state = iterative_process.initialize()

Il secondo della coppia di calcoli federati, next , rappresenta un singolo giro di federati della media, che consiste nello spingere lo stato del server (inclusi i parametri del modello) per i clienti, on-dispositivo di addestramento sui propri dati locali, la raccolta e l'aggiornamento del modello di media e producendo un nuovo modello aggiornato sul server.

Concettualmente, si può pensare di next ad avere una firma di tipo funzionale che appare come segue.

SERVER_STATE, FEDERATED_DATA -> SERVER_STATE, TRAINING_METRICS

In particolare, si dovrebbe pensare next() non come una funzione che gira su un server, ma piuttosto essere una rappresentazione funzionale dichiarativa dell'intero calcolo decentrata - alcuni ingressi sono forniti dal server ( SERVER_STATE ), ma ogni partecipante dispositivo fornisce il proprio set di dati locale.

Eseguiamo un singolo ciclo di allenamento e visualizziamo i risultati. Possiamo utilizzare i dati federati che abbiamo già generato sopra per un campione di utenti.

state, metrics = iterative_process.next(state, federated_train_data)
print('round  1, metrics={}'.format(metrics))
round  1, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.12345679), ('loss', 3.1193738)])), ('stat', OrderedDict([('num_examples', 4860)]))])

Facciamo qualche altro giro. Come notato in precedenza, in genere a questo punto si sceglie un sottoinsieme dei dati di simulazione da un nuovo campione di utenti selezionato casualmente per ogni round al fine di simulare una distribuzione realistica in cui gli utenti vanno e vengono continuamente, ma in questo taccuino interattivo, per a scopo dimostrativo riutilizzeremo semplicemente gli stessi utenti, in modo che il sistema converga rapidamente.

NUM_ROUNDS = 11
for round_num in range(2, NUM_ROUNDS):
  state, metrics = iterative_process.next(state, federated_train_data)
  print('round {:2d}, metrics={}'.format(round_num, metrics))
round  2, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.13518518), ('loss', 2.9834728)])), ('stat', OrderedDict([('num_examples', 4860)]))])
round  3, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.14382716), ('loss', 2.861665)])), ('stat', OrderedDict([('num_examples', 4860)]))])
round  4, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.17407407), ('loss', 2.7957022)])), ('stat', OrderedDict([('num_examples', 4860)]))])
round  5, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.19917695), ('loss', 2.6146567)])), ('stat', OrderedDict([('num_examples', 4860)]))])
round  6, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.21975309), ('loss', 2.529761)])), ('stat', OrderedDict([('num_examples', 4860)]))])
round  7, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.2409465), ('loss', 2.4053504)])), ('stat', OrderedDict([('num_examples', 4860)]))])
round  8, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.2611111), ('loss', 2.315389)])), ('stat', OrderedDict([('num_examples', 4860)]))])
round  9, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.30823046), ('loss', 2.1240263)])), ('stat', OrderedDict([('num_examples', 4860)]))])
round 10, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.33312756), ('loss', 2.1164262)])), ('stat', OrderedDict([('num_examples', 4860)]))])

La perdita di formazione sta diminuendo dopo ogni ciclo di formazione federata, indicando che il modello sta convergendo. Ci sono alcune avvertenze importanti con queste metriche di formazione, tuttavia, vedere la sezione successiva valutazione in questo tutorial.

Visualizzazione delle metriche del modello in TensorBoard

Successivamente, visualizziamo le metriche di questi calcoli federati utilizzando Tensorboard.

Iniziamo creando la directory e il corrispondente writer di riepilogo in cui scrivere le metriche.

logdir = "/tmp/logs/scalars/training/"
summary_writer = tf.summary.create_file_writer(logdir)
state = iterative_process.initialize()

Tracciare le metriche scalari rilevanti con lo stesso writer di riepilogo.

with summary_writer.as_default():
  for round_num in range(1, NUM_ROUNDS):
    state, metrics = iterative_process.next(state, federated_train_data)
    for name, value in metrics['train'].items():
      tf.summary.scalar(name, value, step=round_num)

Avvia TensorBoard con la directory del registro principale specificata sopra. Il caricamento dei dati può richiedere alcuni secondi.

!ls {logdir}
%tensorboard --logdir {logdir} --port=0
events.out.tfevents.1629557449.ebe6e776479e64ea-4903924a278.borgtask.google.com.458912.1.v2
Launching TensorBoard...
Reusing TensorBoard on port 50681 (pid 292785), started 0:30:30 ago. (Use '!kill 292785' to kill it.)
<IPython.core.display.Javascript at 0x7fd6617e02d0>
# Uncomment and run this this cell to clean your directory of old output for
# future graphs from this directory. We don't run it by default so that if 
# you do a "Runtime > Run all" you don't lose your results.

# !rm -R /tmp/logs/scalars/*

Per visualizzare le metriche di valutazione allo stesso modo, puoi creare una cartella eval separata, come "logs/scalars/eval", da scrivere su TensorBoard.

Personalizzazione dell'implementazione del modello

Keras è consigliata API modello alto livello per tensorflow e incoraggiamo utilizzando modelli KERAS (via tff.learning.from_keras_model ) in TFF quando possibile.

Tuttavia, tff.learning fornisce un'interfaccia modello di livello inferiore, tff.learning.Model , che espone la funzionalità minima necessaria per l'utilizzo di un modello per l'apprendimento federata. Direttamente implementa questa interfaccia (eventualmente ancora utilizzando blocchi da costruzione come tf.keras.layers ) consente la massima personalizzazione senza modificare la struttura interna di algoritmi di apprendimento federati.

Quindi ripetiamo tutto da capo.

Definire le variabili del modello, il passaggio in avanti e le metriche

Il primo passo è identificare le variabili TensorFlow con cui lavoreremo. Per rendere più leggibile il seguente codice, definiamo una struttura dati per rappresentare l'intero insieme. Ciò includerà variabili quali weights e bias che saremo formare, così come le variabili che conterranno diverse statistiche cumulative e contatori aggiorneremo durante l'allenamento, come loss_sum , accuracy_sum e num_examples .

MnistVariables = collections.namedtuple(
    'MnistVariables', 'weights bias num_examples loss_sum accuracy_sum')

Ecco un metodo che crea le variabili. Per semplicità, si rappresenta tutte le statistiche come tf.float32 , come che eliminerà la necessità di conversioni di tipo in una fase successiva. Wrapping inizializzatori variabili come lambda è un requisito imposto da variabili di risorse .

def create_mnist_variables():
  return MnistVariables(
      weights=tf.Variable(
          lambda: tf.zeros(dtype=tf.float32, shape=(784, 10)),
          name='weights',
          trainable=True),
      bias=tf.Variable(
          lambda: tf.zeros(dtype=tf.float32, shape=(10)),
          name='bias',
          trainable=True),
      num_examples=tf.Variable(0.0, name='num_examples', trainable=False),
      loss_sum=tf.Variable(0.0, name='loss_sum', trainable=False),
      accuracy_sum=tf.Variable(0.0, name='accuracy_sum', trainable=False))

Con le variabili per i parametri del modello e le statistiche cumulative in atto, ora possiamo definire il metodo forward pass che calcola la perdita, emette previsioni e aggiorna le statistiche cumulative per un singolo batch di dati di input, come segue.

def predict_on_batch(variables, x):
  return tf.nn.softmax(tf.matmul(x, variables.weights) + variables.bias)

def mnist_forward_pass(variables, batch):
  y = predict_on_batch(variables, batch['x'])
  predictions = tf.cast(tf.argmax(y, 1), tf.int32)

  flat_labels = tf.reshape(batch['y'], [-1])
  loss = -tf.reduce_mean(
      tf.reduce_sum(tf.one_hot(flat_labels, 10) * tf.math.log(y), axis=[1]))
  accuracy = tf.reduce_mean(
      tf.cast(tf.equal(predictions, flat_labels), tf.float32))

  num_examples = tf.cast(tf.size(batch['y']), tf.float32)

  variables.num_examples.assign_add(num_examples)
  variables.loss_sum.assign_add(loss * num_examples)
  variables.accuracy_sum.assign_add(accuracy * num_examples)

  return loss, predictions

Successivamente, definiamo una funzione che restituisce un insieme di metriche locali, sempre utilizzando TensorFlow. Questi sono i valori (oltre agli aggiornamenti del modello, che vengono gestiti automaticamente) che possono essere aggregati al server in un processo di apprendimento o valutazione federato.

Qui, abbiamo semplicemente ritorniamo alla media loss e accuracy , nonché i num_examples , che abbiamo bisogno di pesare correttamente i contributi provenienti da diversi utenti nel calcolo degli aggregati federati.

def get_local_mnist_metrics(variables):
  return collections.OrderedDict(
      num_examples=variables.num_examples,
      loss=variables.loss_sum / variables.num_examples,
      accuracy=variables.accuracy_sum / variables.num_examples)

Infine, è necessario determinare come aggregare le metriche locali emessi da ciascun dispositivo tramite get_local_mnist_metrics . Questa è l'unica parte del codice che non è scritto in tensorflow - è un calcolo federata espresso in TFF. Se vuoi scavare più a fondo, sfiorano il algoritmi personalizzati tutorial, ma nella maggior parte delle applicazioni, non sarà davvero bisogno di; le varianti del modello mostrato di seguito dovrebbero essere sufficienti. Ecco come appare:

@tff.federated_computation
def aggregate_mnist_metrics_across_clients(metrics):
  return collections.OrderedDict(
      num_examples=tff.federated_sum(metrics.num_examples),
      loss=tff.federated_mean(metrics.loss, metrics.num_examples),
      accuracy=tff.federated_mean(metrics.accuracy, metrics.num_examples))

L'ingresso metrics corrisponde argomento alla OrderedDict restituito da get_local_mnist_metrics di cui sopra, ma in modo critico i valori non sono più tf.Tensors - sono "in scatola", come tff.Value s, per mettere in chiaro non è più possibile manipolarli usando tensorflow, ma solo utilizzando gli operatori federate del TFF come tff.federated_mean e tff.federated_sum . Il dizionario restituito degli aggregati globali definisce l'insieme di metriche che saranno disponibili sul server.

Costruire un'istanza di tff.learning.Model

Con tutto quanto sopra in atto, siamo pronti per costruire una rappresentazione del modello da utilizzare con TFF simile a quella che viene generata per te quando lasci che TFF ingerisca un modello Keras.

class MnistModel(tff.learning.Model):

  def __init__(self):
    self._variables = create_mnist_variables()

  @property
  def trainable_variables(self):
    return [self._variables.weights, self._variables.bias]

  @property
  def non_trainable_variables(self):
    return []

  @property
  def local_variables(self):
    return [
        self._variables.num_examples, self._variables.loss_sum,
        self._variables.accuracy_sum
    ]

  @property
  def input_spec(self):
    return collections.OrderedDict(
        x=tf.TensorSpec([None, 784], tf.float32),
        y=tf.TensorSpec([None, 1], tf.int32))

  @tf.function
  def predict_on_batch(self, x, training=True):
    del training
    return predict_on_batch(self._variables, x)

  @tf.function
  def forward_pass(self, batch, training=True):
    del training
    loss, predictions = mnist_forward_pass(self._variables, batch)
    num_exmaples = tf.shape(batch['x'])[0]
    return tff.learning.BatchOutput(
        loss=loss, predictions=predictions, num_examples=num_exmaples)

  @tf.function
  def report_local_outputs(self):
    return get_local_mnist_metrics(self._variables)

  @property
  def federated_output_computation(self):
    return aggregate_mnist_metrics_across_clients

Come si può vedere, i metodi astratti e le proprietà definite da tff.learning.Model corrisponde ai frammenti di codice nella sezione precedente che ha introdotto le variabili e ha definito la perdita e le statistiche.

Ecco alcuni punti che vale la pena evidenziare:

  • Tutto lo stato che il modello utilizzerà devono essere catturati come variabili tensorflow, come TFF non usa Python a runtime (ricordate il codice dovrebbe essere scritto in modo tale che possa essere distribuito ai dispositivi mobili; vedere la personalizzato algoritmi tutorial per una più approfondita commento alle motivazioni).
  • Il modello dovrebbe descrivere ciò sotto forma di dati accetta ( input_spec ), come, in generale, TFF è un ambiente fortemente tipizzato e vuole determinare il tipo di firme per tutti i componenti. Dichiarare il formato dell'input del tuo modello ne è una parte essenziale.
  • Anche se tecnicamente non richiesto, si consiglia di avvolgere ogni logica tensorflow (in avanti, computi metrici, etc.) come tf.function s, come questo aiuta a garantire la tensorflow può essere serializzato, ed elimina la necessità per le dipendenze controllo esplicito.

Quanto sopra è sufficiente per la valutazione e algoritmi come Federated SGD. Tuttavia, per la media federata, è necessario specificare come il modello deve essere addestrato localmente su ciascun batch. Specifichiamo un ottimizzatore locale durante la creazione dell'algoritmo di media federata.

Simulazione dell'allenamento federato con il nuovo modello

Con tutto quanto sopra in atto, il resto del processo assomiglia a quello che abbiamo già visto: basta sostituire il costruttore del modello con il costruttore della nostra nuova classe del modello e utilizzare i due calcoli federati nel processo iterativo che hai creato per scorrere turni di allenamento.

iterative_process = tff.learning.build_federated_averaging_process(
    MnistModel,
    client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=0.02))
state = iterative_process.initialize()
state, metrics = iterative_process.next(state, federated_train_data)
print('round  1, metrics={}'.format(metrics))
round  1, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('num_examples', 4860.0), ('loss', 3.0708053), ('accuracy', 0.12777779)])), ('stat', OrderedDict([('num_examples', 4860)]))])
for round_num in range(2, 11):
  state, metrics = iterative_process.next(state, federated_train_data)
  print('round {:2d}, metrics={}'.format(round_num, metrics))
round  2, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('num_examples', 4860.0), ('loss', 3.011699), ('accuracy', 0.13024691)])), ('stat', OrderedDict([('num_examples', 4860)]))])
round  3, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('num_examples', 4860.0), ('loss', 2.7408307), ('accuracy', 0.15576132)])), ('stat', OrderedDict([('num_examples', 4860)]))])
round  4, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('num_examples', 4860.0), ('loss', 2.6761012), ('accuracy', 0.17921811)])), ('stat', OrderedDict([('num_examples', 4860)]))])
round  5, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('num_examples', 4860.0), ('loss', 2.675567), ('accuracy', 0.1855967)])), ('stat', OrderedDict([('num_examples', 4860)]))])
round  6, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('num_examples', 4860.0), ('loss', 2.5664043), ('accuracy', 0.20329218)])), ('stat', OrderedDict([('num_examples', 4860)]))])
round  7, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('num_examples', 4860.0), ('loss', 2.4179392), ('accuracy', 0.24382716)])), ('stat', OrderedDict([('num_examples', 4860)]))])
round  8, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('num_examples', 4860.0), ('loss', 2.3237286), ('accuracy', 0.26687244)])), ('stat', OrderedDict([('num_examples', 4860)]))])
round  9, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('num_examples', 4860.0), ('loss', 2.1861682), ('accuracy', 0.28209877)])), ('stat', OrderedDict([('num_examples', 4860)]))])
round 10, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('num_examples', 4860.0), ('loss', 2.046388), ('accuracy', 0.32037038)])), ('stat', OrderedDict([('num_examples', 4860)]))])

Per visualizzare queste metriche all'interno di TensorBoard, fare riferimento ai passaggi sopra elencati in "Visualizzazione delle metriche del modello in TensorBoard".

Valutazione

Tutti i nostri esperimenti finora hanno presentato solo metriche di addestramento federate: le metriche medie su tutti i batch di dati addestrati per tutti i clienti a turno. Ciò introduce le normali preoccupazioni sull'overfitting, soprattutto perché abbiamo utilizzato lo stesso set di client in ogni round per semplicità, ma c'è un'ulteriore nozione di overfitting nelle metriche di allenamento specifiche per l'algoritmo di media federata. Questo è più facile da vedere se immaginiamo che ogni client abbia un singolo batch di dati e ci alleniamo su quel batch per molte iterazioni (epoche). In questo caso, il modello locale si adatterà rapidamente esattamente a quel batch, quindi la metrica di accuratezza locale di cui facciamo la media si avvicinerà a 1.0. Pertanto, queste metriche di formazione possono essere prese come un segno che la formazione sta progredendo, ma non molto di più.

Per eseguire la valutazione sui dati federati, è possibile costruire un altro calcolo federata progettata proprio per questo scopo, utilizzando la tff.learning.build_federated_evaluation funzione e passando nel costruttore del modello come argomento. Si noti che a differenza di federati della media, dove abbiamo usato MnistTrainableModel , è sufficiente passare il MnistModel . La valutazione non esegue la discesa del gradiente e non è necessario costruire ottimizzatori.

Di sperimentazione e ricerca, quando un insieme di dati di test centralizzato è disponibile, Federated Learning for Text Generation dimostra un'altra opzione di valutazione: prendendo i pesi addestrati da apprendimento federata, applicandole a un modello standard Keras, e poi semplicemente chiamando tf.keras.models.Model.evaluate() su un insieme di dati centralizzata.

evaluation = tff.learning.build_federated_evaluation(MnistModel)

È possibile esaminare la firma del tipo astratto della funzione di valutazione come segue.

str(evaluation.type_signature)
'(<server_model_weights=<trainable=<float32[784,10],float32[10]>,non_trainable=<>>@SERVER,federated_dataset={<x=float32[?,784],y=int32[?,1]>*}@CLIENTS> -> <eval=<num_examples=float32,loss=float32,accuracy=float32>,stat=<num_examples=int64>>@SERVER)'

Non c'è bisogno di essere preoccupato per i dettagli, a questo punto, basta essere consapevoli che ci vuole la seguente forma generale, simile a tff.templates.IterativeProcess.next ma con due importanti differenze. Innanzitutto, non stiamo restituendo lo stato del server, poiché la valutazione non modifica il modello o qualsiasi altro aspetto dello stato: puoi considerarlo senza stato. In secondo luogo, la valutazione richiede solo il modello e non richiede altre parti dello stato del server che potrebbero essere associate all'addestramento, come le variabili dell'ottimizzatore.

SERVER_MODEL, FEDERATED_DATA -> TRAINING_METRICS

Invochiamo la valutazione sull'ultimo stato a cui siamo arrivati ​​durante l'addestramento. Per estrarre l'ultimo modello addestrati dallo stato del server, è sufficiente accedere al .model membro, ossia.

train_metrics = evaluation(state.model, federated_train_data)

Ecco cosa otteniamo. Nota che i numeri sembrano leggermente migliori di quelli riportati dall'ultimo round di allenamento sopra. Per convenzione, le metriche di addestramento riportate dal processo di addestramento iterativo generalmente riflettono le prestazioni del modello all'inizio del ciclo di addestramento, quindi le metriche di valutazione saranno sempre un passo avanti.

str(train_metrics)
"OrderedDict([('eval', OrderedDict([('num_examples', 4860.0), ('loss', 1.7510437), ('accuracy', 0.2788066)])), ('stat', OrderedDict([('num_examples', 4860)]))])"

Ora, compiliamo un campione di test di dati federati ed eseguiamo nuovamente la valutazione sui dati di test. I dati proverranno dallo stesso campione di utenti reali, ma da un set di dati distinto.

federated_test_data = make_federated_data(emnist_test, sample_clients)

len(federated_test_data), federated_test_data[0]
(10,
 <DatasetV1Adapter shapes: OrderedDict([(x, (None, 784)), (y, (None, 1))]), types: OrderedDict([(x, tf.float32), (y, tf.int32)])>)
test_metrics = evaluation(state.model, federated_test_data)
str(test_metrics)
"OrderedDict([('eval', OrderedDict([('num_examples', 580.0), ('loss', 1.8361608), ('accuracy', 0.2413793)])), ('stat', OrderedDict([('num_examples', 580)]))])"

Questo conclude il tutorial. Ti invitiamo a giocare con i parametri (ad es. dimensioni batch, numero di utenti, epoche, tassi di apprendimento, ecc.), a modificare il codice sopra per simulare l'addestramento su campioni casuali di utenti in ogni round e ad esplorare gli altri tutorial abbiamo sviluppato.