날짜를 저장하십시오! Google I / O가 5 월 18 일부터 20 일까지 반환됩니다. 지금 등록
이 페이지는 Cloud Translation API를 통해 번역되었습니다.
Switch to English

TFF에서 JAX에 대한 실험적 지원

TensorFlow.org에서보기 Google Colab에서 실행 GitHub에서보기 노트북 다운로드

TensorFlow 에코 시스템의 일부인 것 외에도 TFF는 다른 프런트 엔드 및 백엔드 ML 프레임 워크와의 상호 운용성을 지원하는 것을 목표로합니다. 현재 다른 ML 프레임 워크에 대한 지원은 아직 인큐베이션 단계에 있으며 지원되는 API 및 기능은 변경 될 수 있습니다 (대부분 TFF 사용자의 요구에 따라). 이 자습서에서는 대체 ML 프런트 엔드로 JAX와 함께 TFF를 사용하고 대체 백엔드로 XLA 컴파일러를 사용하는 방법을 설명합니다. 여기에 표시된 예제는 완전한 기본 JAX / XLA 스택, 엔드 투 엔드를 기반으로합니다. 프레임 워크 (예 : JAX와 TensorFlow)간에 코드를 혼합 할 수있는 가능성은 향후 튜토리얼 중 하나에서 논의 될 것입니다.

언제나 그렇듯이 여러분의 기여를 환영합니다. JAX / XLA에 대한 지원 또는 다른 ML 프레임 워크와의 상호 운용 기능이 중요한 경우 이러한 기능을 나머지 TFF와 동등하게 발전시킬 수 있도록 도와 주시기 바랍니다.

시작하기 전에

환경을 구성하는 방법은 TFF 문서의 본문을 참조하십시오. 이 자습서를 실행하는 위치에 따라 아래 코드의 일부 또는 전체를 주석 해제하고 실행할 수 있습니다.

# !pip install --quiet --upgrade tensorflow-federated-nightly
# !pip install --quiet --upgrade nest-asyncio
# import nest_asyncio
# nest_asyncio.apply()

또한이 자습서에서는 TFF의 기본 TensorFlow 자습서를 검토했으며 핵심 TFF 개념에 익숙하다고 가정합니다. 아직이 작업을 수행하지 않은 경우 적어도 하나를 검토해보십시오.

JAX 계산

TFF에서 JAX에 대한 지원은 TFF가 가져 오기부터 시작하여 TensorFlow와 상호 운용되는 방식과 대칭되도록 설계되었습니다.

import jax
import numpy as np
import tensorflow_federated as tff

또한 TensorFlow와 마찬가지로 TFF 코드를 표현하기위한 기반은 로컬에서 실행되는 로직입니다. @tff.experimental.jax_computation 래퍼를 사용하여 아래와 같이 JAX에서이 로직을 표현할 수 있습니다. 지금까지 익숙한 @tff.tf_computation 과 유사하게 작동합니다. 두 개의 정수를 더하는 계산과 같은 간단한 것부터 시작해 보겠습니다.

@tff.experimental.jax_computation(np.int32, np.int32)
def add_numbers(x, y):
  return jax.numpy.add(x, y)

일반적으로 TFF 계산을 사용하는 것처럼 위에서 정의한 JAX 계산을 사용할 수 있습니다. 예를 들어 다음과 같이 유형 서명을 확인할 수 있습니다.

str(add_numbers.type_signature)
'(<x=int32,y=int32> -> int32)'

인수 유형을 정의하기 위해 np.int32 를 사용했습니다. TFF는 Numpy 유형 (예 : np.int32 )과 TensorFlow 유형 (예 : tf.int32 )을 구분하지 않습니다. TFF의 관점에서 그들은 단지 같은 것을 참조하는 방법 일뿐입니다.

이제 TFF는 Python이 아니라는 점을 기억하십시오 (그리고 이것이 종을 울리지 않는다면, 예를 들어 사용자 정의 알고리즘에 대한 이전 튜토리얼 중 일부를 검토하십시오). @tff.experimental.jax_computation 래퍼를 추적 및 직렬화 할 수있는 모든 JAX 코드와 함께 사용할 수 있습니다. 즉, 일반적으로 XLA로 컴파일 될 것으로 예상되는 @jax.jit 주석을다는 코드와 함께 사용할 수 있습니다 (하지만 그럴 필요는 없습니다). 실제로 @jax.jit 주석을 사용하여 JAX 코드를 TFF에 포함).

실제로 내부적으로 TFF는 JAX 계산을 XLA로 즉시 컴파일합니다. 다음과 같이 add_numbers 에서 직렬화 된 XLA 코드를 수동으로 추출하고 인쇄하여 직접 확인할 수 있습니다.

comp_pb = tff.framework.serialize_computation(add_numbers)
comp_pb.WhichOneof('computation')
'xla'
xla_code = jax.lib.xla_client.XlaComputation(comp_pb.xla.hlo_module.value)
print(xla_code.as_hlo_text())
HloModule xla_computation_add_numbers.7

