日付を保存! Google I / Oが5月18日から20日に戻ってきます今すぐ登録
このページは Cloud Translation API によって翻訳されました。
Switch to English

概要概要

TensorFlow.orgで表示 GoogleColabで実行 GitHubでソースを表示ノートブックをダウンロード

始める前に

開始する前に、以下を実行して、環境が正しくセットアップされていることを確認してください。あいさつが表示されない場合は、インストールガイドを参照して手順を確認してください。

!pip install --quiet --upgrade tensorflow-federated-nightly
!pip install --quiet --upgrade nest-asyncio

import nest_asyncio
nest_asyncio.apply()
import collections
import attr
import functools
import numpy as np
import tensorflow as tf
import tensorflow_federated as tff

np.random.seed(0)

画像分類テキスト生成のチュートリアルでは、フェデレーションラーニング(FL)のモデルとデータパイプラインを設定する方法を学び、TFFのtff.learningレイヤーを介してフェデレーショントレーニングを実行しました。

これは、FLの研究に関しては氷山の一角にすぎません。このチュートリアルでは、 tff.learning使用せずにフェデレーション学習アルゴリズムを実装する方法について説明します。私たちは以下を達成することを目指しています:

目標:

  • 連合学習アルゴリズムの一般的な構造を理解します。
  • TFFのフェデレーションコアを探索してください。
  • Federated Coreを使用して、FederatedAveragingを直接実装します。

このチュートリアルは自己完結型ですが、最初に画像分類テキスト生成のチュートリアルを読むことをお勧めします

入力データの準備

まず、TFFに含まれているEMNISTデータセットをロードして前処理します。詳細については、画像分類チュートリアルを参照してください。

emnist_train, emnist_test = tff.simulation.datasets.emnist.load_data()

データセットをモデルにフィードするために、データをフラット化し、各例を形式(flattened_image_vector, label)タプルに変換します。

NUM_CLIENTS = 10
BATCH_SIZE = 20

def preprocess(dataset):

  def batch_format_fn(element):
    """Flatten a batch of EMNIST data and return a (features, label) tuple."""
    return (tf.reshape(element['pixels'], [-1, 784]), 
            tf.reshape(element['label'], [-1, 1]))

  return dataset.batch(BATCH_SIZE).map(batch_format_fn)

ここで、少数のクライアントをサンプリングし、上記の前処理をそれらのデータセットに適用します。

client_ids = np.random.choice(emnist_train.client_ids, size=NUM_CLIENTS, replace=False)

federated_train_data = [preprocess(emnist_train.create_tf_dataset_for_client(x))
  for x in client_ids
]

モデルの準備

画像分類チュートリアルと同じモデルを使用します。このモデル( tf.kerasを介してtf.keras )には、単一の非表示レイヤーがあり、その後にtf.kerasレイヤーが続きます。

def create_keras_model():
  return tf.keras.models.Sequential([
      tf.keras.layers.Input(shape=(784,)),
      tf.keras.layers.Dense(10, kernel_initializer='zeros'),
      tf.keras.layers.Softmax(),
  ])

このモデルをTFFで使用するために、 tff.learning.Modelモデルをtff.learning.Modelとしてラップします。これにより、TFF内でモデルのフォワードパスを実行し、 モデル出力抽出できます。詳細については、画像分類チュートリアルも参照してください。

def model_fn():
  keras_model = create_keras_model()
  return tff.learning.from_keras_model(
      keras_model,
      input_spec=federated_train_data[0].element_spec,
      loss=tf.keras.losses.SparseCategoricalCrossentropy(),
      metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])

tf.kerasを使用してtf.kerasを作成しtff.learning.Model 、TFFははるかに一般的なモデルをサポートしています。これらのモデルには、モデルの重みを取得する次の関連属性があります。

  • trainable_variables :訓練可能な層に対応するテンソルの反復可能。
  • non_trainable_variables :訓練不可能な層に対応するテンソルの反復可能。

ここでは、 trainable_variablesのみを使用します。 (私たちのモデルにはそれらしかありません!)。

