Support expérimental pour JAX dans TFF

Voir sur TensorFlow.org Exécuter dans Google Colab Voir sur GitHub Télécharger le cahier

En plus de faire partie de l'écosystème TensorFlow, TFF vise à permettre l'interopérabilité avec d'autres frameworks de ML frontend et backend. Pour le moment, la prise en charge d'autres frameworks de ML est encore en phase d'incubation, et les API et les fonctionnalités prises en charge peuvent changer (en grande partie en fonction de la demande des utilisateurs de TFF). Ce tutoriel décrit comment utiliser TFF avec JAX comme interface ML alternative et le compilateur XLA comme backend alternatif. Les exemples présentés ici sont basés sur une pile JAX/XLA entièrement native, de bout en bout. La possibilité de mélanger du code entre des frameworks (par exemple, JAX avec TensorFlow) sera abordée dans l'un des futurs didacticiels.

Comme toujours, nous apprécions vos contributions. Si la prise en charge de JAX/XLA ou la capacité d'interagir avec d'autres frameworks de ML est importante pour vous, veuillez envisager de nous aider à faire évoluer ces capacités vers la parité avec le reste de TFF.

Avant que nous commencions

Veuillez consulter le corps principal de la documentation TFF pour savoir comment configurer votre environnement. Selon l'endroit où vous exécutez ce didacticiel, vous pouvez supprimer les commentaires et exécuter tout ou partie du code ci-dessous.

# !pip install --quiet --upgrade tensorflow-federated-nightly
# !pip install --quiet --upgrade nest-asyncio
# import nest_asyncio
# nest_asyncio.apply()

Ce didacticiel suppose également que vous avez examiné les principaux didacticiels TensorFlow de TFF et que vous connaissez les concepts de base de TFF. Si vous ne l'avez pas encore fait, pensez à en examiner au moins un.

calculs JAX

La prise en charge de JAX dans TFF est conçue pour être symétrique à la manière dont TFF interagit avec TensorFlow, en commençant par les importations :

import jax
import numpy as np
import tensorflow_federated as tff

De plus, tout comme avec TensorFlow, la base de l'expression de tout code TFF est la logique qui s'exécute localement. Vous pouvez exprimer cette logique dans JAX, comme indiqué ci - dessous, en utilisant le @tff.experimental.jax_computation emballage. Il se comporte de façon similaire au @tff.tf_computation que maintenant vous êtes au courant. Commençons par quelque chose de simple, par exemple, un calcul qui ajoute deux nombres entiers :

@tff.experimental.jax_computation(np.int32, np.int32)
def add_numbers(x, y):
  return jax.numpy.add(x, y)

Vous pouvez utiliser le calcul JAX défini ci-dessus comme vous utiliseriez normalement un calcul TFF. Par exemple, vous pouvez vérifier sa signature de type, comme suit :

str(add_numbers.type_signature)
'(<x=int32,y=int32> -> int32)'

Notez que nous avons utilisé np.int32 pour définir le type d'arguments. TFF ne fait pas de distinction entre les types NumPy (tels que np.int32 ) et le type tensorflow (comme tf.int32 ). Du point de vue de TFF, ce ne sont que des façons de faire référence à la même chose.

Maintenant, rappelez-vous que TFF n'est pas Python (et si cela ne vous dit rien, veuillez consulter certains de nos didacticiels précédents, par exemple sur les algorithmes personnalisés). Vous pouvez utiliser le @tff.experimental.jax_computation emballage avec tout JAX code qui peut être tracé et sérialisés, ie avec le code que vous le feriez normalement annoter avec @jax.jit devrait être compilé dans XLA (mais vous n'avez pas besoin utiliser réellement le @jax.jit annotation à intégrer votre code dans JAX TFF).

En effet, sous le capot, TFF compile instantanément les calculs JAX en XLA. Vous pouvez le vérifier vous - même en extraire manuellement et imprimer le code sérialisé XLA à partir add_numbers , comme suit:

comp_pb = tff.framework.serialize_computation(add_numbers)
comp_pb.WhichOneof('computation')
'xla'
xla_code = jax.lib.xla_client.XlaComputation(comp_pb.xla.hlo_module.value)
print(xla_code.as_hlo_text())
HloModule xla_computation_add_numbers.7

ENTRY xla_computation_add_numbers.7 {
  constant.4 = pred[] constant(false)
  parameter.1 = (s32[], s32[]) parameter(0)
  get-tuple-element.2 = s32[] get-tuple-element(parameter.1), index=0
  get-tuple-element.3 = s32[] get-tuple-element(parameter.1), index=1
  add.5 = s32[] add(get-tuple-element.2, get-tuple-element.3)
  ROOT tuple.6 = (s32[]) tuple(add.5)
}

