도움말 Kaggle에 TensorFlow과 그레이트 배리어 리프 (Great Barrier Reef)를 보호하기 도전에 참여

TFF에서 JAX에 대한 실험적 지원

TensorFlow.org에서 보기 Google Colab에서 실행 GitHub에서 보기 노트북 다운로드

TensorFlow 생태계의 일부가 되는 것 외에도 TFF는 다른 프론트엔드 및 백엔드 ML 프레임워크와의 상호 운용성을 지원하는 것을 목표로 합니다. 현재 다른 ML 프레임워크에 대한 지원은 아직 인큐베이션 단계에 있으며 지원되는 API 및 기능은 변경될 수 있습니다(대부분 TFF 사용자의 요구에 따라). 이 튜토리얼에서는 JAX와 함께 TFF를 대체 ML 프론트엔드로 사용하고 XLA 컴파일러를 대체 백엔드로 사용하는 방법을 설명합니다. 여기에 표시된 예제는 완전히 기본 JAX/XLA 스택을 기반으로 합니다. 프레임워크(예: JAX와 TensorFlow) 간에 코드를 혼합하는 가능성은 향후 자습서 중 하나에서 논의될 것입니다.

언제나처럼 여러분의 기여를 환영합니다. JAX/XLA에 대한 지원 또는 다른 ML 프레임워크와 상호 운용하는 기능이 중요한 경우 TFF의 나머지 부분과 동등하도록 이러한 기능을 발전시키는 데 도움을 주는 것이 좋습니다.

시작하기 전에

환경을 구성하는 방법은 TFF 문서의 본문을 참조하십시오. 이 자습서를 실행하는 위치에 따라 아래 코드의 일부 또는 전체의 주석 처리를 제거하고 실행할 수 있습니다.

# !pip install --quiet --upgrade tensorflow-federated-nightly
# !pip install --quiet --upgrade nest-asyncio
# import nest_asyncio
# nest_asyncio.apply()

이 튜토리얼은 또한 TFF의 기본 TensorFlow 튜토리얼을 검토했으며 핵심 TFF 개념에 익숙하다고 가정합니다. 아직 이 작업을 수행하지 않았다면 그 중 하나 이상을 검토하는 것이 좋습니다.

JAX 계산

TFF의 JAX 지원은 가져오기부터 시작하여 TFF가 TensorFlow와 상호 운용되는 방식과 대칭되도록 설계되었습니다.

import jax
import numpy as np
import tensorflow_federated as tff

또한 TensorFlow와 마찬가지로 TFF 코드를 표현하기 위한 기반은 로컬에서 실행되는 논리입니다. 사용하여, 아래 그림과 같이, JAX에서이 논리를 표현할 수 @tff.experimental.jax_computation 래퍼. 그것은 유사하게 동작 @tff.tf_computation 지금 당신이 잘 알고있는 것으로. 두 개의 정수를 더하는 계산과 같이 간단한 것부터 시작해 보겠습니다.

@tff.experimental.jax_computation(np.int32, np.int32)
def add_numbers(x, y):
  return jax.numpy.add(x, y)

일반적으로 TFF 계산을 사용하는 것처럼 위에서 정의한 JAX 계산을 사용할 수 있습니다. 예를 들어 다음과 같이 유형 서명을 확인할 수 있습니다.

str(add_numbers.type_signature)
'(<x=int32,y=int32> -> int32)'

우리가 사용합니다 np.int32 인수의 유형을 정의 할 수 있습니다. TFF는 (예 NumPy와 유형을 구분하지 않습니다 np.int32 ) 및 (예 TensorFlow 유형 tf.int32 ). TFF의 관점에서, 그것들은 단지 같은 것을 참조하는 방법일 뿐입니다.

이제 TFF가 Python이 아니라는 것을 기억하십시오(그리고 이것이 벨이 울리지 않는다면, 예를 들어 사용자 정의 알고리즘에 대한 이전 자습서 중 일부를 검토하십시오). 당신은 사용할 수 있습니다 @tff.experimental.jax_computation 코드, 즉, 추적 및 직렬화 할 수있는 래퍼 어떤 JAX 코드를 당신 것으로 일반적으로 주석 @jax.jit (하지만 당신은 할 필요가 없습니다 XLA로 컴파일 될 것으로 예상 실제로 사용 @jax.jit ) TFF에 주석을 삽입하여 JAX 코드를.

실제로 TFF는 JAX 계산을 XLA로 즉시 컴파일합니다. 수동으로 추출에서 직렬화 XLA 코드를 인쇄하여 자신이 확인할 수 있습니다 add_numbers 다음과 같이 :

comp_pb = tff.framework.serialize_computation(add_numbers)
comp_pb.WhichOneof('computation')
'xla'
xla_code = jax.lib.xla_client.XlaComputation(comp_pb.xla.hlo_module.value)
print(xla_code.as_hlo_text())
HloModule xla_computation_add_numbers.7