ENTRY xla_computation_add_numbers.7 {
  constant.4 = pred[] constant(false)
  parameter.1 = (s32[], s32[]) parameter(0)
  get-tuple-element.2 = s32[] get-tuple-element(parameter.1), index=0
  get-tuple-element.3 = s32[] get-tuple-element(parameter.1), index=1
  add.5 = s32[] add(get-tuple-element.2, get-tuple-element.3)
  ROOT tuple.6 = (s32[]) tuple(add.5)
}

JAX 계산을 XLA 코드로 표현하는 것을 TensorFlow로 표현 된 계산을위한 tf.GraphDef 와 기능적으로 동일하다고 생각하십시오. tf.GraphDef 가 모든 TensorFlow 런타임에서 실행될 수있는 것처럼 XLA를 지원하는 다양한 환경에서 이식 가능하고 실행 가능합니다.

TFF는 XLA 컴파일러를 기반으로하는 런타임 스택을 백엔드로 제공합니다. 다음과 같이 활성화 할 수 있습니다.

tff.backends.xla.set_local_execution_context()

이제 위에서 정의한 계산을 실행할 수 있습니다.

add_numbers(2, 3)
5

충분히 쉽습니다. 타격을 가하고 MNIST와 같은 더 복잡한 작업을 수행합시다.

미리 준비된 API를 사용한 MNIST 교육의 예

평소와 같이 데이터 일괄 처리와 모델에 대해 TFF 유형을 정의하는 것으로 시작합니다 (TFF는 강력한 유형의 프레임 워크임을 기억하십시오).

import collections

BATCH_TYPE = collections.OrderedDict([
    ('pixels', tff.TensorType(np.float32, (50, 784))),
    ('labels', tff.TensorType(np.int32, (50,)))
])

MODEL_TYPE = collections.OrderedDict([
    ('weights', tff.TensorType(np.float32, (784, 10))),
    ('bias', tff.TensorType(np.float32, (10,)))
])

이제 모델과 단일 데이터 배치를 매개 변수로 사용하여 JAX에서 모델에 대한 손실 함수를 정의 해 보겠습니다.

def loss(model, batch):
  y = jax.nn.softmax(
      jax.numpy.add(
          jax.numpy.matmul(batch['pixels'], model['weights']), model['bias']))
  targets = jax.nn.one_hot(jax.numpy.reshape(batch['labels'], -1), 10)
  return -jax.numpy.mean(jax.numpy.sum(targets * jax.numpy.log(y), axis=1))

이제 한 가지 방법은 미리 준비된 API를 사용하는 것입니다. 다음은 API를 사용하여 방금 정의한 손실 함수를 기반으로 훈련 프로세스를 생성하는 방법의 예입니다.

STEP_SIZE = 0.001

trainer = tff.experimental.learning.build_jax_federated_averaging_process(
    BATCH_TYPE, MODEL_TYPE, loss, STEP_SIZE)

TensorFlow의 tf.Keras 모델에서 트레이너 빌드를 사용하는 것처럼 위를 사용할 수 있습니다. 예를 들어 학습을위한 초기 모델을 만드는 방법은 다음과 같습니다.

initial_model = trainer.initialize()
initial_model
Struct([('weights', array([[0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       ...,
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.]], dtype=float32)), ('bias', array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32))])

실제 훈련을 수행하려면 데이터가 필요합니다. 단순하게 유지하기 위해 임의의 데이터를 만들어 봅시다. 데이터가 무작위이기 때문에 우리는 훈련 데이터에서 평가할 것입니다. 그렇지 않으면 무작위 평가 데이터를 사용하면 모델이 수행 될 것이라고 기대하기 어렵 기 때문입니다. 또한이 소규모 데모에서는 클라이언트를 무작위로 샘플링하는 것에 대해 걱정하지 않습니다 (다른 자습서의 템플릿을 따라 이러한 유형의 변경 사항을 탐색하는 사용자의 연습으로 남겨 둡니다).

def random_batch():
  pixels = np.random.uniform(
      low=0.0, high=1.0, size=(50, 784)).astype(np.float32)
  labels = np.random.randint(low=0, high=9, size=(50,), dtype=np.int32)
  return collections.OrderedDict([('pixels', pixels), ('labels', labels)])

NUM_CLIENTS = 2
NUM_BATCHES = 10

train_data = [
    [random_batch() for _ in range(NUM_BATCHES)]
    for _ in range(NUM_CLIENTS)]

이를 통해 다음과 같은 단일 단계의 교육을 수행 할 수 있습니다.

