Ayuda a proteger la Gran Barrera de Coral con TensorFlow en Kaggle Únete Challenge

Soporte experimental para JAX en TFF

Ver en TensorFlow.org Ejecutar en Google Colab Ver en GitHub Descargar cuaderno

Además de ser parte del ecosistema de TensorFlow, TFF tiene como objetivo permitir la interoperabilidad con otros marcos de ML frontend y backend. Por el momento, el soporte para otros marcos de ML aún se encuentra en la fase de incubación, y las API y la funcionalidad soportada pueden cambiar (en gran parte en función de la demanda de los usuarios de TFF). Este tutorial describe cómo usar TFF con JAX como un frontend alternativo de ML y el compilador XLA como un backend alternativo. Los ejemplos que se muestran aquí se basan en una pila JAX / XLA completamente nativa, de un extremo a otro. La posibilidad de mezclar código entre marcos (por ejemplo, JAX con TensorFlow) se discutirá en uno de los tutoriales futuros.

Como siempre, agradecemos sus contribuciones. Si el soporte para JAX / XLA o la capacidad de interoperar con otros marcos de ML es importante para usted, considere ayudarnos a desarrollar estas capacidades hacia la paridad con el resto de TFF.

Antes de que comencemos

Consulte el cuerpo principal de la documentación de TFF para saber cómo configurar su entorno. Dependiendo de dónde esté ejecutando este tutorial, es posible que desee descomentar y ejecutar parte o todo el código a continuación.

# !pip install --quiet --upgrade tensorflow-federated-nightly
# !pip install --quiet --upgrade nest-asyncio
# import nest_asyncio
# nest_asyncio.apply()

Este tutorial también asume que ha revisado los tutoriales principales de TensorFlow de TFF y que está familiarizado con los conceptos básicos de TFF. Si aún no lo ha hecho, considere revisar al menos uno de ellos.

Cálculos JAX

La compatibilidad con JAX en TFF está diseñada para ser simétrica con la forma en que TFF interopera con TensorFlow, comenzando con las importaciones:

import jax
import numpy as np
import tensorflow_federated as tff

Además, al igual que con TensorFlow, la base para expresar cualquier código TFF es la lógica que se ejecuta localmente. Usted puede expresar esta lógica en JAX, como se muestra a continuación, utilizando el @tff.experimental.jax_computation envoltura. Se comporta de manera similar al @tff.tf_computation que a estas alturas su están familiarizados. Comencemos con algo simple, por ejemplo, un cálculo que suma dos números enteros:

@tff.experimental.jax_computation(np.int32, np.int32)
def add_numbers(x, y):
  return jax.numpy.add(x, y)

Puede usar el cálculo JAX definido anteriormente tal como lo haría normalmente con un cálculo TFF. Por ejemplo, puede verificar su tipo de firma, de la siguiente manera:

str(add_numbers.type_signature)
'(<x=int32,y=int32> -> int32)'

Tenga en cuenta que utilizamos np.int32 para definir el tipo de argumentos. TFF no distingue entre tipos NumPy (tales como np.int32 ) y el tipo TensorFlow (como tf.int32 ). Desde la perspectiva de TFF, son solo formas de referirse a lo mismo.

Ahora, recuerde que TFF no es Python (y si esto no le suena, revise algunos de nuestros tutoriales anteriores, por ejemplo, sobre algoritmos personalizados). Se puede utilizar el @tff.experimental.jax_computation envoltorio con cualquier JAX código que puede ser rastreado y serializado, es decir, con el código que lo haría normalmente con anotaciones @jax.jit espera que sea compilado en XLA (pero que no es necesario en realidad utilizar el @jax.jit anotación para incrustar el código JAX en TFF).

De hecho, bajo el capó, TFF compila instantáneamente los cálculos JAX en XLA. Esto se puede comprobar por sí mismo mediante la extracción e imprimir el código XLA serializado de forma manual add_numbers , de la siguiente manera:

comp_pb = tff.framework.serialize_computation(add_numbers)
comp_pb.WhichOneof('computation')
'xla'
xla_code = jax.lib.xla_client.XlaComputation(comp_pb.xla.hlo_module.value)
print(xla_code.as_hlo_text())
HloModule xla_computation_add_numbers.7

ENTRY xla_computation_add_numbers.7 {
  constant.4 = pred[] constant(false)
  parameter.1 = (s32[], s32[]) parameter(0)
  get-tuple-element.2 = s32[] get-tuple-element(parameter.1), index=0
  get-tuple-element.3 = s32[] get-tuple-element(parameter.1), index=1
  add.5 = s32[] add(get-tuple-element.2, get-tuple-element.3)
  ROOT tuple.6 = (s32[]) tuple(add.5)
}

