Есть вопрос? Присоединяйтесь к сообществу на форуме TensorFlow. Посетите форум.

Экспериментальная поддержка JAX в TFF

Посмотреть на TensorFlow.org Запускаем в Google Colab Посмотреть на GitHub Скачать блокнот

Помимо того, что он является частью экосистемы TensorFlow, TFF стремится обеспечить взаимодействие с другими внешними и внутренними платформами машинного обучения. На данный момент поддержка других фреймворков машинного обучения все еще находится на стадии инкубации, а API и поддерживаемые функции могут измениться (в основном в зависимости от спроса со стороны пользователей TFF). В этом руководстве описывается, как использовать TFF с JAX в качестве альтернативного внешнего интерфейса ML и компилятора XLA в качестве альтернативного внутреннего интерфейса. Показанные здесь примеры основаны на полностью нативном стеке JAX / XLA, сквозном. Возможность смешивания кода из разных фреймворков (например, JAX с TensorFlow) будет обсуждаться в одном из будущих руководств.

Как всегда, мы приветствуем ваш вклад. Если для вас важна поддержка JAX / XLA или возможность взаимодействия с другими фреймворками машинного обучения, подумайте о том, чтобы помочь нам развить эти возможности для обеспечения паритета с остальной частью TFF.

Прежде чем мы начнем

Пожалуйста, обратитесь к основной части документации TFF, чтобы узнать, как настроить вашу среду. В зависимости от того, где вы запускаете это руководство, вы можете раскомментировать и запустить часть или весь код ниже.

# !pip install --quiet --upgrade tensorflow-federated-nightly
# !pip install --quiet --upgrade nest-asyncio
# import nest_asyncio
# nest_asyncio.apply()

В этом руководстве также предполагается, что вы ознакомились с основными руководствами TFF по TensorFlow и знакомы с основными концепциями TFF. Если вы еще этого не сделали, рассмотрите возможность рассмотрения хотя бы одного из них.

Вычисления JAX

Поддержка JAX в TFF разработана так, чтобы быть симметричной тому, как TFF взаимодействует с TensorFlow, начиная с импорта:

import jax
import numpy as np
import tensorflow_federated as tff

Кроме того, как и в случае с TensorFlow, основой для выражения любого кода TFF является локальная логика. Вы можете выразить эту логику в JAX, как показано ниже, с @tff.experimental.jax_computation оболочки @tff.experimental.jax_computation . Он ведет себя аналогично @tff.tf_computation , с которым вы уже знакомы. Начнем с чего-нибудь простого, например, с вычисления, которое складывает два целых числа:

@tff.experimental.jax_computation(np.int32, np.int32)
def add_numbers(x, y):
  return jax.numpy.add(x, y)

Вы можете использовать вычисление JAX, определенное выше, так же, как вы обычно используете вычисление TFF. Например, вы можете проверить его подпись типа следующим образом:

str(add_numbers.type_signature)
'(<x=int32,y=int32> -> int32)'

Обратите внимание, что мы использовали np.int32 для определения типа аргументов. TFF не различает типы Numpy (например, np.int32 ) и тип tf.int32 (например, tf.int32 ). С точки зрения TFF, это просто способы обозначить одно и то же.

Теперь помните, что TFF - это не Python (и если это не повод для беспокойства, просмотрите некоторые из наших предыдущих руководств, например, о пользовательских алгоритмах). Вы можете использовать оболочку @tff.experimental.jax_computation с любым JAX-кодом, который можно отследить и сериализовать, то есть с кодом, который вы обычно аннотируете с помощью @jax.jit ожидается, будет скомпилирован в XLA (но вам не нужно на самом деле используйте аннотацию @jax.jit для встраивания вашего JAX-кода в TFF).

Действительно, под капотом TFF ​​мгновенно компилирует вычисления JAX в XLA. Вы можете проверить это сами, вручную извлекая и распечатав сериализованный код add_numbers из add_numbers , как add_numbers ниже:

comp_pb = tff.framework.serialize_computation(add_numbers)
comp_pb.WhichOneof('computation')
'xla'
xla_code = jax.lib.xla_client.XlaComputation(comp_pb.xla.hlo_module.value)
print(xla_code.as_hlo_text())
HloModule xla_computation_add_numbers.7

ENTRY xla_computation_add_numbers.7 {
  constant.4 = pred[] constant(false)
  parameter.1 = (s32[], s32[]) parameter(0)
  get-tuple-element.2 = s32[] get-tuple-element(parameter.1), index=0
  get-tuple-element.3 = s32[] get-tuple-element(parameter.1), index=1
  add.5 = s32[] add(get-tuple-element.2, get-tuple-element.3)
  ROOT tuple.6 = (s32[]) tuple(add.5)
}

