Google I / O tornerà dal 18 al 20 maggio! Prenota lo spazio e costruisci il tuo programma Registrati ora

Supporto sperimentale per JAX in TFF

Visualizza su TensorFlow.org Esegui in Google Colab Visualizza su GitHub Scarica taccuino

Oltre a far parte dell'ecosistema TensorFlow, TFF mira a consentire l'interoperabilità con altri framework ML frontend e backend. Al momento, il supporto per altri framework ML è ancora in fase di incubazione e le API e le funzionalità supportate potrebbero cambiare (in gran parte in funzione della domanda da parte degli utenti di TFF). Questo tutorial descrive come utilizzare TFF con JAX come frontend ML alternativo e il compilatore XLA come backend alternativo. Gli esempi mostrati qui sono basati su uno stack JAX / XLA completamente nativo, end-to-end. La possibilità di mescolare il codice tra i framework (ad esempio, JAX con TensorFlow) sarà discussa in uno dei futuri tutorial.

Come sempre, accogliamo con favore i vostri contributi. Se il supporto per JAX / XLA o la capacità di interoperare con altri framework ML è importante per te, considera di aiutarci a far evolvere queste funzionalità verso la parità con il resto del TFF.

Prima di iniziare

Si prega di consultare il corpo principale della documentazione TFF per come configurare il proprio ambiente. A seconda di dove stai eseguendo questo tutorial, potresti voler rimuovere il commento ed eseguire parte o tutto il codice di seguito.

# !pip install --quiet --upgrade tensorflow-federated-nightly
# !pip install --quiet --upgrade nest-asyncio
# import nest_asyncio
# nest_asyncio.apply()

Questo tutorial presuppone inoltre che tu abbia esaminato i principali tutorial di TensorFlow di TFF e che tu abbia familiarità con i concetti fondamentali di TFF. Se non l'hai ancora fatto, considera di rivederne almeno uno.

Calcoli JAX

Il supporto per JAX in TFF è progettato per essere simmetrico con il modo in cui TFF interagisce con TensorFlow, a partire dalle importazioni:

import jax
import numpy as np
import tensorflow_federated as tff

Inoltre, proprio come con TensorFlow, la base per esprimere qualsiasi codice TFF è la logica che viene eseguita localmente. È possibile esprimere questa logica in JAX, come mostrato di seguito, utilizzando il wrapper @tff.experimental.jax_computation . Si comporta in modo simile alla @tff.tf_computation che ormai conosci. Cominciamo con qualcosa di semplice, ad esempio, un calcolo che aggiunge due numeri interi:

@tff.experimental.jax_computation(np.int32, np.int32)
def add_numbers(x, y):
  return jax.numpy.add(x, y)

Puoi usare il calcolo JAX definito sopra proprio come faresti normalmente con un calcolo TFF. Ad esempio, puoi controllare la sua firma del tipo, come segue:

str(add_numbers.type_signature)
'(<x=int32,y=int32> -> int32)'

Nota che abbiamo usato np.int32 per definire il tipo di argomenti. TFF non distingue tra i tipi Numpy (come np.int32 ) e il tipo TensorFlow (come tf.int32 ). Dal punto di vista di TFF, sono solo modi per riferirsi alla stessa cosa.