独自の連合学習アルゴリズムの構築

tff.learning APIを使用すると、フェデレーション平均化の多くのバリアントを作成できますが、このフレームワークにうまく適合しない他のフェデレーションアルゴリズムがあります。たとえば、正則化、クリッピング、またはフェデレーションGANトレーニングなどのより複雑なアルゴリズムを追加したい場合があります。代わりに、フェデレーション分析に興味があるかもしれません。

これらのより高度なアルゴリズムでは、TFFを使用して独自のカスタムアルゴリズムを作成する必要があります。多くの場合、フェデレーションアルゴリズムには4つの主要なコンポーネントがあります。

  1. サーバーからクライアントへのブロードキャストステップ。
  2. ローカルクライアントの更新手順。
  3. クライアントからサーバーへのアップロード手順。
  4. サーバーの更新手順。

TFFでは、通常、フェデレーションアルゴリズムをtff.templates.IterativeProcess (全体を通して単にIterativeProcesstff.templates.IterativeProcessます)として表します。これは、 initialize関数とnext関数を含むクラスです。ここで、 initializeはサーバーを初期化するために使用され、 nextはフェデレーションアルゴリズムの1回の通信ラウンドを実行します。 FedAvgの反復プロセスがどのようになるかについてのスケルトンを書いてみましょう。

まず、 tff.learning.Model作成し、そのトレーニング可能な重みを返す初期化関数があります。

def initialize_fn():
  model = model_fn()
  return model.trainable_variables

この関数は見栄えがしますが、後で説明するように、「TFF計算」にするために小さな変更を加える必要があります。

next_fnもスケッチしたいとnext_fnます。

def next_fn(server_weights, federated_dataset):
  # Broadcast the server weights to the clients.
  server_weights_at_client = broadcast(server_weights)

  # Each client computes their updated weights.
  client_weights = client_update(federated_dataset, server_weights_at_client)

  # The server averages these updates.
  mean_client_weights = mean(client_weights)

  # The server updates its model.
  server_weights = server_update(mean_client_weights)

  return server_weights

これらの4つのコンポーネントを個別に実装することに焦点を当てます。まず、純粋なTensorFlowで実装できる部分、つまりクライアントとサーバーの更新手順に焦点を当てます。

TensorFlowブロック

クライアントの更新

tff.learning.Modelを使用して、TensorFlowモデルをトレーニングするのと基本的に同じ方法でクライアントトレーニングを実行します。特に、我々は、使用するtf.GradientTape 、その後使用してこれらの勾配を適用し、データのバッチで勾配を計算するためにclient_optimizer 。トレーニング可能なウェイトのみに焦点を当てています。

@tf.function
def client_update(model, dataset, server_weights, client_optimizer):
  """Performs training (using the server model weights) on the client's dataset."""
  # Initialize the client model with the current server weights.
  client_weights = model.trainable_variables
  # Assign the server weights to the client model.
  tf.nest.map_structure(lambda x, y: x.assign(y),
                        client_weights, server_weights)

  # Use the client_optimizer to update the local model.
  for batch in dataset:
    with tf.GradientTape() as tape:
      # Compute a forward pass on the batch of data
      outputs = model.forward_pass(batch)

    # Compute the corresponding gradient
    grads = tape.gradient(outputs.loss, client_weights)
    grads_and_vars = zip(grads, client_weights)

    # Apply the gradient using a client optimizer.
    client_optimizer.apply_gradients(grads_and_vars)

  return client_weights

サーバーの更新

FedAvgのサーバー更新は、クライアント更新よりも簡単です。サーバーモデルの重みをクライアントモデルの重みの平均に置き換えるだけの「バニラ」フェデレーション平均を実装します。繰り返しになりますが、トレーニング可能なウェイトのみに焦点を当てています。

@tf.function
def server_update(model, mean_client_weights):
  """Updates the server model weights as the average of the client model weights."""
  model_weights = model.trainable_variables
  # Assign the mean client weights to the server model.
  tf.nest.map_structure(lambda x, y: x.assign(y),
                        model_weights, mean_client_weights)
  return model_weights

