¡Google I / O regresa del 18 al 20 de mayo! Reserva espacio y crea tu horario Regístrate ahora
Se usó la API de Cloud Translation para traducir esta página.
Switch to English

Soporte experimental para JAX en TFF

Ver en TensorFlow.org Ejecutar en Google Colab Ver en GitHub Descargar cuaderno

Además de ser parte del ecosistema de TensorFlow, TFF tiene como objetivo permitir la interoperabilidad con otros marcos de ML frontend y backend. Por el momento, el soporte para otros marcos de ML aún se encuentra en la fase de incubación, y las API y la funcionalidad soportada pueden cambiar (en gran parte en función de la demanda de los usuarios de TFF). Este tutorial describe cómo usar TFF con JAX como un frontend alternativo de ML y el compilador XLA como un backend alternativo. Los ejemplos que se muestran aquí se basan en una pila JAX / XLA completamente nativa, de un extremo a otro. La posibilidad de mezclar código entre marcos (por ejemplo, JAX con TensorFlow) se discutirá en uno de los tutoriales futuros.

Como siempre, agradecemos sus contribuciones. Si el soporte para JAX / XLA o la capacidad de interoperar con otros marcos de ML es importante para usted, considere ayudarnos a desarrollar estas capacidades hacia la paridad con el resto de TFF.

Antes de que comencemos

Consulte el cuerpo principal de la documentación de TFF para saber cómo configurar su entorno. Dependiendo de dónde esté ejecutando este tutorial, es posible que desee descomentar y ejecutar parte o todo el código a continuación.

# !pip install --quiet --upgrade tensorflow-federated-nightly
# !pip install --quiet --upgrade nest-asyncio
# import nest_asyncio
# nest_asyncio.apply()

Este tutorial también asume que ha revisado los tutoriales principales de TensorFlow de TFF y que está familiarizado con los conceptos básicos de TFF. Si aún no lo ha hecho, considere revisar al menos uno de ellos.

Cálculos JAX

La compatibilidad con JAX en TFF está diseñada para ser simétrica con la forma en que TFF interopera con TensorFlow, comenzando con las importaciones:

import jax
import numpy as np
import tensorflow_federated as tff

Además, al igual que con TensorFlow, la base para expresar cualquier código TFF es la lógica que se ejecuta localmente. Puede expresar esta lógica en JAX, como se muestra a continuación, utilizando el contenedor @tff.experimental.jax_computation . Se comporta de manera similar a @tff.tf_computation que ya estás familiarizado. Comencemos con algo simple, por ejemplo, un cálculo que suma dos números enteros:

@tff.experimental.jax_computation(np.int32, np.int32)
def add_numbers(x, y):
  return jax.numpy.add(x, y)

Puede usar el cálculo JAX definido anteriormente tal como lo haría normalmente con un cálculo TFF. Por ejemplo, puede verificar su tipo de firma, de la siguiente manera:

str(add_numbers.type_signature)
'(<x=int32,y=int32> -> int32)'

Tenga en cuenta que usamos np.int32 para definir el tipo de argumentos. TFF no distingue entre tipos Numpy (como np.int32 ) y tipo TensorFlow (como tf.int32 ). Desde la perspectiva de TFF, son solo formas de referirse a lo mismo.

Ahora, recuerde que TFF no es Python (y si esto no le suena, revise algunos de nuestros tutoriales anteriores, por ejemplo, sobre algoritmos personalizados). Puede usar la envoltura @tff.experimental.jax_computation con cualquier código JAX que se pueda rastrear y serializar, es decir, con código que normalmente @jax.jit con @jax.jit espera que se compile en XLA (pero no es necesario en realidad use la anotación @jax.jit para incrustar su código JAX en TFF).

De hecho, bajo el capó, TFF compila instantáneamente los cálculos JAX en XLA. Puede comprobarlo usted mismo extrayendo e imprimiendo manualmente el código XLA serializado de add_numbers , de la siguiente manera:

comp_pb = tff.framework.serialize_computation(add_numbers)
comp_pb.WhichOneof('computation')
'xla'
xla_code = jax.lib.xla_client.XlaComputation(comp_pb.xla.hlo_module.value)
print(xla_code.as_hlo_text())
HloModule xla_computation_add_numbers.7

ENTRY xla_computation_add_numbers.7 {
  constant.4 = pred[] constant(false)
  parameter.1 = (s32[], s32[]) parameter(0)
  get-tuple-element.2 = s32[] get-tuple-element(parameter.1), index=0
  get-tuple-element.3 = s32[] get-tuple-element(parameter.1), index=1
  add.5 = s32[] add(get-tuple-element.2, get-tuple-element.3)
  ROOT tuple.6 = (s32[]) tuple(add.5)
}