Представление вычислений JAX в виде кода XLA можно рассматривать как функциональный эквивалент tf.GraphDef для вычислений, выраженных в TensorFlow. Он переносимый и исполняемый в различных средах, поддерживающих XLA, точно так же, как tf.GraphDef может выполняться в любой среде выполнения TensorFlow.

TFF предоставляет стек среды выполнения на основе компилятора XLA в качестве бэкэнда. Активировать его можно следующим образом:

tff.backends.xla.set_local_execution_context()

Теперь вы можете выполнить вычисление, которое мы определили выше:

add_numbers(2, 3)
5

Достаточно просто. Пойдем с ударом и сделаем что-нибудь посложнее, например MNIST.

Пример обучения MNIST с помощью стандартного API

Как обычно, мы начинаем с определения группы типов TFF для пакетов данных и для модели (помните, TFF - это строго типизированная структура).

import collections

BATCH_TYPE = collections.OrderedDict([
    ('pixels', tff.TensorType(np.float32, (50, 784))),
    ('labels', tff.TensorType(np.int32, (50,)))
])

MODEL_TYPE = collections.OrderedDict([
    ('weights', tff.TensorType(np.float32, (784, 10))),
    ('bias', tff.TensorType(np.float32, (10,)))
])

Теперь давайте определим функцию потерь для модели в JAX, взяв модель и отдельный пакет данных в качестве параметра:

def loss(model, batch):
  y = jax.nn.softmax(
      jax.numpy.add(
          jax.numpy.matmul(batch['pixels'], model['weights']), model['bias']))
  targets = jax.nn.one_hot(jax.numpy.reshape(batch['labels'], -1), 10)
  return -jax.numpy.mean(jax.numpy.sum(targets * jax.numpy.log(y), axis=1))

Теперь один из способов - использовать готовый API. Вот пример того, как вы можете использовать наш API для создания процесса обучения на основе только что определенной функции потерь.

STEP_SIZE = 0.001

trainer = tff.experimental.learning.build_jax_federated_averaging_process(
    BATCH_TYPE, MODEL_TYPE, loss, STEP_SIZE)

Вы можете использовать это так же, как сборку тренера из модели tf.Keras в TensorFlow. Например, вот как вы можете создать начальную модель для обучения:

initial_model = trainer.initialize()
initial_model
Struct([('weights', array([[0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       ...,
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.]], dtype=float32)), ('bias', array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32))])

Чтобы выполнить собственное обучение, нам нужны некоторые данные. Давайте сделаем случайные данные, чтобы было проще. Поскольку данные являются случайными, мы собираемся оценивать данные обучения, поскольку в противном случае со случайными данными eval было бы трудно ожидать, что модель будет работать. Кроме того, в этой небольшой демонстрации мы не будем беспокоиться о случайной выборке клиентов (мы оставляем это упражнение для пользователя, чтобы изучить эти типы изменений, следуя шаблонам из других руководств):

def random_batch():
  pixels = np.random.uniform(
      low=0.0, high=1.0, size=(50, 784)).astype(np.float32)
  labels = np.random.randint(low=0, high=9, size=(50,), dtype=np.int32)
  return collections.OrderedDict([('pixels', pixels), ('labels', labels)])

NUM_CLIENTS = 2
NUM_BATCHES = 10

train_data = [
    [random_batch() for _ in range(NUM_BATCHES)]
    for _ in range(NUM_CLIENTS)]

При этом мы можем выполнить один шаг обучения следующим образом:

trained_model = trainer.next(initial_model, train_data)
trained_model
Struct([('weights', array([[ 1.04456245e-04, -1.53498477e-05,  2.54597180e-05, ...,
         5.61640409e-05, -5.32875274e-05, -4.62881755e-04],
       [ 7.30908650e-05,  4.67643113e-05,  2.03352147e-06, ...,
         3.77510623e-05,  3.52839161e-05, -4.59865667e-04],
       [ 8.14835730e-05,  3.03147244e-05, -1.89143739e-05, ...,
         1.12527239e-04,  4.09212225e-06, -4.59960109e-04],
       ...,
       [ 9.23552434e-05,  2.44302555e-06, -2.20817346e-05, ...,
         7.61375341e-05,  1.76906979e-05, -4.43495519e-04],
       [ 1.17451040e-04,  2.47748958e-05,  1.04728279e-05, ...,
         5.26388249e-07,  7.21131510e-05, -4.67137404e-04],
       [ 3.75041491e-05,  6.58061981e-05,  1.14522081e-05, ...,
         2.52584141e-05,  3.55410739e-05, -4.30888613e-04]], dtype=float32)), ('bias', array([ 1.5096272e-04,  2.6502126e-05, -1.9462314e-05,  8.1269856e-05,
        2.1832302e-04,  1.6636557e-04,  1.2815947e-04,  9.0642272e-05,
        7.7109929e-05, -9.1987278e-04], dtype=float32))])