Ora, ricorda che TFF non è Python (e se questo non ti suona un campanello, rivedi alcuni dei nostri tutorial precedenti, ad esempio, sugli algoritmi personalizzati). È possibile utilizzare il wrapper @tff.experimental.jax_computation con qualsiasi codice JAX che può essere tracciato e serializzato, ovvero con il codice che normalmente si @jax.jit con @jax.jit dovrebbe essere compilato in XLA (ma non è necessario effettivamente usa l'annotazione @jax.jit per incorporare il tuo codice JAX in TFF).

Infatti, sotto il cofano, TFF compila istantaneamente i calcoli JAX in XLA. Puoi verificarlo tu stesso estraendo e stampando manualmente il codice XLA serializzato da add_numbers , come segue:

comp_pb = tff.framework.serialize_computation(add_numbers)
comp_pb.WhichOneof('computation')
'xla'
xla_code = jax.lib.xla_client.XlaComputation(comp_pb.xla.hlo_module.value)
print(xla_code.as_hlo_text())
HloModule xla_computation_add_numbers.7

ENTRY xla_computation_add_numbers.7 {
  constant.4 = pred[] constant(false)
  parameter.1 = (s32[], s32[]) parameter(0)
  get-tuple-element.2 = s32[] get-tuple-element(parameter.1), index=0
  get-tuple-element.3 = s32[] get-tuple-element(parameter.1), index=1
  add.5 = s32[] add(get-tuple-element.2, get-tuple-element.3)
  ROOT tuple.6 = (s32[]) tuple(add.5)
}

Pensa alla rappresentazione dei calcoli JAX come codice XLA come l'equivalente funzionale di tf.GraphDef per i calcoli espressi in TensorFlow. È portatile ed eseguibile in una varietà di ambienti che supportano XLA, proprio come tf.GraphDef può essere eseguito su qualsiasi runtime TensorFlow.

TFF fornisce uno stack di runtime basato sul compilatore XLA come back-end. Puoi attivarlo come segue:

tff.backends.xla.set_local_execution_context()

Ora puoi eseguire il calcolo che abbiamo definito sopra:

add_numbers(2, 3)
5

Abbastanza facile. Andiamo con il colpo e facciamo qualcosa di più complicato, come MNIST.

Esempio di formazione MNIST con API in scatola

Come al solito, iniziamo definendo un gruppo di tipi di TFF per batch di dati e per il modello (ricorda, TFF è un framework fortemente tipizzato).

import collections

BATCH_TYPE = collections.OrderedDict([
    ('pixels', tff.TensorType(np.float32, (50, 784))),
    ('labels', tff.TensorType(np.int32, (50,)))
])

MODEL_TYPE = collections.OrderedDict([
    ('weights', tff.TensorType(np.float32, (784, 10))),
    ('bias', tff.TensorType(np.float32, (10,)))
])

Ora, definiamo una funzione di perdita per il modello in JAX, prendendo come parametro il modello e un singolo batch di dati:

def loss(model, batch):
  y = jax.nn.softmax(
      jax.numpy.add(
          jax.numpy.matmul(batch['pixels'], model['weights']), model['bias']))
  targets = jax.nn.one_hot(jax.numpy.reshape(batch['labels'], -1), 10)
  return -jax.numpy.mean(jax.numpy.sum(targets * jax.numpy.log(y), axis=1))

Ora, un modo è utilizzare un'API in scatola. Ecco un esempio di come puoi utilizzare la nostra API per creare un processo di addestramento basato sulla funzione di perdita appena definita.

STEP_SIZE = 0.001

trainer = tff.experimental.learning.build_jax_federated_averaging_process(
    BATCH_TYPE, MODEL_TYPE, loss, STEP_SIZE)

È possibile utilizzare quanto sopra proprio come si utilizzerebbe una build del trainer da un modello tf.Keras in TensorFlow. Ad esempio, ecco come creare il modello iniziale per l'addestramento:

initial_model = trainer.initialize()
initial_model
Struct([('weights', array([[0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       ...,
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.]], dtype=float32)), ('bias', array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32))])

Per eseguire l'addestramento effettivo, abbiamo bisogno di alcuni dati. Facciamo dati casuali per mantenerlo semplice. Poiché i dati sono casuali, valuteremo sui dati di addestramento, poiché altrimenti, con dati di valutazione casuali, sarebbe difficile aspettarsi che il modello funzioni. Inoltre, per questa demo su piccola scala, non ci preoccuperemo di campionare casualmente i client (lasciamo all'utente l'esplorazione di questi tipi di modifiche seguendo i modelli di altri tutorial come esercizio):

def random_batch():
  pixels = np.random.uniform(
      low=0.0, high=1.0, size=(50, 784)).astype(np.float32)
  labels = np.random.randint(low=0, high=9, size=(50,), dtype=np.int32)
  return collections.OrderedDict([('pixels', pixels), ('labels', labels)])

NUM_CLIENTS = 2
NUM_BATCHES = 10

train_data = [
    [random_batch() for _ in range(NUM_BATCHES)]
    for _ in range(NUM_CLIENTS)]

Con ciò, possiamo eseguire un singolo passaggio dell'allenamento, come segue:

trained_model = trainer.next(initial_model, train_data)
trained_model
Struct([('weights', array([[ 1.04456245e-04, -1.53498477e-05,  2.54597180e-05, ...,
         5.61640409e-05, -5.32875274e-05, -4.62881755e-04],
       [ 7.30908650e-05,  4.67643113e-05,  2.03352147e-06, ...,
         3.77510623e-05,  3.52839161e-05, -4.59865667e-04],
       [ 8.14835730e-05,  3.03147244e-05, -1.89143739e-05, ...,
         1.12527239e-04,  4.09212225e-06, -4.59960109e-04],
       ...,
       [ 9.23552434e-05,  2.44302555e-06, -2.20817346e-05, ...,
         7.61375341e-05,  1.76906979e-05, -4.43495519e-04],
       [ 1.17451040e-04,  2.47748958e-05,  1.04728279e-05, ...,
         5.26388249e-07,  7.21131510e-05, -4.67137404e-04],
       [ 3.75041491e-05,  6.58061981e-05,  1.14522081e-05, ...,
         2.52584141e-05,  3.55410739e-05, -4.30888613e-04]], dtype=float32)), ('bias', array([ 1.5096272e-04,  2.6502126e-05, -1.9462314e-05,  8.1269856e-05,
        2.1832302e-04,  1.6636557e-04,  1.2815947e-04,  9.0642272e-05,
        7.7109929e-05, -9.1987278e-04], dtype=float32))])