スニペットは、 mean_client_weights返すだけで簡略化できます。ただし、Federated mean_client_weightsより高度な実装では、勢いや適応性などのより高度な手法でmean_client_weightsを使用します。

課題:サーバーの重みをmodel_weightsとmean_client_weightsの中間点になるように更新するバージョンのserver_updateを実装します。 (注:この種の「中間」アプローチは、先読みオプティマイザーに関する最近の作業に類似しています!)。

これまでのところ、純粋なTensorFlowコードのみを記述しました。 TFFを使用すると、既に使い慣れているTensorFlowコードの多くを使用できるため、これは仕様によるものです。ただし、ここで、オーケストレーションロジック、つまり、サーバーがクライアントにブロードキャストするものと、クライアントがサーバーにアップロードするものを指示するロジックを指定する必要があります。

これには、TFFのフェデレーションコアが必要になります。

フェデレーションコアの概要

Federated Core(FC)は、 tff.learning基盤として機能する低レベルのインターフェイスのセットです。ただし、これらのインターフェースは学習に限定されません。実際、分散データの分析やその他の多くの計算に使用できます。

大まかに言えば、フェデレーションコアは、コンパクトに表現されたプログラムロジックがTensorFlowコードを分散通信演算子(分散合計やブロードキャストなど)と組み合わせることができるようにする開発環境です。目標は、システム実装の詳細(ポイントツーポイントネットワークメッセージ交換の指定など)を必要とせずに、研究者や実務家がシステム内の分散通信を明示的に制御できるようにすることです。

重要な点の1つは、TFFがプライバシー保護のために設計されていることです。したがって、データが存在する場所を明示的に制御して、中央のサーバーの場所にデータが不要に蓄積されるのを防ぐことができます。

連合データ

TFFの重要な概念は「連合データ」です。これは、分散システム内のデバイスのグループ全体でホストされているデータ項目のコレクションを指します(クライアントデータセットやサーバーモデルの重みなど)。すべてのデバイスにわたるデータ項目のコレクション全体を単一のフェデレーション値としてモデル化します

たとえば、センサーの温度を表すフロートがそれぞれにあるクライアントデバイスがあるとします。によってフェデレーションフロートとして表すことができます

federated_float_on_clients = tff.FederatedType(tf.float32, tff.CLIENTS)

フェデレーションタイプは、そのメンバー構成要素のタイプTtf.float32 )とデバイスのグループGによって指定されます。 Gtff.CLIENTSまたはtff.SERVERいずれかである場合に焦点を当てます。このようなフェデレーションタイプは、以下に示すように{T}@Gとして表されます。

str(federated_float_on_clients)
'{float32}@CLIENTS'

なぜ私たちは配置をそれほど気にするのですか? TFFの主な目標は、実際の分散システムにデプロイできるコードを記述できるようにすることです。これは、デバイスのどのサブセットがどのコードを実行し、さまざまなデータがどこにあるかを推論することが重要であることを意味します。

TFFは三つのことに焦点を当て:データ、データが配置され、データがどのように変換されます。最初の2つはフェデレーション型にカプセル化され、最後の2つはフェデレーション計算にカプセル化されます。

連合計算

TFFは、基本単位がフェデレーション計算である、強く型付けされた関数型プログラミング環境です。これらは、フェデレーション値を入力として受け入れ、フェデレーション値を出力として返すロジックの一部です。

たとえば、クライアントセンサーの温度を平均したいとします。以下を定義できます(フェデレーションフロートを使用)。

@tff.federated_computation(tff.FederatedType(tf.float32, tff.CLIENTS))
def get_average_temperature(client_temperatures):
  return tff.federated_mean(client_temperatures)

これはtf.functionデコレータとどう違うのですか?重要な答えは、 tff.federated_computationによって生成されたコードはTensorFlowでもPythonコードでもないということです。これは、内部プラットフォームに依存しないグルー言語での分散システムの仕様です

これは複雑に聞こえるかもしれませんが、TFF計算は、明確に定義された型シグネチャを持つ関数と考えることができます。これらのタイプシグネチャは直接クエリできます。

