TFFでのJAXの実験的サポート

TensorFlow.orgで表示 GoogleColabで実行 GitHubで表示 ノートブックをダウンロード

TFFは、TensorFlowエコシステムの一部であることに加えて、他のフロントエンドおよびバックエンドMLフレームワークとの相互運用性を実現することを目的としています。現時点では、他のMLフレームワークのサポートはまだインキュベーション段階にあり、サポートされるAPIと機能は変更される可能性があります(主にTFFのユーザーからの要求に応じて)。このチュートリアルでは、代替MLフロントエンドとしてJAXを使用してTFFを使用し、代替バックエンドとしてXLAコンパイラを使用する方法について説明します。ここに示す例は、エンドツーエンドの完全にネイティブなJAX / XLAスタックに基づいています。フレームワーク間でコードを混在させる可能性(たとえば、JAXとTensorFlow)については、今後のチュートリアルの1つで説明します。

いつものように、私たちはあなたの貢献を歓迎します。 JAX / XLAのサポート、または他のMLフレームワークと相互運用する機能が重要な場合は、これらの機能をTFFの他の部分と同等に進化させることを検討してください。

始める前に

環境の構成方法については、TFFドキュメントの本文を参照してください。このチュートリアルを実行している場所によっては、コメントを外して、以下のコードの一部またはすべてを実行することをお勧めします。

# !pip install --quiet --upgrade tensorflow-federated-nightly
# !pip install --quiet --upgrade nest-asyncio
# import nest_asyncio
# nest_asyncio.apply()

このチュートリアルでは、TFFの主要なTensorFlowチュートリアルを確認し、TFFのコアコンセプトに精通していることも前提としています。まだこれを行っていない場合は、少なくとも1つを確認することを検討してください。

JAX計算

TFFでのJAXのサポートは、インポートから始めて、TFFがTensorFlowと相互運用する方法と対称になるように設計されています。

import jax
import numpy as np
import tensorflow_federated as tff

また、TensorFlowの場合と同様に、TFFコードを表現するための基盤は、ローカルで実行されるロジックです。使用して、以下に示すようにあなたは、JAXでこのロジックを表現することができます@tff.experimental.jax_computationラッパーを。これは、と同様に動作@tff.tf_computation今ではあなたに精通していること。簡単なことから始めましょう。たとえば、2つの整数を加算する計算です。

@tff.experimental.jax_computation(np.int32, np.int32)
def add_numbers(x, y):
  return jax.numpy.add(x, y)

通常TFF計算を使用するのと同じように、上記で定義したJAX計算を使用できます。たとえば、次のように型シグネチャを確認できます。

str(add_numbers.type_signature)
'(<x=int32,y=int32> -> int32)'

私たちが使用していることに注意してくださいnp.int32引数の型を定義します。 TFFは、(例えば、numpyのタイプを区別しないnp.int32 (など)とTensorFlowタイプtf.int32 )。 TFFの観点からは、これらは同じことを参照するための単なる方法です。

ここで、TFFはPythonではないことを忘れないでください(これで問題が解決しない場合は、カスタムアルゴリズムなど、以前のチュートリアルのいくつかを確認してください)。あなたは使用することができます@tff.experimental.jax_computationあなたと正常に注釈することをコードで、トレースし、シリアル化できる任意のJAXコード、すなわちでラッパーを@jax.jit XLAにコンパイルされると予想(しかし、あなたがする必要はありません実際に使用@jax.jit )TFFであなたのJAXコードを埋め込む注釈を。

実際、内部では、TFFはJAX計算をXLAに即座にコンパイルします。手動で抽出してから連載XLAコード印刷して、自分のためにこれを確認することができadd_numbers次のように、:

comp_pb = tff.framework.serialize_computation(add_numbers)
comp_pb.WhichOneof('computation')
'xla'
xla_code = jax.lib.xla_client.XlaComputation(comp_pb.xla.hlo_module.value)
print(xla_code.as_hlo_text())
HloModule xla_computation_add_numbers.7

ENTRY xla_computation_add_numbers.7 {
  constant.4 = pred[] constant(false)
  parameter.1 = (s32[], s32[]) parameter(0)
  get-tuple-element.2 = s32[] get-tuple-element(parameter.1), index=0
  get-tuple-element.3 = s32[] get-tuple-element(parameter.1), index=1
  add.5 = s32[] add(get-tuple-element.2, get-tuple-element.3)
  ROOT tuple.6 = (s32[]) tuple(add.5)
}

