Zapisz datę! Google I / O powraca w dniach 18-20 maja Zarejestruj się teraz
Ta strona została przetłumaczona przez Cloud Translation API.
Switch to English

Eksperymentalne wsparcie dla JAX w TFF

Zobacz na TensorFlow.org Uruchom w Google Colab Wyświetl w serwisie GitHub Pobierz notatnik

Oprócz bycia częścią ekosystemu TensorFlow, TFF ma na celu umożliwienie współdziałania z innymi platformami ML frontend i backend. W tej chwili wsparcie dla innych frameworków ML jest wciąż w fazie inkubacji, a API i obsługiwana funkcjonalność mogą ulec zmianie (w dużej mierze w funkcji popytu ze strony użytkowników TFF). W tym samouczku opisano, jak używać TFF z JAX jako alternatywnej nakładki ML oraz kompilatora XLA jako alternatywnego zaplecza. Przedstawione tutaj przykłady są oparte na całkowicie natywnym stosie JAX / XLA, od początku do końca. Możliwość mieszania kodu w różnych frameworkach (np. JAX z TensorFlow) zostanie omówiona w jednym z przyszłych tutoriali.

Jak zawsze, czekamy na Twój wkład. Jeśli obsługa JAX / XLA lub możliwość współdziałania z innymi frameworkami ML jest dla Ciebie ważna, rozważ pomoc w rozwijaniu tych możliwości w kierunku równorzędności z pozostałą częścią TFF.

Zanim zaczniemy

Zapoznaj się z główną częścią dokumentacji TFF, aby dowiedzieć się, jak skonfigurować środowisko. W zależności od tego, gdzie uruchamiasz ten samouczek, możesz odkomentować i uruchomić część lub całość poniższego kodu.

# !pip install --quiet --upgrade tensorflow-federated-nightly
# !pip install --quiet --upgrade nest-asyncio
# import nest_asyncio
# nest_asyncio.apply()

W tym samouczku założono również, że zapoznałeś się z podstawowymi samouczkami TensorFlow TFF i znasz podstawowe pojęcia dotyczące TFF. Jeśli jeszcze tego nie zrobiłeś, rozważ przejrzenie przynajmniej jednego z nich.

Obliczenia JAX

Obsługa JAX w TFF została zaprojektowana tak, aby była symetryczna ze sposobem, w jaki TFF współpracuje z TensorFlow, począwszy od importu:

import jax
import numpy as np
import tensorflow_federated as tff

Ponadto, podobnie jak w przypadku TensorFlow, podstawą wyrażania dowolnego kodu TFF jest logika działająca lokalnie. Możesz wyrazić tę logikę w JAX, jak pokazano poniżej, używając opakowania @tff.experimental.jax_computation . Zachowuje się podobnie do @tff.tf_computation które już znasz. Zacznijmy od czegoś prostego, np. Obliczenia, które dodaje dwie liczby całkowite:

@tff.experimental.jax_computation(np.int32, np.int32)
def add_numbers(x, y):
  return jax.numpy.add(x, y)

Możesz użyć obliczenia JAX zdefiniowanego powyżej, tak jak normalnie używasz obliczenia TFF. Na przykład możesz sprawdzić jego podpis typu w następujący sposób:

str(add_numbers.type_signature)
'(<x=int32,y=int32> -> int32)'

Zauważ, że użyliśmy np.int32 do zdefiniowania typu argumentów. TFF nie rozróżnia typów Numpy (takich jak np.int32 ) i TensorFlow (takich jak tf.int32 ). Z perspektywy TFF to tylko sposoby na odniesienie się do tego samego.

Teraz pamiętaj, że TFF to nie Python (a jeśli to nie zadziała, przejrzyj niektóre z naszych wcześniejszych tutoriali, np. O niestandardowych algorytmach). Możesz użyć opakowania @tff.experimental.jax_computation z dowolnym kodem JAX, który można prześledzić i serializować, tj. Z kodem, do którego normalnie @jax.jit adnotację @jax.jit powinien zostać skompilowany do XLA (ale nie musisz faktycznie użyj adnotacji @jax.jit aby osadzić swój kod JAX w TFF).

Rzeczywiście, pod maską TFF natychmiast kompiluje obliczenia JAX do XLA. Możesz to sprawdzić samodzielnie, ręcznie wyodrębniając i drukując serializowany kod XLA z add_numbers w następujący sposób:

comp_pb = tff.framework.serialize_computation(add_numbers)
comp_pb.WhichOneof('computation')
'xla'
xla_code = jax.lib.xla_client.XlaComputation(comp_pb.xla.hlo_module.value)
print(xla_code.as_hlo_text())
HloModule xla_computation_add_numbers.7

ENTRY xla_computation_add_numbers.7 {
  constant.4 = pred[] constant(false)
  parameter.1 = (s32[], s32[]) parameter(0)
  get-tuple-element.2 = s32[] get-tuple-element(parameter.1), index=0
  get-tuple-element.3 = s32[] get-tuple-element(parameter.1), index=1
  add.5 = s32[] add(get-tuple-element.2, get-tuple-element.3)
  ROOT tuple.6 = (s32[]) tuple(add.5)
}