str(get_average_temperature.type_signature)
'({float32}@CLIENTS -> float32@SERVER)'

このtff.federated_computation 、連合型の引数を受け付け{float32}@CLIENTS 、そして連合型の戻り値{float32}@SERVER 。フェデレーション計算は、サーバーからクライアントへ、クライアントからクライアントへ、またはサーバーからサーバーへと移動する場合もあります。フェデレーション計算は、タイプシグネチャが一致する限り、通常の関数のように構成することもできます。

開発をサポートするために、TFFではtff.federated_computationをPython関数として呼び出すことができます。たとえば、

get_average_temperature([68.5, 70.3, 69.8])
69.53334

非熱心な計算とTensorFlow

注意すべき2つの重要な制限があります。まず、Pythonインタープリターがtff.federated_computationデコレーターに遭遇すると、関数は1回トレースされ、将来の使用のためにシリアル化されます。フェデレーションラーニングは分散型であるため、この将来の使用は、リモート実行環境など、他の場所で発生する可能性があります。したがって、TFFの計算は基本的に熱心ではありません。この動作は、 tf.functionデコレータの動作と多少似ています。

次に、フェデレーション計算はフェデレーション演算子( tff.federated_meanなど)のみで構成でき、TensorFlow操作を含めることはできません。 TensorFlowコードは、 tff.tf_computation装飾されたブロックに限定する必要があります。ほとんどの通常のTensorFlowコードは、数値を取り、それに0.5を追加する次の関数のように、直接装飾できます。

@tff.tf_computation(tf.float32)
def add_half(x):
  return tf.add(x, 0.5)

これらにもタイプシグネチャがありますが、配置はありませ。たとえば、

str(add_half.type_signature)
'(float32 -> float32)'

ここでは、 tff.federated_computationtff.tf_computation重要な違いがtff.tf_computationます。前者には明示的な配置がありますが、後者にはありません。

配置を指定することにより、フェデレーション計算でtff.tf_computationブロックを使用できます。半分を追加する関数を作成しましょう。ただし、クライアントのフェデレーションフロートにのみ追加します。これを行うには、 tff.federated_mapを使用しtff.federated_map 。これは、配置を維持しながら、特定のtff.tf_computationを適用します。

@tff.federated_computation(tff.FederatedType(tf.float32, tff.CLIENTS))
def add_half_on_clients(x):
  return tff.federated_map(add_half, x)

この関数は、 tff.CLIENTSに配置された値のみを受け入れ、同じ配置の値を返すことを除いて、 add_halfとほぼ同じです。これは、タイプシグネチャで確認できます。

str(add_half_on_clients.type_signature)
'({float32}@CLIENTS -> {float32}@CLIENTS)'

要約すれば:

  • TFFはフェデレーション値で動作します。
  • 各連合値はタイプ(例えば。で、連合型を持つtf.float32 )と配置(例えば。 tff.CLIENTS )。
  • フェデレーション値は、フェデレーション計算を使用して変換できます。フェデレーション計算はtff.federated_computationとフェデレーション型シグネチャでtff.federated_computationする必要があります。
  • TensorFlowコードは、 tff.tf_computationデコレータをtff.tf_computationしてブロックにtff.tf_computation必要があります。
  • これらのブロックは、フェデレーション計算に組み込むことができます。

独自の連合学習アルゴリズムの構築、再検討

フェデレーションコアを垣間見ることができたので、独自のフェデレーション学習アルゴリズムを構築できます。上記で、アルゴリズムにinitialize_fnnext_fnを定義したことをnext_fnてください。 next_fnは、純粋なTensorFlowコードを使用して定義したclient_updateserver_updateを利用します。

しかし、私たちのアルゴリズム連合計算を行うために、我々は両方が必要になりますnext_fninitialize_fnもそれぞれにtff.federated_computation

TensorFlowフェデレーションブロック

初期化計算の作成