機能的同等物であるとしてXLAコードとしてJAX計算の表現を考えるtf.GraphDef TensorFlowで発現計算のため。それはちょうど同じように、XLAをサポートし、様々な環境でのポータブルおよび実行可能であるtf.GraphDefどのTensorFlowランタイム上で実行することができます。

TFFは、バックエンドとしてXLAコンパイラに基づくランタイムスタックを提供します。次のようにアクティブ化できます。

tff.backends.xla.set_local_python_execution_context()

これで、上記で定義した計算を実行できます。

add_numbers(2, 3)
5

簡単です。打撃を与えて、MNISTなどのより複雑なことをしてみましょう。

既定のAPIを使用したMNISTトレーニングの例

いつものように、データのバッチとモデルに対して一連のTFFタイプを定義することから始めます(TFFは強く型付けされたフレームワークであることを忘れないでください)。

import collections

BATCH_TYPE = collections.OrderedDict([
    ('pixels', tff.TensorType(np.float32, (50, 784))),
    ('labels', tff.TensorType(np.int32, (50,)))
])

MODEL_TYPE = collections.OrderedDict([
    ('weights', tff.TensorType(np.float32, (784, 10))),
    ('bias', tff.TensorType(np.float32, (10,)))
])

次に、モデルとデータの単一バッチをパラメーターとして使用して、JAXでモデルの損失関数を定義しましょう。

def loss(model, batch):
  y = jax.nn.softmax(
      jax.numpy.add(
          jax.numpy.matmul(batch['pixels'], model['weights']), model['bias']))
  targets = jax.nn.one_hot(jax.numpy.reshape(batch['labels'], -1), 10)
  return -jax.numpy.mean(jax.numpy.sum(targets * jax.numpy.log(y), axis=1))

さて、行く方法の1つは、既定のAPIを使用することです。これは、APIを使用して、定義した損失関数に基づいてトレーニングプロセスを作成する方法の例です。

STEP_SIZE = 0.001

trainer = tff.experimental.learning.build_jax_federated_averaging_process(
    BATCH_TYPE, MODEL_TYPE, loss, STEP_SIZE)

あなたがからトレーナーのビルドを使用するのと同じように上記使用することができますtf.Keras TensorFlowでモデル。たとえば、トレーニング用の初期モデルを作成する方法は次のとおりです。

initial_model = trainer.initialize()
initial_model
Struct([('weights', array([[0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       ...,
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.]], dtype=float32)), ('bias', array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32))])

実際のトレーニングを行うには、いくつかのデータが必要です。単純にするためにランダムデータを作成しましょう。データはランダムであるため、トレーニングデータで評価します。そうしないと、ランダムな評価データでは、モデルのパフォーマンスを期待するのが困難になります。また、この小規模なデモでは、クライアントをランダムにサンプリングすることについて心配する必要はありません(他のチュートリアルのテンプレートに従って、これらのタイプの変更を調査するための演習としてユーザーに任せます)。

def random_batch():
  pixels = np.random.uniform(
      low=0.0, high=1.0, size=(50, 784)).astype(np.float32)
  labels = np.random.randint(low=0, high=9, size=(50,), dtype=np.int32)
  return collections.OrderedDict([('pixels', pixels), ('labels', labels)])

NUM_CLIENTS = 2
NUM_BATCHES = 10

train_data = [
    [random_batch() for _ in range(NUM_BATCHES)]
    for _ in range(NUM_CLIENTS)]

これで、次のように1ステップのトレーニングを実行できます。

trained_model = trainer.next(initial_model, train_data)
trained_model
Struct([('weights', array([[ 1.04456245e-04, -1.53498477e-05,  2.54597180e-05, ...,
         5.61640409e-05, -5.32875274e-05, -4.62881755e-04],
       [ 7.30908650e-05,  4.67643113e-05,  2.03352147e-06, ...,
         3.77510623e-05,  3.52839161e-05, -4.59865667e-04],
       [ 8.14835730e-05,  3.03147244e-05, -1.89143739e-05, ...,
         1.12527239e-04,  4.09212225e-06, -4.59960109e-04],
       ...,
       [ 9.23552434e-05,  2.44302555e-06, -2.20817346e-05, ...,
         7.61375341e-05,  1.76906979e-05, -4.43495519e-04],
       [ 1.17451040e-04,  2.47748958e-05,  1.04728279e-05, ...,
         5.26388249e-07,  7.21131510e-05, -4.67137404e-04],
       [ 3.75041491e-05,  6.58061981e-05,  1.14522081e-05, ...,
         2.52584141e-05,  3.55410739e-05, -4.30888613e-04]], dtype=float32)), ('bias', array([ 1.5096272e-04,  2.6502126e-05, -1.9462314e-05,  8.1269856e-05,
        2.1832302e-04,  1.6636557e-04,  1.2815947e-04,  9.0642272e-05,
        7.7109929e-05, -9.1987278e-04], dtype=float32))])