Pomyśl o reprezentacji obliczeń JAX jako kodu XLA jako funkcjonalnego odpowiednika tf.GraphDef dla obliczeń wyrażonych w TensorFlow. Jest przenośny i wykonywalny w różnych środowiskach obsługujących XLA, podobnie jak tf.GraphDef może być wykonywany w dowolnym środowisku wykonawczym TensorFlow.

TFF zapewnia stos środowiska uruchomieniowego oparty na kompilatorze XLA jako zaplecze. Możesz go aktywować w następujący sposób:

tff.backends.xla.set_local_execution_context()

Teraz możesz wykonać obliczenia, które zdefiniowaliśmy powyżej:

add_numbers(2, 3)
5

Wystarczająco łatwe. Chodźmy za ciosem i zróbmy coś bardziej skomplikowanego, na przykład MNIST.

Przykład szkolenia MNIST z gotowym interfejsem API

Jak zwykle zaczynamy od zdefiniowania kilku typów TFF dla partii danych i dla modelu (pamiętaj, że TFF jest frameworkiem o silnym typie).

import collections

BATCH_TYPE = collections.OrderedDict([
    ('pixels', tff.TensorType(np.float32, (50, 784))),
    ('labels', tff.TensorType(np.int32, (50,)))
])

MODEL_TYPE = collections.OrderedDict([
    ('weights', tff.TensorType(np.float32, (784, 10))),
    ('bias', tff.TensorType(np.float32, (10,)))
])

Teraz zdefiniujmy funkcję straty dla modelu w JAX, biorąc model i pojedynczą partię danych jako parametr:

def loss(model, batch):
  y = jax.nn.softmax(
      jax.numpy.add(
          jax.numpy.matmul(batch['pixels'], model['weights']), model['bias']))
  targets = jax.nn.one_hot(jax.numpy.reshape(batch['labels'], -1), 10)
  return -jax.numpy.mean(jax.numpy.sum(targets * jax.numpy.log(y), axis=1))

Teraz jednym ze sposobów jest użycie gotowego interfejsu API. Oto przykład, w jaki sposób możesz użyć naszego API do stworzenia procesu szkoleniowego w oparciu o właśnie zdefiniowaną funkcję utraty.

STEP_SIZE = 0.001

trainer = tff.experimental.learning.build_jax_federated_averaging_process(
    BATCH_TYPE, MODEL_TYPE, loss, STEP_SIZE)

Można użyć wyżej tak jak byłoby użyć kompilacji trener z tf.Keras modelu w TensorFlow. Na przykład, oto jak możesz stworzyć początkowy model do treningu:

initial_model = trainer.initialize()
initial_model
Struct([('weights', array([[0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       ...,
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.]], dtype=float32)), ('bias', array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32))])

Aby przeprowadzić właściwe szkolenie, potrzebujemy pewnych danych. Zróbmy losowe dane, aby było to proste. Ponieważ dane są losowe, będziemy oceniać je na danych uczących, ponieważ w przeciwnym razie, w przypadku losowych danych ewaluacyjnych, trudno byłoby oczekiwać, że model będzie działał. Ponadto w przypadku tego demo na małą skalę nie będziemy się martwić losowym próbkowaniem klientów (pozostawiamy to użytkownikowi jako ćwiczenie, aby zbadał tego typu zmiany, postępując zgodnie z szablonami z innych samouczków):

def random_batch():
  pixels = np.random.uniform(
      low=0.0, high=1.0, size=(50, 784)).astype(np.float32)
  labels = np.random.randint(low=0, high=9, size=(50,), dtype=np.int32)
  return collections.OrderedDict([('pixels', pixels), ('labels', labels)])

NUM_CLIENTS = 2
NUM_BATCHES = 10

train_data = [
    [random_batch() for _ in range(NUM_BATCHES)]
    for _ in range(NUM_CLIENTS)]

Dzięki temu możemy wykonać jeden krok treningu w następujący sposób:

trained_model = trainer.next(initial_model, train_data)
trained_model
Struct([('weights', array([[ 1.04456245e-04, -1.53498477e-05,  2.54597180e-05, ...,
         5.61640409e-05, -5.32875274e-05, -4.62881755e-04],
       [ 7.30908650e-05,  4.67643113e-05,  2.03352147e-06, ...,
         3.77510623e-05,  3.52839161e-05, -4.59865667e-04],
       [ 8.14835730e-05,  3.03147244e-05, -1.89143739e-05, ...,
         1.12527239e-04,  4.09212225e-06, -4.59960109e-04],
       ...,
       [ 9.23552434e-05,  2.44302555e-06, -2.20817346e-05, ...,
         7.61375341e-05,  1.76906979e-05, -4.43495519e-04],
       [ 1.17451040e-04,  2.47748958e-05,  1.04728279e-05, ...,
         5.26388249e-07,  7.21131510e-05, -4.67137404e-04],
       [ 3.75041491e-05,  6.58061981e-05,  1.14522081e-05, ...,
         2.52584141e-05,  3.55410739e-05, -4.30888613e-04]], dtype=float32)), ('bias', array([ 1.5096272e-04,  2.6502126e-05, -1.9462314e-05,  8.1269856e-05,
        2.1832302e-04,  1.6636557e-04,  1.2815947e-04,  9.0642272e-05,
        7.7109929e-05, -9.1987278e-04], dtype=float32))])