Оценим результат обучающего шага. Чтобы упростить задачу, мы можем оценить ее централизованно:

import itertools
eval_data = list(itertools.chain.from_iterable(train_data))

def average_loss(model, data):
  return np.mean([loss(model, batch) for batch in data])

print (average_loss(initial_model, eval_data))
print (average_loss(trained_model, eval_data))
2.3025854
2.282762

Убыток уменьшается. Большой! Теперь давайте запустим это в несколько раундов:

NUM_ROUNDS = 20
for _ in range(NUM_ROUNDS):
  trained_model = trainer.next(trained_model, train_data)
  print(average_loss(trained_model, eval_data))
2.2685437
2.257856
2.2495182
2.2428129
2.2372835
2.2326245
2.2286277
2.2251441
2.2220676
2.219318
2.2168345
2.2145717
2.2124937
2.2105706
2.2087805
2.2071042
2.2055268
2.2040353
2.2026198
2.2012706

Как видите, использование JAX с TFF не сильно отличается, хотя экспериментальные API-интерфейсы еще не соответствуют функциональным возможностям API-интерфейсов TensorFlow.

Под капотом

Если вы предпочитаете не использовать наш стандартный API, вы можете реализовать свои собственные вычисления, во многом так же, как вы видели это в обучающих материалах по настраиваемым алгоритмам для TensorFlow, за исключением того, что вы будете использовать механизм JAX для градиентного спуска. Например, ниже показано, как вы можете определить вычисление JAX, которое обновляет модель в одном мини-пакете:

@tff.experimental.jax_computation(MODEL_TYPE, BATCH_TYPE)
def train_on_one_batch(model, batch):
  grads = jax.api.grad(loss)(model, batch)
  return collections.OrderedDict([
      (k, model[k] - STEP_SIZE * grads[k]) for k in ['weights', 'bias']
  ])

Вот как вы можете проверить, что это работает:

sample_batch = random_batch()
trained_model = train_on_one_batch(initial_model, sample_batch)
print(average_loss(initial_model, [sample_batch]))
print(average_loss(trained_model, [sample_batch]))
2.3025854
2.2977567

Одно из предостережений при работе с JAX заключается в том, что он не предлагает эквивалентаtf.data.Dataset . Таким образом, чтобы перебирать наборы данных, вам нужно будет использовать декларативные конструкции TFF для операций с последовательностями, например показанную ниже:

@tff.federated_computation(MODEL_TYPE, tff.SequenceType(BATCH_TYPE))
def train_on_one_client(model, batches):
  return tff.sequence_reduce(batches, model, train_on_one_batch)

Посмотрим, что работает:

sample_dataset = [random_batch() for _ in range(100)]
trained_model = train_on_one_client(initial_model, sample_dataset)
print(average_loss(initial_model, sample_dataset))
print(average_loss(trained_model, sample_dataset))
2.3025854
2.2284968

Вычисление, которое выполняет один раунд обучения, выглядит так же, как то, которое вы, возможно, видели в учебниках TensorFlow:

@tff.federated_computation(
    tff.FederatedType(MODEL_TYPE, tff.SERVER),
    tff.FederatedType(tff.SequenceType(BATCH_TYPE), tff.CLIENTS))
def train_one_round(model, federated_data):
  locally_trained_models = tff.federated_map(
      train_on_one_client,
      collections.OrderedDict([
          ('model', tff.federated_broadcast(model)),
          ('batches', federated_data)]))
  return tff.federated_mean(locally_trained_models)

Посмотрим, что работает:

trained_model = train_one_round(initial_model, train_data)
print(average_loss(initial_model, eval_data))
print(average_loss(trained_model, eval_data))
2.3025854
2.282762

Как видите, использование JAX в TFF, будь то через стандартные API-интерфейсы или напрямую с использованием низкоуровневых конструкций TFF, аналогично использованию TFF с TensorFlow. Следите за обновлениями в будущем, и если вы хотите увидеть лучшую поддержку взаимодействия между фреймворками машинного обучения, не стесняйтесь присылать нам запрос на включение!