初期化関数は非常に単純ですmodel_fnを使用してモデルを作成します。ただし、tff.tf_computationを使用してTensorFlowコードを分離する必要があることに注意してtff.tf_computation

@tff.tf_computation
def server_init():
  model = model_fn()
  return model.trainable_variables

次に、 tff.federated_valueを使用して、これをフェデレーション計算に直接渡すことができます。

@tff.federated_computation
def initialize_fn():
  return tff.federated_value(server_init(), tff.SERVER)

next_fn作成

ここで、クライアントとサーバーの更新コードを使用して、実際のアルゴリズムを記述します。まず、 client_updatetff.tf_computationします。これは、クライアントデータセットとサーバーの重みを受け入れ、更新されたクライアントの重みテンソルを出力します。

関数を適切に装飾するには、対応するタイプが必要になります。幸い、サーバーの重みのタイプは、モデルから直接抽出できます。

dummy_model = model_fn()
tf_dataset_type = tff.SequenceType(dummy_model.input_spec)

データセットタイプのシグネチャを見てみましょう。 28 x 28の画像(整数ラベル付き)を取り、それらを平坦化したことを思い出してください。

str(tf_dataset_type)
'<float32[?,784],int32[?,1]>*'

上記のserver_init関数を使用して、モデルの重みタイプを抽出することもできます。

model_weights_type = server_init.type_signature.result

タイプシグニチャを調べると、モデルのアーキテクチャを確認できます。

str(model_weights_type)
'<float32[784,10],float32[10]>'

tff.tf_computationで、クライアント更新用のtff.tf_computationを作成できます。

@tff.tf_computation(tf_dataset_type, model_weights_type)
def client_update_fn(tf_dataset, server_weights):
  model = model_fn()
  client_optimizer = tf.keras.optimizers.SGD(learning_rate=0.01)
  return client_update(model, tf_dataset, server_weights, client_optimizer)

サーバー更新のtff.tf_computationバージョンは、すでに抽出した型を使用して、同様の方法で定義できます。

@tff.tf_computation(model_weights_type)
def server_update_fn(mean_client_weights):
  model = model_fn()
  return server_update(model, mean_client_weights)

最後になりましたが、これをすべてまとめるtff.federated_computationを作成する必要があります。この関数は、2つのフェデレーション値を受け入れます。1つはサーバーの重み(配置tff.SERVER )に対応し、もう1つはクライアントデータセット(配置tff.CLIENTS )に対応します。

これらのタイプは両方とも上記で定義されていることに注意してください。 tff.FederatedTypeを使用して、適切な配置を与える必要があります。

federated_server_type = tff.FederatedType(model_weights_type, tff.SERVER)
federated_dataset_type = tff.FederatedType(tf_dataset_type, tff.CLIENTS)

FLアルゴリズムの4つの要素を覚えていますか?

  1. サーバーからクライアントへのブロードキャストステップ。
  2. ローカルクライアントの更新手順。
  3. クライアントからサーバーへのアップロード手順。
  4. サーバーの更新手順。

上記を構築したので、各部分を1行のTFFコードとしてコンパクトに表すことができます。この単純さから、フェデレーションタイプなどを指定するために特別な注意を払う必要がありました。

@tff.federated_computation(federated_server_type, federated_dataset_type)
def next_fn(server_weights, federated_dataset):
  # Broadcast the server weights to the clients.
  server_weights_at_client = tff.federated_broadcast(server_weights)

  # Each client computes their updated weights.
  client_weights = tff.federated_map(
      client_update_fn, (federated_dataset, server_weights_at_client))

  # The server averages these updates.
  mean_client_weights = tff.federated_mean(client_weights)

  # The server updates its model.
  server_weights = tff.federated_map(server_update_fn, mean_client_weights)

  return server_weights

これで、アルゴリズムの初期化とアルゴリズムの1つのステップの実行の両方にtff.federated_computationました。アルゴリズムを終了するには、これらをtff.templates.IterativeProcessます。

federated_algorithm = tff.templates.IterativeProcess(
    initialize_fn=initialize_fn,
    next_fn=next_fn
)