Pensez à la représentation des calculs JAX que le code XLA comme étant l'équivalent fonctionnel de tf.GraphDef pour les calculs exprimés en tensorflow. Il est portable et exécutable dans une variété d'environnements qui prennent en charge XLA, tout comme le tf.GraphDef peut être exécuté sur une exécution tensorflow.

TFF fournit une pile d'exécution basée sur le compilateur XLA en tant que backend. Vous pouvez l'activer comme suit :

tff.backends.xla.set_local_python_execution_context()

Maintenant, vous pouvez exécuter le calcul que nous avons défini ci-dessus :

add_numbers(2, 3)
5

Assez facile. Allons avec le coup et faisons quelque chose de plus compliqué, comme MNIST.

Exemple de formation MNIST avec API standardisée

Comme d'habitude, nous commençons par définir un tas de types TFF pour les lots de données, et pour le modèle (rappelez-vous, TFF est un framework fortement typé).

import collections

BATCH_TYPE = collections.OrderedDict([
    ('pixels', tff.TensorType(np.float32, (50, 784))),
    ('labels', tff.TensorType(np.int32, (50,)))
])

MODEL_TYPE = collections.OrderedDict([
    ('weights', tff.TensorType(np.float32, (784, 10))),
    ('bias', tff.TensorType(np.float32, (10,)))
])

Maintenant, définissons une fonction de perte pour le modèle dans JAX, en prenant le modèle et un seul lot de données comme paramètre :

def loss(model, batch):
  y = jax.nn.softmax(
      jax.numpy.add(
          jax.numpy.matmul(batch['pixels'], model['weights']), model['bias']))
  targets = jax.nn.one_hot(jax.numpy.reshape(batch['labels'], -1), 10)
  return -jax.numpy.mean(jax.numpy.sum(targets * jax.numpy.log(y), axis=1))

Maintenant, une solution consiste à utiliser une API standardisée. Voici un exemple de la façon dont vous pouvez utiliser notre API pour créer un processus de formation basé sur la fonction de perte que vous venez de définir.

STEP_SIZE = 0.001

trainer = tff.experimental.learning.build_jax_federated_averaging_process(
    BATCH_TYPE, MODEL_TYPE, loss, STEP_SIZE)

Vous pouvez utiliser ce qui précède comme vous utilisez une version d'un formateur tf.Keras modèle tensorflow. Par exemple, voici comment vous pouvez créer le modèle initial pour la formation :

initial_model = trainer.initialize()
initial_model
Struct([('weights', array([[0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       ...,
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.]], dtype=float32)), ('bias', array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32))])

Afin d'effectuer une formation réelle, nous avons besoin de certaines données. Faisons des données aléatoires pour rester simple. Étant donné que les données sont aléatoires, nous allons évaluer sur les données d'entraînement, car sinon, avec des données d'évaluation aléatoires, il serait difficile de s'attendre à ce que le modèle fonctionne. De plus, pour cette démo à petite échelle, nous ne nous soucierons pas de l'échantillonnage aléatoire des clients (nous laissons à l'utilisateur le soin d'explorer ces types de modifications en suivant les modèles d'autres tutoriels) :

def random_batch():
  pixels = np.random.uniform(
      low=0.0, high=1.0, size=(50, 784)).astype(np.float32)
  labels = np.random.randint(low=0, high=9, size=(50,), dtype=np.int32)
  return collections.OrderedDict([('pixels', pixels), ('labels', labels)])

NUM_CLIENTS = 2
NUM_BATCHES = 10

train_data = [
    [random_batch() for _ in range(NUM_BATCHES)]
    for _ in range(NUM_CLIENTS)]

Avec cela, nous pouvons effectuer une seule étape de formation, comme suit :

trained_model = trainer.next(initial_model, train_data)
trained_model
Struct([('weights', array([[ 1.04456245e-04, -1.53498477e-05,  2.54597180e-05, ...,
         5.61640409e-05, -5.32875274e-05, -4.62881755e-04],
       [ 7.30908650e-05,  4.67643113e-05,  2.03352147e-06, ...,
         3.77510623e-05,  3.52839161e-05, -4.59865667e-04],
       [ 8.14835730e-05,  3.03147244e-05, -1.89143739e-05, ...,
         1.12527239e-04,  4.09212225e-06, -4.59960109e-04],
       ...,
       [ 9.23552434e-05,  2.44302555e-06, -2.20817346e-05, ...,
         7.61375341e-05,  1.76906979e-05, -4.43495519e-04],
       [ 1.17451040e-04,  2.47748958e-05,  1.04728279e-05, ...,
         5.26388249e-07,  7.21131510e-05, -4.67137404e-04],
       [ 3.75041491e-05,  6.58061981e-05,  1.14522081e-05, ...,
         2.52584141e-05,  3.55410739e-05, -4.30888613e-04]], dtype=float32)), ('bias', array([ 1.5096272e-04,  2.6502126e-05, -1.9462314e-05,  8.1269856e-05,
        2.1832302e-04,  1.6636557e-04,  1.2815947e-04,  9.0642272e-05,
        7.7109929e-05, -9.1987278e-04], dtype=float32))])

