Google I/O — это обертка! Наверстать упущенное в сеансах TensorFlow Просмотреть сеансы

Экспериментальная поддержка JAX в TFF

Посмотреть на TensorFlow.org Запускаем в Google Colab Посмотреть на GitHub Скачать блокнот

Помимо того, что он является частью экосистемы TensorFlow, TFF направлен на обеспечение взаимодействия с другими внешними и внутренними платформами машинного обучения. На данный момент поддержка других фреймворков машинного обучения все еще находится на стадии инкубации, и поддерживаемые API и функциональность могут измениться (в основном в зависимости от спроса со стороны пользователей TFF). В этом руководстве описывается, как использовать TFF с JAX в качестве альтернативного внешнего интерфейса ML и компилятор XLA в качестве альтернативного внутреннего интерфейса. Показанные здесь примеры основаны на полностью нативном стеке JAX / XLA, сквозном. Возможность смешивания кода из разных фреймворков (например, JAX с TensorFlow) будет обсуждаться в одном из будущих руководств.

Как всегда, мы приветствуем ваш вклад. Если для вас важна поддержка JAX / XLA или возможность взаимодействия с другими фреймворками машинного обучения, подумайте о том, чтобы помочь нам развить эти возможности для обеспечения паритета с остальной частью TFF.

Прежде, чем мы начнем

Пожалуйста, обратитесь к основной части документации TFF, чтобы узнать, как настроить вашу среду. В зависимости от того, где вы запускаете это руководство, вы можете раскомментировать и запустить часть или весь приведенный ниже код.

# !pip install --quiet --upgrade tensorflow-federated-nightly
# !pip install --quiet --upgrade nest-asyncio
# import nest_asyncio
# nest_asyncio.apply()

В этом руководстве также предполагается, что вы ознакомились с основными руководствами TFF по TensorFlow и знакомы с основными концепциями TFF. Если вы еще этого не сделали, рассмотрите возможность рассмотрения хотя бы одного из них.

Вычисления JAX

Поддержка JAX в TFF разработана таким образом, чтобы быть симметричной тому, каким образом TFF ​​взаимодействует с TensorFlow, начиная с импорта:

import jax
import numpy as np
import tensorflow_federated as tff

Кроме того, как и в случае с TensorFlow, основой для выражения любого кода TFF является локальная логика. Вы можете выразить эту логику в JAX, как показано ниже, с помощью @tff.experimental.jax_computation обертки. Он ведет себя подобно @tff.tf_computation , что теперь ваши знакомы. Начнем с чего-нибудь простого, например, с вычисления, которое складывает два целых числа:

@tff.experimental.jax_computation(np.int32, np.int32)
def add_numbers(x, y):
  return jax.numpy.add(x, y)

Вы можете использовать вычисление JAX, определенное выше, так же, как вы обычно используете вычисление TFF. Например, вы можете проверить его подпись типа следующим образом:

str(add_numbers.type_signature)
'(<x=int32,y=int32> -> int32)'

Обратите внимание , что мы использовали np.int32 для определения типа аргументов. ПТФ не делает различия между типами Numpy (например, np.int32 ) и типом TensorFlow (например, tf.int32 ). С точки зрения TFF, это просто способы обозначить одно и то же.

Теперь помните, что TFF - это не Python (и если это не повод для беспокойства, просмотрите некоторые из наших предыдущих руководств, например, о пользовательских алгоритмах). Вы можете использовать @tff.experimental.jax_computation обертку с любой JAX код , который может быть прослежен и сериализован, то есть, с кодом , который вы бы нормально аннотировать с @jax.jit должен быть скомпилирован в XLA (но не нужно на самом деле использовать @jax.jit аннотаций , чтобы вставлять код JAX в TFF).

Действительно, под капотом TFF ​​мгновенно компилирует вычисления JAX в XLA. Вы можете проверить это сами, вручную извлекая и печать сериализован XLA код из add_numbers , следующим образом :

comp_pb = tff.framework.serialize_computation(add_numbers)
comp_pb.WhichOneof('computation')
'xla'
xla_code = jax.lib.xla_client.XlaComputation(comp_pb.xla.hlo_module.value)
print(xla_code.as_hlo_text())
HloModule xla_computation_add_numbers.7

ENTRY xla_computation_add_numbers.7 {
  constant.4 = pred[] constant(false)
  parameter.1 = (s32[], s32[]) parameter(0)
  get-tuple-element.2 = s32[] get-tuple-element(parameter.1), index=0
  get-tuple-element.3 = s32[] get-tuple-element(parameter.1), index=1
  add.5 = s32[] add(get-tuple-element.2, get-tuple-element.3)
  ROOT tuple.6 = (s32[]) tuple(add.5)
}