Piense en la representación de cálculos JAX como código XLA como el equivalente funcional de tf.GraphDef para cálculos expresados ​​en TensorFlow. Es portátil y ejecutable en una variedad de entornos que admiten XLA, al igual que tf.GraphDef se puede ejecutar en cualquier tiempo de ejecución de TensorFlow.

TFF proporciona una pila de tiempo de ejecución basada en el compilador XLA como backend. Puede activarlo de la siguiente manera:

tff.backends.xla.set_local_execution_context()

Ahora, puede ejecutar el cálculo que definimos anteriormente:

add_numbers(2, 3)
5

Suficientemente fácil. Vayamos con el golpe y hagamos algo más complicado, como MNIST.

Ejemplo de entrenamiento MNIST con API enlatada

Como de costumbre, comenzamos por definir un montón de tipos de TFF para lotes de datos y para el modelo (recuerde, TFF es un marco fuertemente tipado).

import collections

BATCH_TYPE = collections.OrderedDict([
    ('pixels', tff.TensorType(np.float32, (50, 784))),
    ('labels', tff.TensorType(np.int32, (50,)))
])

MODEL_TYPE = collections.OrderedDict([
    ('weights', tff.TensorType(np.float32, (784, 10))),
    ('bias', tff.TensorType(np.float32, (10,)))
])

Ahora, definamos una función de pérdida para el modelo en JAX, tomando el modelo y un solo lote de datos como parámetro:

def loss(model, batch):
  y = jax.nn.softmax(
      jax.numpy.add(
          jax.numpy.matmul(batch['pixels'], model['weights']), model['bias']))
  targets = jax.nn.one_hot(jax.numpy.reshape(batch['labels'], -1), 10)
  return -jax.numpy.mean(jax.numpy.sum(targets * jax.numpy.log(y), axis=1))

Ahora, una forma de hacerlo es usar una API enlatada. Aquí hay un ejemplo de cómo puede usar nuestra API para crear un proceso de entrenamiento basado en la función de pérdida que acaba de definir.

STEP_SIZE = 0.001

trainer = tff.experimental.learning.build_jax_federated_averaging_process(
    BATCH_TYPE, MODEL_TYPE, loss, STEP_SIZE)

Puede usar lo anterior tal como lo haría con una compilación de entrenador a partir de un modelo tf.Keras en TensorFlow. Por ejemplo, así es como puede crear el modelo inicial para el entrenamiento:

initial_model = trainer.initialize()
initial_model
Struct([('weights', array([[0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       ...,
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.]], dtype=float32)), ('bias', array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32))])

Para realizar un entrenamiento real, necesitamos algunos datos. Hagamos datos aleatorios para que sea sencillo. Dado que los datos son aleatorios, vamos a evaluar los datos de entrenamiento, ya que de lo contrario, con datos de evaluación aleatorios, sería difícil esperar que el modelo funcione. Además, para esta demostración a pequeña escala, no nos preocuparemos por muestrear clientes al azar (lo dejamos como ejercicio para que el usuario explore esos tipos de cambios siguiendo las plantillas de otros tutoriales):

def random_batch():
  pixels = np.random.uniform(
      low=0.0, high=1.0, size=(50, 784)).astype(np.float32)
  labels = np.random.randint(low=0, high=9, size=(50,), dtype=np.int32)
  return collections.OrderedDict([('pixels', pixels), ('labels', labels)])

NUM_CLIENTS = 2
NUM_BATCHES = 10

train_data = [
    [random_batch() for _ in range(NUM_BATCHES)]
    for _ in range(NUM_CLIENTS)]

Con eso, podemos realizar un solo paso de entrenamiento, de la siguiente manera:

trained_model = trainer.next(initial_model, train_data)
trained_model
Struct([('weights', array([[ 1.04456245e-04, -1.53498477e-05,  2.54597180e-05, ...,
         5.61640409e-05, -5.32875274e-05, -4.62881755e-04],
       [ 7.30908650e-05,  4.67643113e-05,  2.03352147e-06, ...,
         3.77510623e-05,  3.52839161e-05, -4.59865667e-04],
       [ 8.14835730e-05,  3.03147244e-05, -1.89143739e-05, ...,
         1.12527239e-04,  4.09212225e-06, -4.59960109e-04],
       ...,
       [ 9.23552434e-05,  2.44302555e-06, -2.20817346e-05, ...,
         7.61375341e-05,  1.76906979e-05, -4.43495519e-04],
       [ 1.17451040e-04,  2.47748958e-05,  1.04728279e-05, ...,
         5.26388249e-07,  7.21131510e-05, -4.67137404e-04],
       [ 3.75041491e-05,  6.58061981e-05,  1.14522081e-05, ...,
         2.52584141e-05,  3.55410739e-05, -4.30888613e-04]], dtype=float32)), ('bias', array([ 1.5096272e-04,  2.6502126e-05, -1.9462314e-05,  8.1269856e-05,
        2.1832302e-04,  1.6636557e-04,  1.2815947e-04,  9.0642272e-05,
        7.7109929e-05, -9.1987278e-04], dtype=float32))])

