Lavorare con ClientData di tff.

Visualizza su TensorFlow.org Esegui in Google Colab Visualizza la fonte su GitHub Scarica taccuino

La nozione di un set di dati codificato dai client (ad esempio gli utenti) è essenziale per il calcolo federato come modellato in TFF. TFF fornisce l'interfaccia tff.simulation.datasets.ClientData a astratto su questo concetto, e le serie di dati che ospita TFF ( StackOverflow , shakespeare , emnist , cifar100 e gldv2 ) tutto implementare questa interfaccia.

Se si sta lavorando sulla formazione federata con il proprio set di dati, TFF incoraggia vivamente di implementare sia la ClientData interfaccia o l'uso una delle funzioni di supporto del TFF per generare un ClientData che rappresenta i dati sul disco, ad esempio tff.simulation.datasets.ClientData.from_clients_and_fn .

Come la maggior parte di esempi end-to-end del TFF Iniziamo con ClientData oggetti, l'attuazione del ClientData di interfaccia con il vostro set di dati personalizzato renderà più facile per spelunk tramite codice esistente scritto con TFF. Inoltre, i tf.data.Datasets che ClientData costrutti possono essere ripetuti su direttamente per produrre strutture di numpy array, in modo ClientData oggetti possono essere utilizzati con qualsiasi quadro ML basato su Python prima di passare al TFF.

Esistono diversi modelli con cui puoi semplificarti la vita se intendi estendere le tue simulazioni a molte macchine o distribuirle. Di seguito ci sarà a piedi attraverso alcuni dei modi in cui possiamo usare ClientData e TFF per rendere la nostra piccola scala iterazione a larga scala sperimentazione-produzione esperienza distribuzione quanto più agevole possibile.

Quale modello dovrei usare per passare ClientData in TFF?

Discuteremo due usi dell'unico del TFF ClientData di profondità; se rientri in una delle due categorie seguenti, preferirai chiaramente l'una rispetto all'altra. In caso contrario, potresti aver bisogno di una comprensione più dettagliata dei pro e dei contro di ciascuno per fare una scelta più sfumata.

  • Voglio eseguire l'iterazione il più rapidamente possibile su una macchina locale; Non ho bisogno di poter sfruttare facilmente il runtime distribuito di TFF.

    • Si vuole passare tf.data.Datasets in al TFF direttamente.
    • Questo consente di programmare imperativamente con tf.data.Dataset oggetti, ed elaborarli in modo arbitrario.
    • Fornisce una maggiore flessibilità rispetto all'opzione di seguito; il push della logica ai client richiede che questa logica sia serializzabile.
  • Voglio eseguire il mio calcolo federato nel runtime remoto di TFF o ho intenzione di farlo presto.

    • In questo caso si desidera mappare la costruzione e la preelaborazione del set di dati ai client.
    • Questo si traduce in si passa semplicemente un elenco di client_ids direttamente sul tuo calcolo federata.
    • Spingere la costruzione e la pre-elaborazione del set di dati ai client evita i colli di bottiglia nella serializzazione e aumenta significativamente le prestazioni con centinaia di migliaia di client.

Configura un ambiente open source

Importa pacchetti

Manipolare un oggetto ClientData

Cominciamo da carico e l'esplorazione del TFF EMNIST ClientData :

client_data, _ = tff.simulation.datasets.emnist.load_data()
Downloading emnist_all.sqlite.lzma: 100%|██████████| 170507172/170507172 [00:19<00:00, 8831921.67it/s]
2021-10-01 11:17:58.718735: E tensorflow/stream_executor/cuda/cuda_driver.cc:271] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected

Controllo il primo set di dati può dirci che tipo di esempi sono nel ClientData .

first_client_id = client_data.client_ids[0]
first_client_dataset = client_data.create_tf_dataset_for_client(
    first_client_id)
print(first_client_dataset.element_spec)
# This information is also available as a `ClientData` property:
assert client_data.element_type_structure == first_client_dataset.element_spec
OrderedDict([('label', TensorSpec(shape=(), dtype=tf.int32, name=None)), ('pixels', TensorSpec(shape=(28, 28), dtype=tf.float32, name=None))])

Si noti che le rese del set di dati collections.OrderedDict oggetti che hanno pixels e label chiavi, dove i pixel è un tensore di forma [28, 28] . Supponiamo di voler appiattire i nostri ingressi fuori a forma [784] . Un possibile modo per farlo sarebbe quello di applicare una funzione di pre-elaborazione per la nostra ClientData oggetto.