Подумайте о представлении JAX вычислений как XLA кода как функциональный эквивалент tf.GraphDef для вычислений , выраженных в TensorFlow. Это портативное и исполняемое в различных средах , которые поддерживают XLA, так же , как tf.GraphDef может быть выполнен на любом TensorFlow выполнения.

TFF предоставляет стек среды выполнения на основе компилятора XLA в качестве бэкэнда. Активировать его можно следующим образом:

tff.backends.xla.set_local_python_execution_context()

Теперь вы можете выполнить вычисление, которое мы определили выше:

add_numbers(2, 3)
5

Достаточно просто. Давайте перейдем к делу и сделаем что-нибудь посложнее, например MNIST.

Пример обучения MNIST с помощью стандартного API

Как обычно, мы начинаем с определения группы типов TFF для пакетов данных и для модели (помните, что TFF - это строго типизированная структура).

import collections

BATCH_TYPE = collections.OrderedDict([
    ('pixels', tff.TensorType(np.float32, (50, 784))),
    ('labels', tff.TensorType(np.int32, (50,)))
])

MODEL_TYPE = collections.OrderedDict([
    ('weights', tff.TensorType(np.float32, (784, 10))),
    ('bias', tff.TensorType(np.float32, (10,)))
])

Теперь давайте определим функцию потерь для модели в JAX, взяв модель и отдельный пакет данных в качестве параметра:

def loss(model, batch):
  y = jax.nn.softmax(
      jax.numpy.add(
          jax.numpy.matmul(batch['pixels'], model['weights']), model['bias']))
  targets = jax.nn.one_hot(jax.numpy.reshape(batch['labels'], -1), 10)
  return -jax.numpy.mean(jax.numpy.sum(targets * jax.numpy.log(y), axis=1))

Теперь один из способов - использовать готовый API. Вот пример того, как вы можете использовать наш API для создания процесса обучения на основе только что определенной функции потерь.

STEP_SIZE = 0.001

trainer = tff.experimental.learning.build_jax_federated_averaging_process(
    BATCH_TYPE, MODEL_TYPE, loss, STEP_SIZE)

Вы можете использовать выше так же , как вы бы использовать тренажер сборку из tf.Keras модели в TensorFlow. Например, вот как вы можете создать начальную модель для обучения:

initial_model = trainer.initialize()
initial_model
Struct([('weights', array([[0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       ...,
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.]], dtype=float32)), ('bias', array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32))])

Чтобы выполнить собственное обучение, нам нужны некоторые данные. Давайте сделаем случайные данные, чтобы было проще. Поскольку данные являются случайными, мы собираемся оценивать данные обучения, поскольку в противном случае со случайными данными eval было бы трудно ожидать, что модель будет работать. Кроме того, для этой небольшой демонстрации мы не будем беспокоиться о случайной выборке клиентов (мы оставляем это упражнение для пользователя, чтобы изучить эти типы изменений, следуя шаблонам из других руководств):

def random_batch():
  pixels = np.random.uniform(
      low=0.0, high=1.0, size=(50, 784)).astype(np.float32)
  labels = np.random.randint(low=0, high=9, size=(50,), dtype=np.int32)
  return collections.OrderedDict([('pixels', pixels), ('labels', labels)])

NUM_CLIENTS = 2
NUM_BATCHES = 10

train_data = [
    [random_batch() for _ in range(NUM_BATCHES)]
    for _ in range(NUM_CLIENTS)]

При этом мы можем выполнить один шаг обучения следующим образом:

trained_model = trainer.next(initial_model, train_data)
trained_model
Struct([('weights', array([[ 1.04456245e-04, -1.53498477e-05,  2.54597180e-05, ...,
         5.61640409e-05, -5.32875274e-05, -4.62881755e-04],
       [ 7.30908650e-05,  4.67643113e-05,  2.03352147e-06, ...,
         3.77510623e-05,  3.52839161e-05, -4.59865667e-04],
       [ 8.14835730e-05,  3.03147244e-05, -1.89143739e-05, ...,
         1.12527239e-04,  4.09212225e-06, -4.59960109e-04],
       ...,
       [ 9.23552434e-05,  2.44302555e-06, -2.20817346e-05, ...,
         7.61375341e-05,  1.76906979e-05, -4.43495519e-04],
       [ 1.17451040e-04,  2.47748958e-05,  1.04728279e-05, ...,
         5.26388249e-07,  7.21131510e-05, -4.67137404e-04],
       [ 3.75041491e-05,  6.58061981e-05,  1.14522081e-05, ...,
         2.52584141e-05,  3.55410739e-05, -4.30888613e-04]], dtype=float32)), ('bias', array([ 1.5096272e-04,  2.6502126e-05, -1.9462314e-05,  8.1269856e-05,
        2.1832302e-04,  1.6636557e-04,  1.2815947e-04,  9.0642272e-05,
        7.7109929e-05, -9.1987278e-04], dtype=float32))])