トレーニングステップの結果を評価してみましょう。簡単にするために、一元化された方法で評価できます。

import itertools
eval_data = list(itertools.chain.from_iterable(train_data))

def average_loss(model, data):
  return np.mean([loss(model, batch) for batch in data])

print (average_loss(initial_model, eval_data))
print (average_loss(trained_model, eval_data))
2.3025854
2.282762

損失は​​減少しています。素晴らしい!それでは、これを複数のラウンドで実行してみましょう。

NUM_ROUNDS = 20
for _ in range(NUM_ROUNDS):
  trained_model = trainer.next(trained_model, train_data)
  print(average_loss(trained_model, eval_data))
2.2685437
2.257856
2.2495182
2.2428129
2.2372835
2.2326245
2.2286277
2.2251441
2.2220676
2.219318
2.2168345
2.2145717
2.2124937
2.2105706
2.2087805
2.2071042
2.2055268
2.2040353
2.2026198
2.2012706

ご覧のとおり、TFFでJAXを使用することはそれほど違いはありませんが、実験的なAPIは機能的にはまだTensorFlowAPIと同等ではありません。

フードの下

既定のAPIを使用したくない場合は、勾配降下法にJAXのメカニズムを使用することを除いて、TensorFlowのカスタムアルゴリズムチュートリアルで行ったのとほぼ同じ方法で、独自のカスタム計算を実装できます。たとえば、以下は、単一のミニバッチでモデルを更新するJAX計算を定義する方法です。

@tff.experimental.jax_computation(MODEL_TYPE, BATCH_TYPE)
def train_on_one_batch(model, batch):
  grads = jax.grad(loss)(model, batch)
  return collections.OrderedDict([
      (k, model[k] - STEP_SIZE * grads[k]) for k in ['weights', 'bias']
  ])

それが機能することをテストする方法は次のとおりです。

sample_batch = random_batch()
trained_model = train_on_one_batch(initial_model, sample_batch)
print(average_loss(initial_model, [sample_batch]))
print(average_loss(trained_model, [sample_batch]))
2.3025854
2.2977567

JAXでの作業の1つの警告は、それが同等提供していないということであるtf.data.Dataset 。したがって、データセットを反復処理するには、以下に示すようなシーケンスの操作にTFFの宣言型構造を使用する必要があります。

@tff.federated_computation(MODEL_TYPE, tff.SequenceType(BATCH_TYPE))
def train_on_one_client(model, batches):
  return tff.sequence_reduce(batches, model, train_on_one_batch)

それが機能することを見てみましょう:

sample_dataset = [random_batch() for _ in range(100)]
trained_model = train_on_one_client(initial_model, sample_dataset)
print(average_loss(initial_model, sample_dataset))
print(average_loss(trained_model, sample_dataset))
2.3025854
2.2284968

1ラウンドのトレーニングを実行する計算は、TensorFlowチュートリアルで見たものと同じように見えます。

@tff.federated_computation(
    tff.FederatedType(MODEL_TYPE, tff.SERVER),
    tff.FederatedType(tff.SequenceType(BATCH_TYPE), tff.CLIENTS))
def train_one_round(model, federated_data):
  locally_trained_models = tff.federated_map(
      train_on_one_client,
      collections.OrderedDict([
          ('model', tff.federated_broadcast(model)),
          ('batches', federated_data)]))
  return tff.federated_mean(locally_trained_models)

それが機能することを見てみましょう:

trained_model = train_one_round(initial_model, train_data)
print(average_loss(initial_model, eval_data))
print(average_loss(trained_model, eval_data))
2.3025854
2.282762

ご覧のとおり、TFFでJAXを使用することは、定型APIを介する場合でも、低レベルのTFF構造を直接使用する場合でも、TensorFlowでTFFを使用する場合と似ています。今後のアップデートにご期待ください。MLフレームワーク間の相互運用性のサポートを強化したい場合は、プルリクエストを送信してください。