def preprocess_dataset(dataset):
  """Create batches of 5 examples, and limit to 3 batches."""

  def map_fn(input):
    return collections.OrderedDict(
        x=tf.reshape(input['pixels'], shape=(-1, 784)),
        y=tf.cast(tf.reshape(input['label'], shape=(-1, 1)), tf.int64),
    )

  return dataset.batch(5).map(
      map_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE).take(5)


preprocessed_client_data = client_data.preprocess(preprocess_dataset)

# Notice that we have both reshaped and renamed the elements of the ordered dict.
first_client_dataset = preprocessed_client_data.create_tf_dataset_for_client(
    first_client_id)
print(first_client_dataset.element_spec)
OrderedDict([('x', TensorSpec(shape=(None, 784), dtype=tf.float32, name=None)), ('y', TensorSpec(shape=(None, 1), dtype=tf.int64, name=None))])

Potremmo inoltre voler eseguire alcune pre-elaborazione più complesse (e possibilmente stateful), ad esempio lo shuffling.

def preprocess_and_shuffle(dataset):
  """Applies `preprocess_dataset` above and shuffles the result."""
  preprocessed = preprocess_dataset(dataset)
  return preprocessed.shuffle(buffer_size=5)

preprocessed_and_shuffled = client_data.preprocess(preprocess_and_shuffle)

# The type signature will remain the same, but the batches will be shuffled.
first_client_dataset = preprocessed_and_shuffled.create_tf_dataset_for_client(
    first_client_id)
print(first_client_dataset.element_spec)
OrderedDict([('x', TensorSpec(shape=(None, 784), dtype=tf.float32, name=None)), ('y', TensorSpec(shape=(None, 1), dtype=tf.int64, name=None))])

L'interfacciamento con un tff.Computation

Ora che siamo in grado di eseguire alcune manipolazioni di base con ClientData oggetti, siamo pronti a dati di alimentazione ad un tff.Computation . Definiamo un tff.templates.IterativeProcess che implementa federati della media , ed esplorare diversi metodi di passarlo dati.

def model_fn():
  model = tf.keras.models.Sequential([
      tf.keras.layers.InputLayer(input_shape=(784,)),
      tf.keras.layers.Dense(10, kernel_initializer='zeros'),
  ])
  return tff.learning.from_keras_model(
      model,
      # Note: input spec is the _batched_ shape, and includes the 
      # label tensor which will be passed to the loss function. This model is
      # therefore configured to accept data _after_ it has been preprocessed.
      input_spec=collections.OrderedDict(
          x=tf.TensorSpec(shape=[None, 784], dtype=tf.float32),
          y=tf.TensorSpec(shape=[None, 1], dtype=tf.int64)),
      loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
      metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])

trainer = tff.learning.build_federated_averaging_process(
    model_fn,
    client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=0.01))

Prima di iniziare a lavorare con questo IterativeProcess , un commento su semantica di ClientData è in ordine. Un ClientData oggetto rappresenta la totalità della popolazione disposizione per la formazione federata, che in genere è disponibile per l'ambiente di esecuzione di un sistema di produzione FL ed è specifico per simulazione. ClientData dà infatti all'utente la capacità di bypass di calcolo federata del tutto e semplicemente addestrare un modello di server-side come al solito tramite ClientData.create_tf_dataset_from_all_clients .

L'ambiente di simulazione di TFF mette il ricercatore in controllo completo del ciclo esterno. In particolare, ciò implica considerazioni sulla disponibilità del client, sull'abbandono del client, ecc., che devono essere affrontate dall'utente o dallo script del driver Python. Si potrebbe per esempio modello client dropout regolando la distribuzione campionaria sui tuoi ClientData's client_ids tali che gli utenti con più dati (e corrispondentemente più lungo in esecuzione calcoli locali) sarebbero selezionato con minore probabilità.

In un vero sistema federato, tuttavia, i clienti non possono essere selezionati in modo esplicito dal formatore del modello; la selezione dei client è delegata al sistema che esegue la computazione federata.

Passando tf.data.Datasets direttamente al TFF

Una possibilità abbiamo per l'interfacciamento fra una ClientData ed un IterativeProcess è quella di costruire tf.data.Datasets in Python, e passando questi set di dati al TFF.