trained_model = trainer.next(initial_model, train_data)
trained_model
Struct([('weights', array([[ 1.04456245e-04, -1.53498477e-05,  2.54597180e-05, ...,
         5.61640409e-05, -5.32875274e-05, -4.62881755e-04],
       [ 7.30908650e-05,  4.67643113e-05,  2.03352147e-06, ...,
         3.77510623e-05,  3.52839161e-05, -4.59865667e-04],
       [ 8.14835730e-05,  3.03147244e-05, -1.89143739e-05, ...,
         1.12527239e-04,  4.09212225e-06, -4.59960109e-04],
       ...,
       [ 9.23552434e-05,  2.44302555e-06, -2.20817346e-05, ...,
         7.61375341e-05,  1.76906979e-05, -4.43495519e-04],
       [ 1.17451040e-04,  2.47748958e-05,  1.04728279e-05, ...,
         5.26388249e-07,  7.21131510e-05, -4.67137404e-04],
       [ 3.75041491e-05,  6.58061981e-05,  1.14522081e-05, ...,
         2.52584141e-05,  3.55410739e-05, -4.30888613e-04]], dtype=float32)), ('bias', array([ 1.5096272e-04,  2.6502126e-05, -1.9462314e-05,  8.1269856e-05,
        2.1832302e-04,  1.6636557e-04,  1.2815947e-04,  9.0642272e-05,
        7.7109929e-05, -9.1987278e-04], dtype=float32))])

훈련 단계의 결과를 평가 해 봅시다. 쉽게하기 위해 중앙 집중식으로 평가할 수 있습니다.

import itertools
eval_data = list(itertools.chain.from_iterable(train_data))

def average_loss(model, data):
  return np.mean([loss(model, batch) for batch in data])

print (average_loss(initial_model, eval_data))
print (average_loss(trained_model, eval_data))
2.3025854
2.282762

손실이 감소하고 있습니다. 큰! 이제 여러 라운드에 걸쳐 실행 해 봅시다.

NUM_ROUNDS = 20
for _ in range(NUM_ROUNDS):
  trained_model = trainer.next(trained_model, train_data)
  print(average_loss(trained_model, eval_data))
2.2685437
2.257856
2.2495182
2.2428129
2.2372835
2.2326245
2.2286277
2.2251441
2.2220676
2.219318
2.2168345
2.2145717
2.2124937
2.2105706
2.2087805
2.2071042
2.2055268
2.2040353
2.2026198
2.2012706

보시다시피, TFF와 함께 JAX를 사용하는 것은 그다지 다르지 않지만 실험적 API는 기능면에서 아직 TensorFlow API와 동등하지 않습니다.

후드

미리 준비된 API를 사용하지 않으려면 기울기 하강에 JAX의 메커니즘을 사용한다는 점을 제외하고 TensorFlow의 사용자 지정 알고리즘 자습서에서 확인한 것과 동일한 방식으로 사용자 지정 계산을 구현할 수 있습니다. 예를 들어 다음은 단일 미니 배치에서 모델을 업데이트하는 JAX 계산을 정의하는 방법입니다.

@tff.experimental.jax_computation(MODEL_TYPE, BATCH_TYPE)
def train_on_one_batch(model, batch):
  grads = jax.api.grad(loss)(model, batch)
  return collections.OrderedDict([
      (k, model[k] - STEP_SIZE * grads[k]) for k in ['weights', 'bias']
  ])

작동하는지 테스트하는 방법은 다음과 같습니다.

sample_batch = random_batch()
trained_model = train_on_one_batch(initial_model, sample_batch)
print(average_loss(initial_model, [sample_batch]))
print(average_loss(trained_model, [sample_batch]))
2.3025854
2.2977567

JAX로 작업 할 때의 한 가지주의 사항은tf.data.Dataset 해당하는 것을 제공하지 않는다는tf.data.Dataset 입니다. 따라서 데이터 세트를 반복하려면 아래 표시된 것과 같은 시퀀스 작업에 TFF의 선언적 구조를 사용해야합니다.

@tff.federated_computation(MODEL_TYPE, tff.SequenceType(BATCH_TYPE))
def train_on_one_client(model, batches):
  return tff.sequence_reduce(batches, model, train_on_one_batch)

작동하는지 봅시다 :

sample_dataset = [random_batch() for _ in range(100)]
trained_model = train_on_one_client(initial_model, sample_dataset)
print(average_loss(initial_model, sample_dataset))
print(average_loss(trained_model, sample_dataset))
2.3025854
2.2284968

단일 라운드의 학습을 수행하는 계산은 TensorFlow 자습서에서 볼 수있는 것과 유사합니다.

@tff.federated_computation(
    tff.FederatedType(MODEL_TYPE, tff.SERVER),
    tff.FederatedType(tff.SequenceType(BATCH_TYPE), tff.CLIENTS))
def train_one_round(model, federated_data):
  locally_trained_models = tff.federated_map(
      train_on_one_client,
      collections.OrderedDict([
          ('model', tff.federated_broadcast(model)),
          ('batches', federated_data)]))
  return tff.federated_mean(locally_trained_models)

작동하는지 봅시다 :

trained_model = train_one_round(initial_model, train_data)
print(average_loss(initial_model, eval_data))
print(average_loss(trained_model, eval_data))
2.3025854
2.282762

보시다시피 미리 준비된 API를 통하든 저수준 TFF 구조를 직접 사용하든 TFF에서 JAX를 사용하는 것은 TensorFlow에서 TFF를 사용하는 것과 유사합니다. 향후 업데이트를 계속 지켜봐 주시기 바랍니다. ML 프레임 워크 간의 상호 운용성에 대한 더 나은 지원을보고 싶다면 언제든지 풀 요청을 보내주세요!