Piense en representación de los cálculos JAX como código XLA como el equivalente funcional de tf.GraphDef para los cálculos expresados en TensorFlow. Es portátil y ejecutable en una variedad de entornos que apoyan XLA, al igual que el tf.GraphDef puede ser ejecutado en cualquier tiempo de ejecución TensorFlow.

TFF proporciona una pila de tiempo de ejecución basada en el compilador XLA como backend. Puede activarlo de la siguiente manera:

tff.backends.xla.set_local_python_execution_context()

Ahora, puede ejecutar el cálculo que definimos anteriormente:

add_numbers(2, 3)
5

Suficientemente fácil. Vayamos con el golpe y hagamos algo más complicado, como MNIST.

Ejemplo de entrenamiento MNIST con API enlatada

Como de costumbre, comenzamos por definir un montón de tipos de TFF para lotes de datos y para el modelo (recuerde, TFF es un marco fuertemente tipado).

import collections

BATCH_TYPE = collections.OrderedDict([
    ('pixels', tff.TensorType(np.float32, (50, 784))),
    ('labels', tff.TensorType(np.int32, (50,)))
])

MODEL_TYPE = collections.OrderedDict([
    ('weights', tff.TensorType(np.float32, (784, 10))),
    ('bias', tff.TensorType(np.float32, (10,)))
])

Ahora, definamos una función de pérdida para el modelo en JAX, tomando el modelo y un solo lote de datos como parámetro:

def loss(model, batch):
  y = jax.nn.softmax(
      jax.numpy.add(
          jax.numpy.matmul(batch['pixels'], model['weights']), model['bias']))
  targets = jax.nn.one_hot(jax.numpy.reshape(batch['labels'], -1), 10)
  return -jax.numpy.mean(jax.numpy.sum(targets * jax.numpy.log(y), axis=1))

Ahora, una forma de hacerlo es usar una API enlatada. A continuación, se muestra un ejemplo de cómo puede utilizar nuestra API para crear un proceso de entrenamiento basado en la función de pérdida que acaba de definir.

STEP_SIZE = 0.001

trainer = tff.experimental.learning.build_jax_federated_averaging_process(
    BATCH_TYPE, MODEL_TYPE, loss, STEP_SIZE)

Puede utilizar la anterior del mismo modo que usaría una acumulación de un entrenador tf.Keras modelo en TensorFlow. Por ejemplo, así es como puede crear el modelo inicial para el entrenamiento:

initial_model = trainer.initialize()
initial_model
Struct([('weights', array([[0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       ...,
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.]], dtype=float32)), ('bias', array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32))])

Para realizar un entrenamiento real, necesitamos algunos datos. Hagamos datos aleatorios para que sea sencillo. Dado que los datos son aleatorios, vamos a evaluar los datos de entrenamiento, ya que de lo contrario, con datos de evaluación aleatorios, sería difícil esperar que el modelo funcione. Además, para esta demostración a pequeña escala, no nos preocuparemos por muestrear clientes al azar (lo dejamos como ejercicio para que el usuario explore esos tipos de cambios siguiendo las plantillas de otros tutoriales):

def random_batch():
  pixels = np.random.uniform(
      low=0.0, high=1.0, size=(50, 784)).astype(np.float32)
  labels = np.random.randint(low=0, high=9, size=(50,), dtype=np.int32)
  return collections.OrderedDict([('pixels', pixels), ('labels', labels)])

NUM_CLIENTS = 2
NUM_BATCHES = 10

train_data = [
    [random_batch() for _ in range(NUM_BATCHES)]
    for _ in range(NUM_CLIENTS)]

Con eso, podemos realizar un solo paso de entrenamiento, de la siguiente manera:

trained_model = trainer.next(initial_model, train_data)
trained_model
Struct([('weights', array([[ 1.04456245e-04, -1.53498477e-05,  2.54597180e-05, ...,
         5.61640409e-05, -5.32875274e-05, -4.62881755e-04],
       [ 7.30908650e-05,  4.67643113e-05,  2.03352147e-06, ...,
         3.77510623e-05,  3.52839161e-05, -4.59865667e-04],
       [ 8.14835730e-05,  3.03147244e-05, -1.89143739e-05, ...,
         1.12527239e-04,  4.09212225e-06, -4.59960109e-04],
       ...,
       [ 9.23552434e-05,  2.44302555e-06, -2.20817346e-05, ...,
         7.61375341e-05,  1.76906979e-05, -4.43495519e-04],
       [ 1.17451040e-04,  2.47748958e-05,  1.04728279e-05, ...,
         5.26388249e-07,  7.21131510e-05, -4.67137404e-04],
       [ 3.75041491e-05,  6.58061981e-05,  1.14522081e-05, ...,
         2.52584141e-05,  3.55410739e-05, -4.30888613e-04]], dtype=float32)), ('bias', array([ 1.5096272e-04,  2.6502126e-05, -1.9462314e-05,  8.1269856e-05,
        2.1832302e-04,  1.6636557e-04,  1.2815947e-04,  9.0642272e-05,
        7.7109929e-05, -9.1987278e-04], dtype=float32))])