Si noti che, se usiamo i nostri preelaborate ClientData i dataset cediamo sono del tipo appropriato previsto dal nostro modello definito in precedenza.

selected_client_ids = preprocessed_and_shuffled.client_ids[:10]

preprocessed_data_for_clients = [
    preprocessed_and_shuffled.create_tf_dataset_for_client(
        selected_client_ids[i]) for i in range(10)
]

state = trainer.initialize()
for _ in range(5):
  t1 = time.time()
  state, metrics = trainer.next(state, preprocessed_data_for_clients)
  t2 = time.time()
  print('loss {}, round time {}'.format(metrics['train']['loss'], t2 - t1))
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow_federated/python/core/impl/compiler/tensorflow_computation_transformations.py:62: extract_sub_graph (from tensorflow.python.framework.graph_util_impl) is deprecated and will be removed in a future version.
Instructions for updating:
Use `tf.compat.v1.graph_util.extract_sub_graph`
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow_federated/python/core/impl/compiler/tensorflow_computation_transformations.py:62: extract_sub_graph (from tensorflow.python.framework.graph_util_impl) is deprecated and will be removed in a future version.
Instructions for updating:
Use `tf.compat.v1.graph_util.extract_sub_graph`
loss 2.9005744457244873, round time 4.576513767242432
loss 3.113278388977051, round time 0.49641919136047363
loss 2.7581865787506104, round time 0.4904160499572754
loss 2.87259578704834, round time 0.48976993560791016
loss 3.1202380657196045, round time 0.6724586486816406

Se prendiamo questa strada, però, non saremo in grado di muoversi banalmente alla simulazione multimachine. I set di dati costruiamo nel runtime tensorflow locale può acquisire lo stato dall'ambiente circostante pitone e sicuro in serializzazione o deserializzazione quando tentano di stato di riferimento che non è più a loro è. Questo può manifestarsi ad esempio l'errore imperscrutabile dal di tensorflow tensor_util.cc :

Check failed: DT_VARIANT == input.dtype() (21 vs. 20)

Costruzione della mappatura e pre-elaborazione sui clienti

Per evitare questo problema, TFF raccomanda agli utenti di prendere in considerazione insieme di dati di istanze e pre-elaborazione, come qualcosa che accade a livello locale su ogni client, e di utilizzare gli aiutanti di TFF o federated_map per eseguire in modo esplicito questo codice pre-elaborazione su ciascun client.

Concettualmente, il motivo per preferire questo è chiaro: nel runtime locale di TFF, i client hanno accesso solo "accidentalmente" all'ambiente Python globale a causa del fatto che l'intera orchestrazione federata avviene su una singola macchina. Vale la pena notare a questo punto che un pensiero simile dà origine alla filosofia funzionale multipiattaforma, sempre serializzabile di TFF.

TFF rende un tale cambiamento semplice via ClientData's attributo dataset_computation , un tff.Computation che prende un client_id e restituisce associati tf.data.Dataset .

Si noti che preprocess semplicemente funziona con dataset_computation ; la dataset_computation attributo del pre-elaborato ClientData incorpora l'intera pipeline di pre-elaborazione che abbiamo appena definito:

print('dataset computation without preprocessing:')
print(client_data.dataset_computation.type_signature)
print('\n')
print('dataset computation with preprocessing:')
print(preprocessed_and_shuffled.dataset_computation.type_signature)
dataset computation without preprocessing:
(string -> <label=int32,pixels=float32[28,28]>*)


dataset computation with preprocessing:
(string -> <x=float32[?,784],y=int64[?,1]>*)

Potremmo invocare dataset_computation e ricevere un set di dati ansioso nel runtime Python, ma il vero potere di questo approccio è esercitata quando componiamo con un processo iterativo o un altro calcolo per evitare materializzare questi insiemi di dati in runtime ansioso globale a tutti. TFF offre una funzione di supporto tff.simulation.compose_dataset_computation_with_iterative_process che può essere usato per fare esattamente questo.

trainer_accepting_ids = tff.simulation.compose_dataset_computation_with_iterative_process(
    preprocessed_and_shuffled.dataset_computation, trainer)