Évaluons le résultat de l'étape d'entraînement. Pour simplifier les choses, nous pouvons l'évaluer de manière centralisée :

import itertools
eval_data = list(itertools.chain.from_iterable(train_data))

def average_loss(model, data):
  return np.mean([loss(model, batch) for batch in data])

print (average_loss(initial_model, eval_data))
print (average_loss(trained_model, eval_data))
2.3025854
2.282762

La perte diminue. Super! Maintenant, exécutons ceci sur plusieurs tours :

NUM_ROUNDS = 20
for _ in range(NUM_ROUNDS):
  trained_model = trainer.next(trained_model, train_data)
  print(average_loss(trained_model, eval_data))
2.2685437
2.257856
2.2495182
2.2428129
2.2372835
2.2326245
2.2286277
2.2251441
2.2220676
2.219318
2.2168345
2.2145717
2.2124937
2.2105706
2.2087805
2.2071042
2.2055268
2.2040353
2.2026198
2.2012706

Comme vous le voyez, l'utilisation de JAX avec TFF n'est pas si différente, bien que les API expérimentales ne soient pas encore à la hauteur des fonctionnalités des API TensorFlow.

Sous la capuche

Si vous préférez ne pas utiliser notre API standardisée, vous pouvez implémenter vos propres calculs personnalisés, de la même manière que vous l'avez vu dans les didacticiels sur les algorithmes personnalisés pour TensorFlow, sauf que vous utiliserez le mécanisme de JAX pour la descente de gradient. Par exemple, voici comment vous pouvez définir un calcul JAX qui met à jour le modèle sur un seul mini-lot :

@tff.experimental.jax_computation(MODEL_TYPE, BATCH_TYPE)
def train_on_one_batch(model, batch):
  grads = jax.grad(loss)(model, batch)
  return collections.OrderedDict([
      (k, model[k] - STEP_SIZE * grads[k]) for k in ['weights', 'bias']
  ])

Voici comment vous pouvez tester que cela fonctionne :

sample_batch = random_batch()
trained_model = train_on_one_batch(initial_model, sample_batch)
print(average_loss(initial_model, [sample_batch]))
print(average_loss(trained_model, [sample_batch]))
2.3025854
2.2977567

Une mise en garde de travailler avec JAX est qu'il ne propose pas l'équivalent de tf.data.Dataset . Ainsi, afin d'itérer sur des ensembles de données, vous devrez utiliser les constructions déclaratives de TFF pour les opérations sur les séquences, comme celle illustrée ci-dessous :

@tff.federated_computation(MODEL_TYPE, tff.SequenceType(BATCH_TYPE))
def train_on_one_client(model, batches):
  return tff.sequence_reduce(batches, model, train_on_one_batch)

Voyons que cela fonctionne :

sample_dataset = [random_batch() for _ in range(100)]
trained_model = train_on_one_client(initial_model, sample_dataset)
print(average_loss(initial_model, sample_dataset))
print(average_loss(trained_model, sample_dataset))
2.3025854
2.2284968

Le calcul qui effectue un seul cycle d'entraînement ressemble à celui que vous avez peut-être vu dans les didacticiels TensorFlow :

@tff.federated_computation(
    tff.FederatedType(MODEL_TYPE, tff.SERVER),
    tff.FederatedType(tff.SequenceType(BATCH_TYPE), tff.CLIENTS))
def train_one_round(model, federated_data):
  locally_trained_models = tff.federated_map(
      train_on_one_client,
      collections.OrderedDict([
          ('model', tff.federated_broadcast(model)),
          ('batches', federated_data)]))
  return tff.federated_mean(locally_trained_models)

Voyons que cela fonctionne :

trained_model = train_one_round(initial_model, train_data)
print(average_loss(initial_model, eval_data))
print(average_loss(trained_model, eval_data))
2.3025854
2.282762

Comme vous le voyez, l'utilisation de JAX dans TFF, que ce soit via des API prédéfinies ou directement en utilisant les constructions TFF de bas niveau, est similaire à l'utilisation de TFF avec TensorFlow. Restez à l'écoute des futures mises à jour, et si vous souhaitez voir une meilleure prise en charge de l'interopérabilité entre les frameworks de ML, n'hésitez pas à nous envoyer une pull request !