Oceńmy wynik etapu szkolenia. Aby było to łatwe, możemy to ocenić w sposób scentralizowany:

import itertools
eval_data = list(itertools.chain.from_iterable(train_data))

def average_loss(model, data):
  return np.mean([loss(model, batch) for batch in data])

print (average_loss(initial_model, eval_data))
print (average_loss(trained_model, eval_data))
2.3025854
2.282762

Strata maleje. Świetny! Teraz przeprowadźmy to przez wiele rund:

NUM_ROUNDS = 20
for _ in range(NUM_ROUNDS):
  trained_model = trainer.next(trained_model, train_data)
  print(average_loss(trained_model, eval_data))
2.2685437
2.257856
2.2495182
2.2428129
2.2372835
2.2326245
2.2286277
2.2251441
2.2220676
2.219318
2.2168345
2.2145717
2.2124937
2.2105706
2.2087805
2.2071042
2.2055268
2.2040353
2.2026198
2.2012706

Jak widać, używanie JAX z TFF nie różni się zbytnio, chociaż eksperymentalne interfejsy API nie są jeszcze na równi z funkcjami API TensorFlow.

Pod maską

Jeśli wolisz nie korzystać z naszego gotowego interfejsu API, możesz zaimplementować własne niestandardowe obliczenia, podobnie jak w przypadku samouczków niestandardowych algorytmów dla TensorFlow, z wyjątkiem tego, że będziesz używać mechanizmu JAX do zejścia gradientowego. Na przykład poniżej opisano, w jaki sposób można zdefiniować obliczenia JAX, które aktualizują model w pojedynczej minibatchu:

@tff.experimental.jax_computation(MODEL_TYPE, BATCH_TYPE)
def train_on_one_batch(model, batch):
  grads = jax.api.grad(loss)(model, batch)
  return collections.OrderedDict([
      (k, model[k] - STEP_SIZE * grads[k]) for k in ['weights', 'bias']
  ])

Oto, jak możesz sprawdzić, czy to działa:

sample_batch = random_batch()
trained_model = train_on_one_batch(initial_model, sample_batch)
print(average_loss(initial_model, [sample_batch]))
print(average_loss(trained_model, [sample_batch]))
2.3025854
2.2977567

Jedynym zastrzeżeniem dotyczącym współpracy z JAX jest to, że nie oferuje on odpowiednikatf.data.Dataset . Dlatego, aby iterować po zestawach danych, będziesz musiał użyć deklaratywnych konstrukcji TFF do operacji na sekwencjach, takich jak ta pokazana poniżej:

@tff.federated_computation(MODEL_TYPE, tff.SequenceType(BATCH_TYPE))
def train_on_one_client(model, batches):
  return tff.sequence_reduce(batches, model, train_on_one_batch)

Zobaczmy, że to działa:

sample_dataset = [random_batch() for _ in range(100)]
trained_model = train_on_one_client(initial_model, sample_dataset)
print(average_loss(initial_model, sample_dataset))
print(average_loss(trained_model, sample_dataset))
2.3025854
2.2284968

Obliczenie, które wykonuje pojedynczą rundę szkolenia, wygląda tak samo, jak to, które mogłeś zobaczyć w samouczkach TensorFlow:

@tff.federated_computation(
    tff.FederatedType(MODEL_TYPE, tff.SERVER),
    tff.FederatedType(tff.SequenceType(BATCH_TYPE), tff.CLIENTS))
def train_one_round(model, federated_data):
  locally_trained_models = tff.federated_map(
      train_on_one_client,
      collections.OrderedDict([
          ('model', tff.federated_broadcast(model)),
          ('batches', federated_data)]))
  return tff.federated_mean(locally_trained_models)

Zobaczmy, że to działa:

trained_model = train_one_round(initial_model, train_data)
print(average_loss(initial_model, eval_data))
print(average_loss(trained_model, eval_data))
2.3025854
2.282762

Jak widać, używanie JAX w TFF, czy to za pośrednictwem gotowych interfejsów API, czy bezpośrednio przy użyciu niskopoziomowych konstrukcji TFF, jest podobne do używania TFF z TensorFlow. Bądź na bieżąco z przyszłymi aktualizacjami, a jeśli chcesz zobaczyć lepszą obsługę interoperacyjności we wszystkich platformach ML, wyślij nam żądanie ściągnięcia!