反復プロセスのinitialize関数とnext関数の型シグネチャを見てみましょう。

str(federated_algorithm.initialize.type_signature)
'( -> <float32[784,10],float32[10]>@SERVER)'

これは、 federated_algorithm.initializeが(784行10列の重み行列と10個のバイアス単位を持つ)単層モデルを返す引数なしの関数であるという事実を反映しています。

str(federated_algorithm.next.type_signature)
'(<<float32[784,10],float32[10]>@SERVER,{<float32[?,784],int32[?,1]>*}@CLIENTS> -> <float32[784,10],float32[10]>@SERVER)'

ここでは、 federated_algorithm.nextがサーバーモデルとクライアントデータを受け入れ、更新されたサーバーモデルを返すことがわかります。

アルゴリズムの評価

いくつかのラウンドを実行して、損失がどのように変化するかを見てみましょう。最初に、2番目のチュートリアルで説明した集中型アプローチを使用して評価関数を定義します。

最初に一元化された評価データセットを作成し、次にトレーニングデータに使用したのと同じ前処理を適用します。

計算効率の理由から最初の1000要素のみをtakeしますが、通常はテストデータセット全体を使用することに注意してください。

central_emnist_test = emnist_test.create_tf_dataset_from_all_clients().take(1000)
central_emnist_test = preprocess(central_emnist_test)

次に、サーバーの状態を受け入れ、Kerasを使用してテストデータセットを評価する関数を記述します。 tf.Kerasに精通している場合、 set_weightsの使用に注意してくださいが、これはすべて馴染みがあるようにset_weightsます。

def evaluate(server_state):
  keras_model = create_keras_model()
  keras_model.compile(
      loss=tf.keras.losses.SparseCategoricalCrossentropy(),
      metrics=[tf.keras.metrics.SparseCategoricalAccuracy()]  
  )
  keras_model.set_weights(server_state)
  keras_model.evaluate(central_emnist_test)

それでは、アルゴリズムを初期化して、テストセットで評価してみましょう。

server_state = federated_algorithm.initialize()
evaluate(server_state)
50/50 [==============================] - 0s 2ms/step - loss: 2.3026 - sparse_categorical_accuracy: 0.0910

数ラウンドトレーニングして、何かが変わるかどうか見てみましょう。

for round in range(15):
  server_state = federated_algorithm.next(server_state, federated_train_data)
evaluate(server_state)
50/50 [==============================] - 0s 1ms/step - loss: 2.1706 - sparse_categorical_accuracy: 0.2440

損失関数がわずかに減少していることがわかります。ジャンプは小さいですが、10回のトレーニングラウンドのみを実行し、クライアントの小さなサブセットで実行しました。より良い結果を得るには、数千回ではないにしても数百回のラウンドを行う必要があるかもしれません。

アルゴリズムの変更

この時点で、立ち止まって、私たちが達成したことについて考えてみましょう。純粋なTensorFlowコード(クライアントとサーバーの更新用)をTFFのフェデレーションコアからのフェデレーション計算と組み合わせることにより、フェデレーション平均化を直接実装しました。

より洗練された学習を実行するために、上記の内容を変更するだけです。特に、上記の純粋なTFコードを編集することで、クライアントがトレーニングを実行する方法、またはサーバーがモデルを更新する方法を変更できます。

課題: client_update関数にグラデーションクリッピングを追加します。

より大きな変更を加えたい場合は、サーバーにさらに多くのデータを保存してブロードキャストさせることもできます。たとえば、サーバーはクライアントの学習率を保存し、時間の経過とともに減衰させることもできます。これには、上記のtff.tf_computation呼び出しで使用される型シグネチャの変更が必要になることに注意してください。

より難しい課題:クライアントで学習率の低下を伴うフェデレーション平均を実装します。

この時点で、このフレームワークで実装できるものにどれほどの柔軟性があるかを理解し始めるかもしれません。アイデア(上記のより難しい課題への回答を含む)については、 tff.learning.build_federated_averaging_processソースコードを参照するか、TFFを使用してさまざまな研究プロジェクトを確認してください。