Entrambi questi tff.templates.IterativeProcesses e quello sopra funzionare allo stesso modo; ma ex accetta set di dati client preelaborate, e quest'ultimo accetta stringhe che rappresentano gli ID cliente, movimentazione sia costruzione set di dati e pre-elaborazione nel suo corpo - infatti state può essere passato tra i due.

for _ in range(5):
  t1 = time.time()
  state, metrics = trainer_accepting_ids.next(state, selected_client_ids)
  t2 = time.time()
  print('loss {}, round time {}'.format(metrics['train']['loss'], t2 - t1))
loss 2.8417396545410156, round time 1.6707067489624023
loss 2.7670371532440186, round time 0.5207102298736572
loss 2.665048122406006, round time 0.5302855968475342
loss 2.7213189601898193, round time 0.5313887596130371
loss 2.580148935317993, round time 0.5283482074737549

Ridimensionamento per un gran numero di clienti

trainer_accepting_ids possono immediatamente essere utilizzati in fase di esecuzione multimachine di TFF, ed evita materializzando tf.data.Datasets e il controller (e quindi li serializzazione e li invio agli operai).

Ciò accelera notevolmente le simulazioni distribuite, in particolare con un numero elevato di client, e consente l'aggregazione intermedia per evitare un simile sovraccarico di serializzazione/deserializzazione.

Approfondimento opzionale: composizione manuale della logica di preelaborazione in TFF

TFF è progettato per la composizionalità da zero; il tipo di composizione appena eseguita dall'assistente di TFF è completamente sotto il nostro controllo come utenti. Potremmo avere manualmente comporre il calcolo di pre-elaborazione che abbiamo appena definito con il formatore proprio next molto semplicemente:

selected_clients_type = tff.FederatedType(preprocessed_and_shuffled.dataset_computation.type_signature.parameter, tff.CLIENTS)

@tff.federated_computation(trainer.next.type_signature.parameter[0], selected_clients_type)
def new_next(server_state, selected_clients):
  preprocessed_data = tff.federated_map(preprocessed_and_shuffled.dataset_computation, selected_clients)
  return trainer.next(server_state, preprocessed_data)

manual_trainer_with_preprocessing = tff.templates.IterativeProcess(initialize_fn=trainer.initialize, next_fn=new_next)

In effetti, questo è effettivamente ciò che l'helper che abbiamo usato sta facendo sotto il cofano (oltre a eseguire il controllo e la manipolazione del tipo appropriati). Potremmo anche avere espresso la stessa logica leggermente diverso, serializzando preprocess_and_shuffle in un tff.Computation , e decomponendo il federated_map in un passo che costruisce set di dati non-pretrattato ed un altro che corre preprocess_and_shuffle ad ogni cliente.

Possiamo verificare che questo percorso più manuale si traduce in calcoli con lo stesso tipo di firma dell'helper di TFF (nomi dei parametri del modulo):

print(trainer_accepting_ids.next.type_signature)
print(manual_trainer_with_preprocessing.next.type_signature)
(<server_state=<model=<trainable=<float32[784,10],float32[10]>,non_trainable=<>>,optimizer_state=<int64>,delta_aggregate_state=<value_sum_process=<>,weight_sum_process=<>>,model_broadcast_state=<>>@SERVER,federated_dataset={string}@CLIENTS> -> <<model=<trainable=<float32[784,10],float32[10]>,non_trainable=<>>,optimizer_state=<int64>,delta_aggregate_state=<value_sum_process=<>,weight_sum_process=<>>,model_broadcast_state=<>>@SERVER,<broadcast=<>,aggregation=<mean_value=<>,mean_weight=<>>,train=<sparse_categorical_accuracy=float32,loss=float32>,stat=<num_examples=int64>>@SERVER>)
(<server_state=<model=<trainable=<float32[784,10],float32[10]>,non_trainable=<>>,optimizer_state=<int64>,delta_aggregate_state=<value_sum_process=<>,weight_sum_process=<>>,model_broadcast_state=<>>@SERVER,selected_clients={string}@CLIENTS> -> <<model=<trainable=<float32[784,10],float32[10]>,non_trainable=<>>,optimizer_state=<int64>,delta_aggregate_state=<value_sum_process=<>,weight_sum_process=<>>,model_broadcast_state=<>>@SERVER,<broadcast=<>,aggregation=<mean_value=<>,mean_weight=<>>,train=<sparse_categorical_accuracy=float32,loss=float32>,stat=<num_examples=int64>>@SERVER>)