ENTRY xla_computation_add_numbers.7 {
  constant.4 = pred[] constant(false)
  parameter.1 = (s32[], s32[]) parameter(0)
  get-tuple-element.2 = s32[] get-tuple-element(parameter.1), index=0
  get-tuple-element.3 = s32[] get-tuple-element(parameter.1), index=1
  add.5 = s32[] add(get-tuple-element.2, get-tuple-element.3)
  ROOT tuple.6 = (s32[]) tuple(add.5)
}

의 기능적 등가물로서 XLA 코드로 JAX 계산 표현 생각 tf.GraphDef TensorFlow 표현 연산 용. 그것은처럼, XLA를 지원하는 다양한 환경에서 휴대용 및 실행 가능 tf.GraphDef 어떤 TensorFlow 런타임에서 실행할 수 있습니다.

TFF는 XLA 컴파일러를 기반으로 하는 런타임 스택을 백엔드로 제공합니다. 다음과 같이 활성화할 수 있습니다.

tff.backends.xla.set_local_python_execution_context()

이제 위에서 정의한 계산을 실행할 수 있습니다.

add_numbers(2, 3)
5

충분히 쉽습니다. 타격을 가하고 MNIST와 같은 더 복잡한 작업을 수행해 보겠습니다.

미리 준비된 API를 사용한 MNIST 교육의 예

평소와 같이 데이터 배치와 모델에 대해 TFF 유형을 정의하는 것으로 시작합니다(TFF는 강력한 형식의 프레임워크임을 기억하십시오).

import collections

BATCH_TYPE = collections.OrderedDict([
    ('pixels', tff.TensorType(np.float32, (50, 784))),
    ('labels', tff.TensorType(np.int32, (50,)))
])

MODEL_TYPE = collections.OrderedDict([
    ('weights', tff.TensorType(np.float32, (784, 10))),
    ('bias', tff.TensorType(np.float32, (10,)))
])

이제 모델과 단일 데이터 배치를 매개변수로 사용하여 JAX에서 모델에 대한 손실 함수를 정의해 보겠습니다.

def loss(model, batch):
  y = jax.nn.softmax(
      jax.numpy.add(
          jax.numpy.matmul(batch['pixels'], model['weights']), model['bias']))
  targets = jax.nn.one_hot(jax.numpy.reshape(batch['labels'], -1), 10)
  return -jax.numpy.mean(jax.numpy.sum(targets * jax.numpy.log(y), axis=1))

이제 한 가지 방법은 미리 준비된 API를 사용하는 것입니다. 다음은 방금 정의한 손실 함수를 기반으로 하는 교육 프로세스를 생성하기 위해 API를 사용하는 방법의 예입니다.

STEP_SIZE = 0.001

trainer = tff.experimental.learning.build_jax_federated_averaging_process(
    BATCH_TYPE, MODEL_TYPE, loss, STEP_SIZE)

당신이에서 트레이너 빌드를 사용 하듯이 당신은 위의를 사용할 수 있습니다 tf.Keras TensorFlow에서 모델. 예를 들어 학습을 위한 초기 모델을 생성하는 방법은 다음과 같습니다.

initial_model = trainer.initialize()
initial_model
Struct([('weights', array([[0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       ...,
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.]], dtype=float32)), ('bias', array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32))])

실제 훈련을 수행하려면 몇 가지 데이터가 필요합니다. 단순하게 유지하기 위해 임의의 데이터를 만들어 보겠습니다. 데이터가 무작위이기 때문에 훈련 데이터에 대해 평가할 것입니다. 그렇지 않으면 무작위 평가 데이터를 사용하면 모델이 수행되는 것을 기대하기 어렵기 때문입니다. 또한 이 소규모 데모의 경우 클라이언트를 무작위로 샘플링하는 것에 대해 걱정하지 않을 것입니다.

def random_batch():
  pixels = np.random.uniform(
      low=0.0, high=1.0, size=(50, 784)).astype(np.float32)
  labels = np.random.randint(low=0, high=9, size=(50,), dtype=np.int32)
  return collections.OrderedDict([('pixels', pixels), ('labels', labels)])

NUM_CLIENTS = 2
NUM_BATCHES = 10

train_data = [
    [random_batch() for _ in range(NUM_BATCHES)]
    for _ in range(NUM_CLIENTS)]

이를 통해 다음과 같이 단일 단계의 교육을 수행할 수 있습니다.