Evaluemos el resultado del paso de entrenamiento. Para que sea fácil, podemos evaluarlo de forma centralizada:

import itertools
eval_data = list(itertools.chain.from_iterable(train_data))

def average_loss(model, data):
  return np.mean([loss(model, batch) for batch in data])

print (average_loss(initial_model, eval_data))
print (average_loss(trained_model, eval_data))
2.3025854
2.282762

La pérdida está disminuyendo. ¡Estupendo! Ahora, ejecutemos esto en varias rondas:

NUM_ROUNDS = 20
for _ in range(NUM_ROUNDS):
  trained_model = trainer.next(trained_model, train_data)
  print(average_loss(trained_model, eval_data))
2.2685437
2.257856
2.2495182
2.2428129
2.2372835
2.2326245
2.2286277
2.2251441
2.2220676
2.219318
2.2168345
2.2145717
2.2124937
2.2105706
2.2087805
2.2071042
2.2055268
2.2040353
2.2026198
2.2012706

Como puede ver, usar JAX con TFF no es tan diferente, aunque las API experimentales aún no están a la par con la funcionalidad de las API de TensorFlow.

Bajo el capó

Si prefiere no usar nuestra API enlatada, puede implementar sus propios cálculos personalizados, de la misma manera que lo ha visto en los tutoriales de algoritmos personalizados para TensorFlow, excepto que usará el mecanismo de JAX para el descenso de gradientes. Por ejemplo, a continuación se muestra cómo puede definir un cálculo JAX que actualice el modelo en un solo minibatch:

@tff.experimental.jax_computation(MODEL_TYPE, BATCH_TYPE)
def train_on_one_batch(model, batch):
  grads = jax.api.grad(loss)(model, batch)
  return collections.OrderedDict([
      (k, model[k] - STEP_SIZE * grads[k]) for k in ['weights', 'bias']
  ])

A continuación, le indicamos cómo puede probar que funciona:

sample_batch = random_batch()
trained_model = train_on_one_batch(initial_model, sample_batch)
print(average_loss(initial_model, [sample_batch]))
print(average_loss(trained_model, [sample_batch]))
2.3025854
2.2977567

Una advertencia de trabajar con JAX es que no ofrece el equivalente detf.data.Dataset . Por lo tanto, para iterar sobre conjuntos de datos, deberá usar las construcciones declarativas de TFF para operaciones en secuencias, como la que se muestra a continuación:

@tff.federated_computation(MODEL_TYPE, tff.SequenceType(BATCH_TYPE))
def train_on_one_client(model, batches):
  return tff.sequence_reduce(batches, model, train_on_one_batch)

Veamos que funciona:

sample_dataset = [random_batch() for _ in range(100)]
trained_model = train_on_one_client(initial_model, sample_dataset)
print(average_loss(initial_model, sample_dataset))
print(average_loss(trained_model, sample_dataset))
2.3025854
2.2284968

El cálculo que realiza una sola ronda de entrenamiento se parece al que pudo haber visto en los tutoriales de TensorFlow:

@tff.federated_computation(
    tff.FederatedType(MODEL_TYPE, tff.SERVER),
    tff.FederatedType(tff.SequenceType(BATCH_TYPE), tff.CLIENTS))
def train_one_round(model, federated_data):
  locally_trained_models = tff.federated_map(
      train_on_one_client,
      collections.OrderedDict([
          ('model', tff.federated_broadcast(model)),
          ('batches', federated_data)]))
  return tff.federated_mean(locally_trained_models)

Veamos que funciona:

trained_model = train_one_round(initial_model, train_data)
print(average_loss(initial_model, eval_data))
print(average_loss(trained_model, eval_data))
2.3025854
2.282762

Como puede ver, usar JAX en TFF, ya sea a través de API enlatadas o directamente usando las construcciones TFF de bajo nivel, es similar a usar TFF con TensorFlow. Esté atento a las actualizaciones futuras, y si desea ver un mejor soporte para la interoperabilidad en los marcos de ML, no dude en enviarnos una solicitud de extracción.