![]() | ![]() | ![]() | ![]() |
Помимо того, что он является частью экосистемы TensorFlow, TFF направлен на обеспечение взаимодействия с другими внешними и внутренними платформами машинного обучения. На данный момент поддержка других фреймворков машинного обучения все еще находится на стадии инкубации, и поддерживаемые API и функциональность могут измениться (в основном в зависимости от спроса со стороны пользователей TFF). В этом руководстве описывается, как использовать TFF с JAX в качестве альтернативного внешнего интерфейса ML и компилятор XLA в качестве альтернативного внутреннего интерфейса. Показанные здесь примеры основаны на полностью нативном стеке JAX / XLA, сквозном. Возможность смешивания кода из разных фреймворков (например, JAX с TensorFlow) будет обсуждаться в одном из будущих руководств.
Как всегда, мы приветствуем ваш вклад. Если для вас важна поддержка JAX / XLA или возможность взаимодействия с другими фреймворками машинного обучения, подумайте о том, чтобы помочь нам развить эти возможности для обеспечения паритета с остальной частью TFF.
Прежде, чем мы начнем
Пожалуйста, обратитесь к основной части документации TFF, чтобы узнать, как настроить вашу среду. В зависимости от того, где вы запускаете это руководство, вы можете раскомментировать и запустить часть или весь приведенный ниже код.
# !pip install --quiet --upgrade tensorflow-federated-nightly
# !pip install --quiet --upgrade nest-asyncio
# import nest_asyncio
# nest_asyncio.apply()
В этом руководстве также предполагается, что вы ознакомились с основными руководствами TFF по TensorFlow и знакомы с основными концепциями TFF. Если вы еще этого не сделали, рассмотрите возможность рассмотрения хотя бы одного из них.
Вычисления JAX
Поддержка JAX в TFF разработана таким образом, чтобы быть симметричной тому, каким образом TFF взаимодействует с TensorFlow, начиная с импорта:
import jax
import numpy as np
import tensorflow_federated as tff
Кроме того, как и в случае с TensorFlow, основой для выражения любого кода TFF является локальная логика. Вы можете выразить эту логику в JAX, как показано ниже, с помощью @tff.experimental.jax_computation
обертки. Он ведет себя подобно @tff.tf_computation
, что теперь ваши знакомы. Начнем с чего-нибудь простого, например, с вычисления, которое складывает два целых числа:
@tff.experimental.jax_computation(np.int32, np.int32)
def add_numbers(x, y):
return jax.numpy.add(x, y)
Вы можете использовать вычисление JAX, определенное выше, так же, как вы обычно используете вычисление TFF. Например, вы можете проверить его подпись типа следующим образом:
str(add_numbers.type_signature)
'(<x=int32,y=int32> -> int32)'
Обратите внимание , что мы использовали np.int32
для определения типа аргументов. ПТФ не делает различия между типами Numpy (например, np.int32
) и типом TensorFlow (например, tf.int32
). С точки зрения TFF, это просто способы обозначить одно и то же.
Теперь помните, что TFF - это не Python (и если это не повод для беспокойства, просмотрите некоторые из наших предыдущих руководств, например, о пользовательских алгоритмах). Вы можете использовать @tff.experimental.jax_computation
обертку с любой JAX код , который может быть прослежен и сериализован, то есть, с кодом , который вы бы нормально аннотировать с @jax.jit
должен быть скомпилирован в XLA (но не нужно на самом деле использовать @jax.jit
аннотаций , чтобы вставлять код JAX в TFF).
Действительно, под капотом TFF мгновенно компилирует вычисления JAX в XLA. Вы можете проверить это сами, вручную извлекая и печать сериализован XLA код из add_numbers
, следующим образом :
comp_pb = tff.framework.serialize_computation(add_numbers)
comp_pb.WhichOneof('computation')
'xla'
xla_code = jax.lib.xla_client.XlaComputation(comp_pb.xla.hlo_module.value)
print(xla_code.as_hlo_text())
HloModule xla_computation_add_numbers.7 ENTRY xla_computation_add_numbers.7 { constant.4 = pred[] constant(false) parameter.1 = (s32[], s32[]) parameter(0) get-tuple-element.2 = s32[] get-tuple-element(parameter.1), index=0 get-tuple-element.3 = s32[] get-tuple-element(parameter.1), index=1 add.5 = s32[] add(get-tuple-element.2, get-tuple-element.3) ROOT tuple.6 = (s32[]) tuple(add.5) }
Подумайте о представлении JAX вычислений как XLA кода как функциональный эквивалент tf.GraphDef
для вычислений , выраженных в TensorFlow. Это портативное и исполняемое в различных средах , которые поддерживают XLA, так же , как tf.GraphDef
может быть выполнен на любом TensorFlow выполнения.
TFF предоставляет стек среды выполнения на основе компилятора XLA в качестве бэкэнда. Активировать его можно следующим образом:
tff.backends.xla.set_local_python_execution_context()
Теперь вы можете выполнить вычисление, которое мы определили выше:
add_numbers(2, 3)
5
Достаточно просто. Давайте перейдем к делу и сделаем что-нибудь посложнее, например MNIST.
Пример обучения MNIST с помощью стандартного API
Как обычно, мы начинаем с определения группы типов TFF для пакетов данных и для модели (помните, что TFF - это строго типизированная структура).
import collections
BATCH_TYPE = collections.OrderedDict([
('pixels', tff.TensorType(np.float32, (50, 784))),
('labels', tff.TensorType(np.int32, (50,)))
])
MODEL_TYPE = collections.OrderedDict([
('weights', tff.TensorType(np.float32, (784, 10))),
('bias', tff.TensorType(np.float32, (10,)))
])
Теперь давайте определим функцию потерь для модели в JAX, взяв модель и отдельный пакет данных в качестве параметра:
def loss(model, batch):
y = jax.nn.softmax(
jax.numpy.add(
jax.numpy.matmul(batch['pixels'], model['weights']), model['bias']))
targets = jax.nn.one_hot(jax.numpy.reshape(batch['labels'], -1), 10)
return -jax.numpy.mean(jax.numpy.sum(targets * jax.numpy.log(y), axis=1))
Теперь один из способов - использовать готовый API. Вот пример того, как вы можете использовать наш API для создания процесса обучения на основе только что определенной функции потерь.
STEP_SIZE = 0.001
trainer = tff.experimental.learning.build_jax_federated_averaging_process(
BATCH_TYPE, MODEL_TYPE, loss, STEP_SIZE)
Вы можете использовать выше так же , как вы бы использовать тренажер сборку из tf.Keras
модели в TensorFlow. Например, вот как вы можете создать начальную модель для обучения:
initial_model = trainer.initialize()
initial_model
Struct([('weights', array([[0., 0., 0., ..., 0., 0., 0.], [0., 0., 0., ..., 0., 0., 0.], [0., 0., 0., ..., 0., 0., 0.], ..., [0., 0., 0., ..., 0., 0., 0.], [0., 0., 0., ..., 0., 0., 0.], [0., 0., 0., ..., 0., 0., 0.]], dtype=float32)), ('bias', array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32))])
Чтобы выполнить собственное обучение, нам нужны некоторые данные. Давайте сделаем случайные данные, чтобы было проще. Поскольку данные являются случайными, мы собираемся оценивать данные обучения, поскольку в противном случае со случайными данными eval было бы трудно ожидать, что модель будет работать. Кроме того, для этой небольшой демонстрации мы не будем беспокоиться о случайной выборке клиентов (мы оставляем это упражнение для пользователя, чтобы изучить эти типы изменений, следуя шаблонам из других руководств):
def random_batch():
pixels = np.random.uniform(
low=0.0, high=1.0, size=(50, 784)).astype(np.float32)
labels = np.random.randint(low=0, high=9, size=(50,), dtype=np.int32)
return collections.OrderedDict([('pixels', pixels), ('labels', labels)])
NUM_CLIENTS = 2
NUM_BATCHES = 10
train_data = [
[random_batch() for _ in range(NUM_BATCHES)]
for _ in range(NUM_CLIENTS)]
При этом мы можем выполнить один шаг обучения следующим образом:
trained_model = trainer.next(initial_model, train_data)
trained_model
Struct([('weights', array([[ 1.04456245e-04, -1.53498477e-05, 2.54597180e-05, ..., 5.61640409e-05, -5.32875274e-05, -4.62881755e-04], [ 7.30908650e-05, 4.67643113e-05, 2.03352147e-06, ..., 3.77510623e-05, 3.52839161e-05, -4.59865667e-04], [ 8.14835730e-05, 3.03147244e-05, -1.89143739e-05, ..., 1.12527239e-04, 4.09212225e-06, -4.59960109e-04], ..., [ 9.23552434e-05, 2.44302555e-06, -2.20817346e-05, ..., 7.61375341e-05, 1.76906979e-05, -4.43495519e-04], [ 1.17451040e-04, 2.47748958e-05, 1.04728279e-05, ..., 5.26388249e-07, 7.21131510e-05, -4.67137404e-04], [ 3.75041491e-05, 6.58061981e-05, 1.14522081e-05, ..., 2.52584141e-05, 3.55410739e-05, -4.30888613e-04]], dtype=float32)), ('bias', array([ 1.5096272e-04, 2.6502126e-05, -1.9462314e-05, 8.1269856e-05, 2.1832302e-04, 1.6636557e-04, 1.2815947e-04, 9.0642272e-05, 7.7109929e-05, -9.1987278e-04], dtype=float32))])
Оценим результат обучающего шага. Чтобы упростить задачу, мы можем оценить ее централизованно:
import itertools
eval_data = list(itertools.chain.from_iterable(train_data))
def average_loss(model, data):
return np.mean([loss(model, batch) for batch in data])
print (average_loss(initial_model, eval_data))
print (average_loss(trained_model, eval_data))
2.3025854 2.282762
Убыток уменьшается. Здорово! Теперь давайте запустим это в несколько раундов:
NUM_ROUNDS = 20
for _ in range(NUM_ROUNDS):
trained_model = trainer.next(trained_model, train_data)
print(average_loss(trained_model, eval_data))
2.2685437 2.257856 2.2495182 2.2428129 2.2372835 2.2326245 2.2286277 2.2251441 2.2220676 2.219318 2.2168345 2.2145717 2.2124937 2.2105706 2.2087805 2.2071042 2.2055268 2.2040353 2.2026198 2.2012706
Как видите, использование JAX с TFF не сильно отличается, хотя экспериментальные API-интерфейсы еще не соответствуют функциональным возможностям API-интерфейсов TensorFlow.
Под капотом
Если вы предпочитаете не использовать наш стандартный API, вы можете реализовать свои собственные вычисления, во многом так же, как вы видели это в учебниках по пользовательским алгоритмам для TensorFlow, за исключением того, что вы будете использовать механизм JAX для градиентного спуска. Например, ниже показано, как вы можете определить вычисление JAX, которое обновляет модель в одном мини-пакете:
@tff.experimental.jax_computation(MODEL_TYPE, BATCH_TYPE)
def train_on_one_batch(model, batch):
grads = jax.grad(loss)(model, batch)
return collections.OrderedDict([
(k, model[k] - STEP_SIZE * grads[k]) for k in ['weights', 'bias']
])
Вот как вы можете проверить, что это работает:
sample_batch = random_batch()
trained_model = train_on_one_batch(initial_model, sample_batch)
print(average_loss(initial_model, [sample_batch]))
print(average_loss(trained_model, [sample_batch]))
2.3025854 2.2977567
Один нюанс работы с JAX является то , что он не предлагает эквивалент tf.data.Dataset
. Таким образом, чтобы перебирать наборы данных, вам нужно будет использовать декларативные конструкции TFF для операций с последовательностями, например показанную ниже:
@tff.federated_computation(MODEL_TYPE, tff.SequenceType(BATCH_TYPE))
def train_on_one_client(model, batches):
return tff.sequence_reduce(batches, model, train_on_one_batch)
Посмотрим, что работает:
sample_dataset = [random_batch() for _ in range(100)]
trained_model = train_on_one_client(initial_model, sample_dataset)
print(average_loss(initial_model, sample_dataset))
print(average_loss(trained_model, sample_dataset))
2.3025854 2.2284968
Вычисление, которое выполняет один раунд обучения, выглядит точно так же, как то, которое вы, возможно, видели в учебниках TensorFlow:
@tff.federated_computation(
tff.FederatedType(MODEL_TYPE, tff.SERVER),
tff.FederatedType(tff.SequenceType(BATCH_TYPE), tff.CLIENTS))
def train_one_round(model, federated_data):
locally_trained_models = tff.federated_map(
train_on_one_client,
collections.OrderedDict([
('model', tff.federated_broadcast(model)),
('batches', federated_data)]))
return tff.federated_mean(locally_trained_models)
Посмотрим, что работает:
trained_model = train_one_round(initial_model, train_data)
print(average_loss(initial_model, eval_data))
print(average_loss(trained_model, eval_data))
2.3025854 2.282762
Как видите, использование JAX в TFF, будь то через стандартные API-интерфейсы или напрямую с использованием низкоуровневых конструкций TFF, аналогично использованию TFF с TensorFlow. Следите за обновлениями в будущем, и если вы хотите увидеть лучшую поддержку взаимодействия между фреймворками машинного обучения, не стесняйтесь отправить нам запрос на включение!