trained_model = trainer.next(initial_model, train_data)
trained_model
Struct([('weights', array([[ 1.04456245e-04, -1.53498477e-05,  2.54597180e-05, ...,
         5.61640409e-05, -5.32875274e-05, -4.62881755e-04],
       [ 7.30908650e-05,  4.67643113e-05,  2.03352147e-06, ...,
         3.77510623e-05,  3.52839161e-05, -4.59865667e-04],
       [ 8.14835730e-05,  3.03147244e-05, -1.89143739e-05, ...,
         1.12527239e-04,  4.09212225e-06, -4.59960109e-04],
       ...,
       [ 9.23552434e-05,  2.44302555e-06, -2.20817346e-05, ...,
         7.61375341e-05,  1.76906979e-05, -4.43495519e-04],
       [ 1.17451040e-04,  2.47748958e-05,  1.04728279e-05, ...,
         5.26388249e-07,  7.21131510e-05, -4.67137404e-04],
       [ 3.75041491e-05,  6.58061981e-05,  1.14522081e-05, ...,
         2.52584141e-05,  3.55410739e-05, -4.30888613e-04]], dtype=float32)), ('bias', array([ 1.5096272e-04,  2.6502126e-05, -1.9462314e-05,  8.1269856e-05,
        2.1832302e-04,  1.6636557e-04,  1.2815947e-04,  9.0642272e-05,
        7.7109929e-05, -9.1987278e-04], dtype=float32))])

훈련 단계의 결과를 평가해 봅시다. 쉽게 유지하기 위해 중앙 집중식으로 평가할 수 있습니다.

import itertools
eval_data = list(itertools.chain.from_iterable(train_data))

def average_loss(model, data):
  return np.mean([loss(model, batch) for batch in data])

print (average_loss(initial_model, eval_data))
print (average_loss(trained_model, eval_data))
2.3025854
2.282762

손실이 감소하고 있습니다. 엄청난! 이제 이것을 여러 라운드에 걸쳐 실행해 보겠습니다.

NUM_ROUNDS = 20
for _ in range(NUM_ROUNDS):
  trained_model = trainer.next(trained_model, train_data)
  print(average_loss(trained_model, eval_data))
2.2685437
2.257856
2.2495182
2.2428129
2.2372835
2.2326245
2.2286277
2.2251441
2.2220676
2.219318
2.2168345
2.2145717
2.2124937
2.2105706
2.2087805
2.2071042
2.2055268
2.2040353
2.2026198
2.2012706

보시다시피 TFF와 함께 JAX를 사용하는 것은 크게 다르지 않습니다. 비록 실험적 API가 아직 기능면에서 TensorFlow API와 동등하지는 않지만.

후드

미리 준비된 API를 사용하지 않으려면 TensorFlow에 대한 사용자 지정 알고리즘 자습서에서 수행한 것과 같은 방식으로 자체 사용자 지정 계산을 구현할 수 있습니다. 단, 경사하강법에 JAX의 메커니즘을 사용한다는 점만 다릅니다. 예를 들어, 다음은 단일 미니배치에서 모델을 업데이트하는 JAX 계산을 정의하는 방법입니다.

@tff.experimental.jax_computation(MODEL_TYPE, BATCH_TYPE)
def train_on_one_batch(model, batch):
  grads = jax.grad(loss)(model, batch)
  return collections.OrderedDict([
      (k, model[k] - STEP_SIZE * grads[k]) for k in ['weights', 'bias']
  ])

작동하는지 테스트하는 방법은 다음과 같습니다.

sample_batch = random_batch()
trained_model = train_on_one_batch(initial_model, sample_batch)
print(average_loss(initial_model, [sample_batch]))
print(average_loss(trained_model, [sample_batch]))
2.3025854
2.2977567

JAX 작업 중 하나주의해야 할 점은 그것의 동등 제공하지 않는다는 것입니다 tf.data.Dataset . 따라서 데이터 세트를 반복하려면 아래와 같이 시퀀스 작업에 TFF의 선언적 구문을 사용해야 합니다.

@tff.federated_computation(MODEL_TYPE, tff.SequenceType(BATCH_TYPE))
def train_on_one_client(model, batches):
  return tff.sequence_reduce(batches, model, train_on_one_batch)

작동하는지 봅시다:

sample_dataset = [random_batch() for _ in range(100)]
trained_model = train_on_one_client(initial_model, sample_dataset)
print(average_loss(initial_model, sample_dataset))
print(average_loss(trained_model, sample_dataset))
2.3025854
2.2284968

단일 훈련 라운드를 수행하는 계산은 TensorFlow 자습서에서 본 것과 같습니다.

@tff.federated_computation(
    tff.FederatedType(MODEL_TYPE, tff.SERVER),
    tff.FederatedType(tff.SequenceType(BATCH_TYPE), tff.CLIENTS))
def train_one_round(model, federated_data):
  locally_trained_models = tff.federated_map(
      train_on_one_client,
      collections.OrderedDict([
          ('model', tff.federated_broadcast(model)),
          ('batches', federated_data)]))
  return tff.federated_mean(locally_trained_models)

작동하는지 봅시다:

trained_model = train_one_round(initial_model, train_data)
print(average_loss(initial_model, eval_data))
print(average_loss(trained_model, eval_data))
2.3025854
2.282762

보시다시피 미리 준비된 API를 통해 또는 저수준 TFF 구성을 직접 사용하여 TFF에서 JAX를 사용하는 것은 TensorFlow와 함께 TFF를 사용하는 것과 유사합니다. 향후 업데이트를 기대해 주세요. ML 프레임워크 간의 상호 운용성에 대한 더 나은 지원을 원하시면 언제든지 풀 리퀘스트를 보내주세요!