Evaluemos el resultado del paso de entrenamiento. Para que sea fácil, podemos evaluarlo de forma centralizada:

import itertools
eval_data = list(itertools.chain.from_iterable(train_data))

def average_loss(model, data):
  return np.mean([loss(model, batch) for batch in data])

print (average_loss(initial_model, eval_data))
print (average_loss(trained_model, eval_data))
2.3025854
2.282762

La pérdida está disminuyendo. ¡Excelente! Ahora, ejecutemos esto en varias rondas:

NUM_ROUNDS = 20
for _ in range(NUM_ROUNDS):
  trained_model = trainer.next(trained_model, train_data)
  print(average_loss(trained_model, eval_data))
2.2685437
2.257856
2.2495182
2.2428129
2.2372835
2.2326245
2.2286277
2.2251441
2.2220676
2.219318
2.2168345
2.2145717
2.2124937
2.2105706
2.2087805
2.2071042
2.2055268
2.2040353
2.2026198
2.2012706

Como puede ver, usar JAX con TFF no es tan diferente, aunque las API experimentales aún no están a la par con la funcionalidad de las API de TensorFlow.

Bajo el capó

Si prefiere no usar nuestra API enlatada, puede implementar sus propios cálculos personalizados, de la misma manera que lo ha visto en los tutoriales de algoritmos personalizados para TensorFlow, excepto que usará el mecanismo de JAX para el descenso de gradientes. Por ejemplo, a continuación se muestra cómo puede definir un cálculo JAX que actualice el modelo en un solo minibatch:

@tff.experimental.jax_computation(MODEL_TYPE, BATCH_TYPE)
def train_on_one_batch(model, batch):
  grads = jax.grad(loss)(model, batch)
  return collections.OrderedDict([
      (k, model[k] - STEP_SIZE * grads[k]) for k in ['weights', 'bias']
  ])

A continuación, le indicamos cómo puede probar que funciona:

sample_batch = random_batch()
trained_model = train_on_one_batch(initial_model, sample_batch)
print(average_loss(initial_model, [sample_batch]))
print(average_loss(trained_model, [sample_batch]))
2.3025854
2.2977567

Una advertencia de trabajar con JAX es que no ofrece el equivalente de tf.data.Dataset . Por lo tanto, para iterar sobre conjuntos de datos, deberá usar las construcciones declarativas de TFF para operaciones en secuencias, como la que se muestra a continuación:

@tff.federated_computation(MODEL_TYPE, tff.SequenceType(BATCH_TYPE))
def train_on_one_client(model, batches):
  return tff.sequence_reduce(batches, model, train_on_one_batch)

Veamos que funciona:

sample_dataset = [random_batch() for _ in range(100)]
trained_model = train_on_one_client(initial_model, sample_dataset)
print(average_loss(initial_model, sample_dataset))
print(average_loss(trained_model, sample_dataset))
2.3025854
2.2284968

El cálculo que realiza una sola ronda de entrenamiento se parece al que pudo haber visto en los tutoriales de TensorFlow:

@tff.federated_computation(
    tff.FederatedType(MODEL_TYPE, tff.SERVER),
    tff.FederatedType(tff.SequenceType(BATCH_TYPE), tff.CLIENTS))
def train_one_round(model, federated_data):
  locally_trained_models = tff.federated_map(
      train_on_one_client,
      collections.OrderedDict([
          ('model', tff.federated_broadcast(model)),
          ('batches', federated_data)]))
  return tff.federated_mean(locally_trained_models)

Veamos que funciona:

trained_model = train_one_round(initial_model, train_data)
print(average_loss(initial_model, eval_data))
print(average_loss(trained_model, eval_data))
2.3025854
2.282762

Como puede ver, usar JAX en TFF, ya sea a través de API enlatadas o directamente usando las construcciones TFF de bajo nivel, es similar a usar TFF con TensorFlow. Esté atento a las actualizaciones futuras, y si desea ver un mejor soporte para la interoperabilidad en los marcos de ML, no dude en enviarnos una solicitud de extracción.