Valutiamo il risultato della fase di formazione. Per semplificare, possiamo valutarlo in modo centralizzato:

import itertools
eval_data = list(itertools.chain.from_iterable(train_data))

def average_loss(model, data):
  return np.mean([loss(model, batch) for batch in data])

print (average_loss(initial_model, eval_data))
print (average_loss(trained_model, eval_data))
2.3025854
2.282762

La perdita sta diminuendo. Grande! Ora, eseguiamolo su più round:

NUM_ROUNDS = 20
for _ in range(NUM_ROUNDS):
  trained_model = trainer.next(trained_model, train_data)
  print(average_loss(trained_model, eval_data))
2.2685437
2.257856
2.2495182
2.2428129
2.2372835
2.2326245
2.2286277
2.2251441
2.2220676
2.219318
2.2168345
2.2145717
2.2124937
2.2105706
2.2087805
2.2071042
2.2055268
2.2040353
2.2026198
2.2012706

Come puoi vedere, l'utilizzo di JAX con TFF non è così diverso, anche se le API sperimentali non sono ancora alla pari con le funzionalità delle API TensorFlow.

Sotto il cappuccio

Se preferisci non utilizzare la nostra API predefinita, puoi implementare i tuoi calcoli personalizzati, più o meno allo stesso modo di come l'hai visto fare nei tutorial sugli algoritmi personalizzati per TensorFlow, tranne per il fatto che utilizzerai il meccanismo di JAX per la discesa del gradiente. Ad esempio, di seguito viene illustrato come definire un calcolo JAX che aggiorna il modello su un singolo minibatch:

@tff.experimental.jax_computation(MODEL_TYPE, BATCH_TYPE)
def train_on_one_batch(model, batch):
  grads = jax.api.grad(loss)(model, batch)
  return collections.OrderedDict([
      (k, model[k] - STEP_SIZE * grads[k]) for k in ['weights', 'bias']
  ])

Ecco come puoi verificare che funzioni:

sample_batch = random_batch()
trained_model = train_on_one_batch(initial_model, sample_batch)
print(average_loss(initial_model, [sample_batch]))
print(average_loss(trained_model, [sample_batch]))
2.3025854
2.2977567

Un avvertimento nel lavorare con JAX è che non offre l'equivalente ditf.data.Dataset . Pertanto, per iterare su set di dati, sarà necessario utilizzare i controlli dichiarativi di TFF per le operazioni sulle sequenze, come quella mostrata di seguito:

@tff.federated_computation(MODEL_TYPE, tff.SequenceType(BATCH_TYPE))
def train_on_one_client(model, batches):
  return tff.sequence_reduce(batches, model, train_on_one_batch)

Vediamo che funziona:

sample_dataset = [random_batch() for _ in range(100)]
trained_model = train_on_one_client(initial_model, sample_dataset)
print(average_loss(initial_model, sample_dataset))
print(average_loss(trained_model, sample_dataset))
2.3025854
2.2284968

Il calcolo che esegue un singolo ciclo di addestramento è simile a quello che potresti aver visto nei tutorial di TensorFlow:

@tff.federated_computation(
    tff.FederatedType(MODEL_TYPE, tff.SERVER),
    tff.FederatedType(tff.SequenceType(BATCH_TYPE), tff.CLIENTS))
def train_one_round(model, federated_data):
  locally_trained_models = tff.federated_map(
      train_on_one_client,
      collections.OrderedDict([
          ('model', tff.federated_broadcast(model)),
          ('batches', federated_data)]))
  return tff.federated_mean(locally_trained_models)

Vediamo che funziona:

trained_model = train_one_round(initial_model, train_data)
print(average_loss(initial_model, eval_data))
print(average_loss(trained_model, eval_data))
2.3025854
2.282762

Come puoi vedere, l'utilizzo di JAX in TFF, tramite API predefinite o direttamente utilizzando i costrutti TFF di basso livello, è simile all'utilizzo di TFF con TensorFlow. Resta sintonizzato per futuri aggiornamenti e se desideri vedere un supporto migliore per l'interoperabilità tra framework ML, non esitare a inviarci una richiesta pull!