Оценим результат обучающего шага. Чтобы упростить задачу, мы можем оценить ее централизованно:

import itertools
eval_data = list(itertools.chain.from_iterable(train_data))

def average_loss(model, data):
  return np.mean([loss(model, batch) for batch in data])

print (average_loss(initial_model, eval_data))
print (average_loss(trained_model, eval_data))
2.3025854
2.282762

Убыток уменьшается. Здорово! Теперь давайте запустим это в несколько раундов:

NUM_ROUNDS = 20
for _ in range(NUM_ROUNDS):
  trained_model = trainer.next(trained_model, train_data)
  print(average_loss(trained_model, eval_data))
2.2685437
2.257856
2.2495182
2.2428129
2.2372835
2.2326245
2.2286277
2.2251441
2.2220676
2.219318
2.2168345
2.2145717
2.2124937
2.2105706
2.2087805
2.2071042
2.2055268
2.2040353
2.2026198
2.2012706

Как видите, использование JAX с TFF не сильно отличается, хотя экспериментальные API-интерфейсы еще не соответствуют функциональным возможностям API-интерфейсов TensorFlow.

Под капотом

Если вы предпочитаете не использовать наш стандартный API, вы можете реализовать свои собственные вычисления, во многом так же, как вы видели это в учебниках по пользовательским алгоритмам для TensorFlow, за исключением того, что вы будете использовать механизм JAX для градиентного спуска. Например, ниже показано, как вы можете определить вычисление JAX, которое обновляет модель в одном мини-пакете:

@tff.experimental.jax_computation(MODEL_TYPE, BATCH_TYPE)
def train_on_one_batch(model, batch):
  grads = jax.grad(loss)(model, batch)
  return collections.OrderedDict([
      (k, model[k] - STEP_SIZE * grads[k]) for k in ['weights', 'bias']
  ])

Вот как вы можете проверить, что это работает:

sample_batch = random_batch()
trained_model = train_on_one_batch(initial_model, sample_batch)
print(average_loss(initial_model, [sample_batch]))
print(average_loss(trained_model, [sample_batch]))
2.3025854
2.2977567

Один нюанс работы с JAX является то , что он не предлагает эквивалент tf.data.Dataset . Таким образом, чтобы перебирать наборы данных, вам нужно будет использовать декларативные конструкции TFF для операций с последовательностями, например показанную ниже:

@tff.federated_computation(MODEL_TYPE, tff.SequenceType(BATCH_TYPE))
def train_on_one_client(model, batches):
  return tff.sequence_reduce(batches, model, train_on_one_batch)

Посмотрим, что работает:

sample_dataset = [random_batch() for _ in range(100)]
trained_model = train_on_one_client(initial_model, sample_dataset)
print(average_loss(initial_model, sample_dataset))
print(average_loss(trained_model, sample_dataset))
2.3025854
2.2284968

Вычисление, которое выполняет один раунд обучения, выглядит точно так же, как то, которое вы, возможно, видели в учебниках TensorFlow:

@tff.federated_computation(
    tff.FederatedType(MODEL_TYPE, tff.SERVER),
    tff.FederatedType(tff.SequenceType(BATCH_TYPE), tff.CLIENTS))
def train_one_round(model, federated_data):
  locally_trained_models = tff.federated_map(
      train_on_one_client,
      collections.OrderedDict([
          ('model', tff.federated_broadcast(model)),
          ('batches', federated_data)]))
  return tff.federated_mean(locally_trained_models)

Посмотрим, что работает:

trained_model = train_one_round(initial_model, train_data)
print(average_loss(initial_model, eval_data))
print(average_loss(trained_model, eval_data))
2.3025854
2.282762

Как видите, использование JAX в TFF, будь то через стандартные API-интерфейсы или напрямую с использованием низкоуровневых конструкций TFF, аналогично использованию TFF с TensorFlow. Следите за обновлениями в будущем, и если вы хотите увидеть лучшую поддержку взаимодействия между фреймворками машинного обучения, не стесняйтесь отправить нам запрос на включение!