![]() | ![]() | ![]() | ![]() |
始める前に
開始する前に、以下を実行して、環境が正しくセットアップされていることを確認してください。あいさつが表示されない場合は、インストールガイドを参照して手順を確認してください。
!pip install --quiet --upgrade tensorflow-federated-nightly
!pip install --quiet --upgrade nest-asyncio
import nest_asyncio
nest_asyncio.apply()
import collections
import attr
import functools
import numpy as np
import tensorflow as tf
import tensorflow_federated as tff
np.random.seed(0)
画像分類とテキスト生成のチュートリアルでは、フェデレーションラーニング(FL)のモデルとデータパイプラインを設定する方法を学び、TFFのtff.learning
レイヤーを介してフェデレーショントレーニングを実行しました。
これは、FLの研究に関しては氷山の一角にすぎません。このチュートリアルでは、 tff.learning
使用せずにフェデレーション学習アルゴリズムを実装する方法について説明します。私たちは以下を達成することを目指しています:
目標:
- 連合学習アルゴリズムの一般的な構造を理解します。
- TFFのフェデレーションコアを探索してください。
- Federated Coreを使用して、FederatedAveragingを直接実装します。
このチュートリアルは自己完結型ですが、最初に画像分類とテキスト生成のチュートリアルを読むことをお勧めします。
入力データの準備
まず、TFFに含まれているEMNISTデータセットをロードして前処理します。詳細については、画像分類チュートリアルを参照してください。
emnist_train, emnist_test = tff.simulation.datasets.emnist.load_data()
データセットをモデルにフィードするために、データをフラット化し、各例を形式(flattened_image_vector, label)
タプルに変換します。
NUM_CLIENTS = 10
BATCH_SIZE = 20
def preprocess(dataset):
def batch_format_fn(element):
"""Flatten a batch of EMNIST data and return a (features, label) tuple."""
return (tf.reshape(element['pixels'], [-1, 784]),
tf.reshape(element['label'], [-1, 1]))
return dataset.batch(BATCH_SIZE).map(batch_format_fn)
ここで、少数のクライアントをサンプリングし、上記の前処理をそれらのデータセットに適用します。
client_ids = np.random.choice(emnist_train.client_ids, size=NUM_CLIENTS, replace=False)
federated_train_data = [preprocess(emnist_train.create_tf_dataset_for_client(x))
for x in client_ids
]
モデルの準備
画像分類チュートリアルと同じモデルを使用します。このモデル( tf.keras
を介してtf.keras
)には、単一の非表示レイヤーがあり、その後にtf.keras
レイヤーが続きます。
def create_keras_model():
return tf.keras.models.Sequential([
tf.keras.layers.Input(shape=(784,)),
tf.keras.layers.Dense(10, kernel_initializer='zeros'),
tf.keras.layers.Softmax(),
])
このモデルをTFFで使用するために、 tff.learning.Model
モデルをtff.learning.Model
としてラップします。これにより、TFF内でモデルのフォワードパスを実行し、 モデル出力を抽出できます。詳細については、画像分類チュートリアルも参照してください。
def model_fn():
keras_model = create_keras_model()
return tff.learning.from_keras_model(
keras_model,
input_spec=federated_train_data[0].element_spec,
loss=tf.keras.losses.SparseCategoricalCrossentropy(),
metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])
tf.keras
を使用してtf.keras
を作成しtff.learning.Model
、TFFははるかに一般的なモデルをサポートしています。これらのモデルには、モデルの重みを取得する次の関連属性があります。
-
trainable_variables
:訓練可能な層に対応するテンソルの反復可能。 -
non_trainable_variables
:訓練不可能な層に対応するテンソルの反復可能。
ここでは、 trainable_variables
のみを使用します。 (私たちのモデルにはそれらしかありません!)。
独自の連合学習アルゴリズムの構築
tff.learning
APIを使用すると、フェデレーション平均化の多くのバリアントを作成できますが、このフレームワークにうまく適合しない他のフェデレーションアルゴリズムがあります。たとえば、正則化、クリッピング、またはフェデレーションGANトレーニングなどのより複雑なアルゴリズムを追加したい場合があります。代わりに、フェデレーション分析に興味があるかもしれません。
これらのより高度なアルゴリズムでは、TFFを使用して独自のカスタムアルゴリズムを作成する必要があります。多くの場合、フェデレーションアルゴリズムには4つの主要なコンポーネントがあります。
- サーバーからクライアントへのブロードキャストステップ。
- ローカルクライアントの更新手順。
- クライアントからサーバーへのアップロード手順。
- サーバーの更新手順。
TFFでは、通常、フェデレーションアルゴリズムをtff.templates.IterativeProcess
(全体を通して単にIterativeProcess
とtff.templates.IterativeProcess
ます)として表します。これは、 initialize
関数とnext
関数を含むクラスです。ここで、 initialize
はサーバーを初期化するために使用され、 next
はフェデレーションアルゴリズムの1回の通信ラウンドを実行します。 FedAvgの反復プロセスがどのようになるかについてのスケルトンを書いてみましょう。
まず、 tff.learning.Model
作成し、そのトレーニング可能な重みを返す初期化関数があります。
def initialize_fn():
model = model_fn()
return model.trainable_variables
この関数は見栄えがしますが、後で説明するように、「TFF計算」にするために小さな変更を加える必要があります。
next_fn
もスケッチしたいとnext_fn
ます。
def next_fn(server_weights, federated_dataset):
# Broadcast the server weights to the clients.
server_weights_at_client = broadcast(server_weights)
# Each client computes their updated weights.
client_weights = client_update(federated_dataset, server_weights_at_client)
# The server averages these updates.
mean_client_weights = mean(client_weights)
# The server updates its model.
server_weights = server_update(mean_client_weights)
return server_weights
これらの4つのコンポーネントを個別に実装することに焦点を当てます。まず、純粋なTensorFlowで実装できる部分、つまりクライアントとサーバーの更新手順に焦点を当てます。
TensorFlowブロック
クライアントの更新
tff.learning.Model
を使用して、TensorFlowモデルをトレーニングするのと基本的に同じ方法でクライアントトレーニングを実行します。特に、我々は、使用するtf.GradientTape
、その後使用してこれらの勾配を適用し、データのバッチで勾配を計算するためにclient_optimizer
。トレーニング可能なウェイトのみに焦点を当てています。
@tf.function
def client_update(model, dataset, server_weights, client_optimizer):
"""Performs training (using the server model weights) on the client's dataset."""
# Initialize the client model with the current server weights.
client_weights = model.trainable_variables
# Assign the server weights to the client model.
tf.nest.map_structure(lambda x, y: x.assign(y),
client_weights, server_weights)
# Use the client_optimizer to update the local model.
for batch in dataset:
with tf.GradientTape() as tape:
# Compute a forward pass on the batch of data
outputs = model.forward_pass(batch)
# Compute the corresponding gradient
grads = tape.gradient(outputs.loss, client_weights)
grads_and_vars = zip(grads, client_weights)
# Apply the gradient using a client optimizer.
client_optimizer.apply_gradients(grads_and_vars)
return client_weights
サーバーの更新
FedAvgのサーバー更新は、クライアント更新よりも簡単です。サーバーモデルの重みをクライアントモデルの重みの平均に置き換えるだけの「バニラ」フェデレーション平均を実装します。繰り返しになりますが、トレーニング可能なウェイトのみに焦点を当てています。
@tf.function
def server_update(model, mean_client_weights):
"""Updates the server model weights as the average of the client model weights."""
model_weights = model.trainable_variables
# Assign the mean client weights to the server model.
tf.nest.map_structure(lambda x, y: x.assign(y),
model_weights, mean_client_weights)
return model_weights
スニペットは、 mean_client_weights
返すだけで簡略化できます。ただし、Federated mean_client_weights
より高度な実装では、勢いや適応性などのより高度な手法でmean_client_weights
を使用します。
課題:サーバーの重みをmodel_weightsとmean_client_weightsの中間点になるように更新するバージョンのserver_update
を実装します。 (注:この種の「中間」アプローチは、先読みオプティマイザーに関する最近の作業に類似しています!)。
これまでのところ、純粋なTensorFlowコードのみを記述しました。 TFFを使用すると、既に使い慣れているTensorFlowコードの多くを使用できるため、これは仕様によるものです。ただし、ここで、オーケストレーションロジック、つまり、サーバーがクライアントにブロードキャストするものと、クライアントがサーバーにアップロードするものを指示するロジックを指定する必要があります。
これには、TFFのフェデレーションコアが必要になります。
フェデレーションコアの概要
Federated Core(FC)は、 tff.learning
基盤として機能する低レベルのインターフェイスのセットです。ただし、これらのインターフェースは学習に限定されません。実際、分散データの分析やその他の多くの計算に使用できます。
大まかに言えば、フェデレーションコアは、コンパクトに表現されたプログラムロジックがTensorFlowコードを分散通信演算子(分散合計やブロードキャストなど)と組み合わせることができるようにする開発環境です。目標は、システム実装の詳細(ポイントツーポイントネットワークメッセージ交換の指定など)を必要とせずに、研究者や実務家がシステム内の分散通信を明示的に制御できるようにすることです。
重要な点の1つは、TFFがプライバシー保護のために設計されていることです。したがって、データが存在する場所を明示的に制御して、中央のサーバーの場所にデータが不要に蓄積されるのを防ぐことができます。
連合データ
TFFの重要な概念は「連合データ」です。これは、分散システム内のデバイスのグループ全体でホストされているデータ項目のコレクションを指します(クライアントデータセットやサーバーモデルの重みなど)。すべてのデバイスにわたるデータ項目のコレクション全体を単一のフェデレーション値としてモデル化します。
たとえば、センサーの温度を表すフロートがそれぞれにあるクライアントデバイスがあるとします。によってフェデレーションフロートとして表すことができます
federated_float_on_clients = tff.FederatedType(tf.float32, tff.CLIENTS)
フェデレーションタイプは、そのメンバー構成要素のタイプT
( tf.float32
)とデバイスのグループG
によって指定されます。 G
がtff.CLIENTS
またはtff.SERVER
いずれかである場合に焦点を当てます。このようなフェデレーションタイプは、以下に示すように{T}@G
として表されます。
str(federated_float_on_clients)
'{float32}@CLIENTS'
なぜ私たちは配置をそれほど気にするのですか? TFFの主な目標は、実際の分散システムにデプロイできるコードを記述できるようにすることです。これは、デバイスのどのサブセットがどのコードを実行し、さまざまなデータがどこにあるかを推論することが重要であることを意味します。
TFFは三つのことに焦点を当て:データ、データが配置され、データがどのように変換されます。最初の2つはフェデレーション型にカプセル化され、最後の2つはフェデレーション計算にカプセル化されます。
連合計算
TFFは、基本単位がフェデレーション計算である、強く型付けされた関数型プログラミング環境です。これらは、フェデレーション値を入力として受け入れ、フェデレーション値を出力として返すロジックの一部です。
たとえば、クライアントセンサーの温度を平均したいとします。以下を定義できます(フェデレーションフロートを使用)。
@tff.federated_computation(tff.FederatedType(tf.float32, tff.CLIENTS))
def get_average_temperature(client_temperatures):
return tff.federated_mean(client_temperatures)
これはtf.function
デコレータとどう違うのですか?重要な答えは、 tff.federated_computation
によって生成されたコードはTensorFlowでもPythonコードでもないということです。これは、内部プラットフォームに依存しないグルー言語での分散システムの仕様です。
これは複雑に聞こえるかもしれませんが、TFF計算は、明確に定義された型シグネチャを持つ関数と考えることができます。これらのタイプシグネチャは直接クエリできます。
str(get_average_temperature.type_signature)
'({float32}@CLIENTS -> float32@SERVER)'
このtff.federated_computation
、連合型の引数を受け付け{float32}@CLIENTS
、そして連合型の戻り値{float32}@SERVER
。フェデレーション計算は、サーバーからクライアントへ、クライアントからクライアントへ、またはサーバーからサーバーへと移動する場合もあります。フェデレーション計算は、タイプシグネチャが一致する限り、通常の関数のように構成することもできます。
開発をサポートするために、TFFではtff.federated_computation
をPython関数として呼び出すことができます。たとえば、
get_average_temperature([68.5, 70.3, 69.8])
69.53334
非熱心な計算とTensorFlow
注意すべき2つの重要な制限があります。まず、Pythonインタープリターがtff.federated_computation
デコレーターに遭遇すると、関数は1回トレースされ、将来の使用のためにシリアル化されます。フェデレーションラーニングは分散型であるため、この将来の使用は、リモート実行環境など、他の場所で発生する可能性があります。したがって、TFFの計算は基本的に熱心ではありません。この動作は、 tf.function
デコレータの動作と多少似ています。
次に、フェデレーション計算はフェデレーション演算子( tff.federated_mean
など)のみで構成でき、TensorFlow操作を含めることはできません。 TensorFlowコードは、 tff.tf_computation
装飾されたブロックに限定する必要があります。ほとんどの通常のTensorFlowコードは、数値を取り、それに0.5
を追加する次の関数のように、直接装飾できます。
@tff.tf_computation(tf.float32)
def add_half(x):
return tf.add(x, 0.5)
これらにもタイプシグネチャがありますが、配置はありません。たとえば、
str(add_half.type_signature)
'(float32 -> float32)'
ここでは、 tff.federated_computation
とtff.tf_computation
重要な違いがtff.tf_computation
ます。前者には明示的な配置がありますが、後者にはありません。
配置を指定することにより、フェデレーション計算でtff.tf_computation
ブロックを使用できます。半分を追加する関数を作成しましょう。ただし、クライアントのフェデレーションフロートにのみ追加します。これを行うには、 tff.federated_map
を使用しtff.federated_map
。これは、配置を維持しながら、特定のtff.tf_computation
を適用します。
@tff.federated_computation(tff.FederatedType(tf.float32, tff.CLIENTS))
def add_half_on_clients(x):
return tff.federated_map(add_half, x)
この関数は、 tff.CLIENTS
に配置された値のみを受け入れ、同じ配置の値を返すことを除いて、 add_half
とほぼ同じです。これは、タイプシグネチャで確認できます。
str(add_half_on_clients.type_signature)
'({float32}@CLIENTS -> {float32}@CLIENTS)'
要約すれば:
- TFFはフェデレーション値で動作します。
- 各連合値はタイプ(例えば。で、連合型を持つ
tf.float32
)と配置(例えば。tff.CLIENTS
)。 - フェデレーション値は、フェデレーション計算を使用して変換できます。フェデレーション計算は、
tff.federated_computation
とフェデレーション型シグネチャでtff.federated_computation
する必要があります。 - TensorFlowコードは、
tff.tf_computation
デコレータをtff.tf_computation
してブロックにtff.tf_computation
必要があります。 - これらのブロックは、フェデレーション計算に組み込むことができます。
独自の連合学習アルゴリズムの構築、再検討
フェデレーションコアを垣間見ることができたので、独自のフェデレーション学習アルゴリズムを構築できます。上記で、アルゴリズムにinitialize_fn
とnext_fn
を定義したことをnext_fn
てください。 next_fn
は、純粋なTensorFlowコードを使用して定義したclient_update
とserver_update
を利用します。
しかし、私たちのアルゴリズム連合計算を行うために、我々は両方が必要になりますnext_fn
とinitialize_fn
もそれぞれにtff.federated_computation
。
TensorFlowフェデレーションブロック
初期化計算の作成
初期化関数は非常に単純ですmodel_fn
を使用してモデルを作成します。ただし、tff.tf_computationを使用してTensorFlowコードを分離する必要があることに注意してtff.tf_computation
。
@tff.tf_computation
def server_init():
model = model_fn()
return model.trainable_variables
次に、 tff.federated_value
を使用して、これをフェデレーション計算に直接渡すことができます。
@tff.federated_computation
def initialize_fn():
return tff.federated_value(server_init(), tff.SERVER)
next_fn
作成
ここで、クライアントとサーバーの更新コードを使用して、実際のアルゴリズムを記述します。まず、 client_update
をtff.tf_computation
します。これは、クライアントデータセットとサーバーの重みを受け入れ、更新されたクライアントの重みテンソルを出力します。
関数を適切に装飾するには、対応するタイプが必要になります。幸い、サーバーの重みのタイプは、モデルから直接抽出できます。
dummy_model = model_fn()
tf_dataset_type = tff.SequenceType(dummy_model.input_spec)
データセットタイプのシグネチャを見てみましょう。 28 x 28の画像(整数ラベル付き)を取り、それらを平坦化したことを思い出してください。
str(tf_dataset_type)
'<float32[?,784],int32[?,1]>*'
上記のserver_init
関数を使用して、モデルの重みタイプを抽出することもできます。
model_weights_type = server_init.type_signature.result
タイプシグニチャを調べると、モデルのアーキテクチャを確認できます。
str(model_weights_type)
'<float32[784,10],float32[10]>'
tff.tf_computation
で、クライアント更新用のtff.tf_computation
を作成できます。
@tff.tf_computation(tf_dataset_type, model_weights_type)
def client_update_fn(tf_dataset, server_weights):
model = model_fn()
client_optimizer = tf.keras.optimizers.SGD(learning_rate=0.01)
return client_update(model, tf_dataset, server_weights, client_optimizer)
サーバー更新のtff.tf_computation
バージョンは、すでに抽出した型を使用して、同様の方法で定義できます。
@tff.tf_computation(model_weights_type)
def server_update_fn(mean_client_weights):
model = model_fn()
return server_update(model, mean_client_weights)
最後になりましたが、これをすべてまとめるtff.federated_computation
を作成する必要があります。この関数は、2つのフェデレーション値を受け入れます。1つはサーバーの重み(配置tff.SERVER
)に対応し、もう1つはクライアントデータセット(配置tff.CLIENTS
)に対応します。
これらのタイプは両方とも上記で定義されていることに注意してください。 tff.FederatedType
を使用して、適切な配置を与える必要があります。
federated_server_type = tff.FederatedType(model_weights_type, tff.SERVER)
federated_dataset_type = tff.FederatedType(tf_dataset_type, tff.CLIENTS)
FLアルゴリズムの4つの要素を覚えていますか?
- サーバーからクライアントへのブロードキャストステップ。
- ローカルクライアントの更新手順。
- クライアントからサーバーへのアップロード手順。
- サーバーの更新手順。
上記を構築したので、各部分を1行のTFFコードとしてコンパクトに表すことができます。この単純さから、フェデレーションタイプなどを指定するために特別な注意を払う必要がありました。
@tff.federated_computation(federated_server_type, federated_dataset_type)
def next_fn(server_weights, federated_dataset):
# Broadcast the server weights to the clients.
server_weights_at_client = tff.federated_broadcast(server_weights)
# Each client computes their updated weights.
client_weights = tff.federated_map(
client_update_fn, (federated_dataset, server_weights_at_client))
# The server averages these updates.
mean_client_weights = tff.federated_mean(client_weights)
# The server updates its model.
server_weights = tff.federated_map(server_update_fn, mean_client_weights)
return server_weights
これで、アルゴリズムの初期化とアルゴリズムの1つのステップの実行の両方にtff.federated_computation
ました。アルゴリズムを終了するには、これらをtff.templates.IterativeProcess
ます。
federated_algorithm = tff.templates.IterativeProcess(
initialize_fn=initialize_fn,
next_fn=next_fn
)
反復プロセスのinitialize
関数とnext
関数の型シグネチャを見てみましょう。
str(federated_algorithm.initialize.type_signature)
'( -> <float32[784,10],float32[10]>@SERVER)'
これは、 federated_algorithm.initialize
が(784行10列の重み行列と10個のバイアス単位を持つ)単層モデルを返す引数なしの関数であるという事実を反映しています。
str(federated_algorithm.next.type_signature)
'(<<float32[784,10],float32[10]>@SERVER,{<float32[?,784],int32[?,1]>*}@CLIENTS> -> <float32[784,10],float32[10]>@SERVER)'
ここでは、 federated_algorithm.next
がサーバーモデルとクライアントデータを受け入れ、更新されたサーバーモデルを返すことがわかります。
アルゴリズムの評価
いくつかのラウンドを実行して、損失がどのように変化するかを見てみましょう。最初に、2番目のチュートリアルで説明した集中型アプローチを使用して評価関数を定義します。
最初に一元化された評価データセットを作成し、次にトレーニングデータに使用したのと同じ前処理を適用します。
計算効率の理由から最初の1000要素のみをtake
しますが、通常はテストデータセット全体を使用することに注意してください。
central_emnist_test = emnist_test.create_tf_dataset_from_all_clients().take(1000)
central_emnist_test = preprocess(central_emnist_test)
次に、サーバーの状態を受け入れ、Kerasを使用してテストデータセットを評価する関数を記述します。 tf.Keras
に精通している場合、 set_weights
の使用に注意してくださいが、これはすべて馴染みがあるようにset_weights
ます。
def evaluate(server_state):
keras_model = create_keras_model()
keras_model.compile(
loss=tf.keras.losses.SparseCategoricalCrossentropy(),
metrics=[tf.keras.metrics.SparseCategoricalAccuracy()]
)
keras_model.set_weights(server_state)
keras_model.evaluate(central_emnist_test)
それでは、アルゴリズムを初期化して、テストセットで評価してみましょう。
server_state = federated_algorithm.initialize()
evaluate(server_state)
50/50 [==============================] - 0s 2ms/step - loss: 2.3026 - sparse_categorical_accuracy: 0.0910
数ラウンドトレーニングして、何かが変わるかどうか見てみましょう。
for round in range(15):
server_state = federated_algorithm.next(server_state, federated_train_data)
evaluate(server_state)
50/50 [==============================] - 0s 1ms/step - loss: 2.1706 - sparse_categorical_accuracy: 0.2440
損失関数がわずかに減少していることがわかります。ジャンプは小さいですが、10回のトレーニングラウンドのみを実行し、クライアントの小さなサブセットで実行しました。より良い結果を得るには、数千回ではないにしても数百回のラウンドを行う必要があるかもしれません。
アルゴリズムの変更
この時点で、立ち止まって、私たちが達成したことについて考えてみましょう。純粋なTensorFlowコード(クライアントとサーバーの更新用)をTFFのフェデレーションコアからのフェデレーション計算と組み合わせることにより、フェデレーション平均化を直接実装しました。
より洗練された学習を実行するために、上記の内容を変更するだけです。特に、上記の純粋なTFコードを編集することで、クライアントがトレーニングを実行する方法、またはサーバーがモデルを更新する方法を変更できます。
課題: client_update
関数にグラデーションクリッピングを追加します。
より大きな変更を加えたい場合は、サーバーにさらに多くのデータを保存してブロードキャストさせることもできます。たとえば、サーバーはクライアントの学習率を保存し、時間の経過とともに減衰させることもできます。これには、上記のtff.tf_computation
呼び出しで使用される型シグネチャの変更が必要になることに注意してください。
より難しい課題:クライアントで学習率の低下を伴うフェデレーション平均を実装します。
この時点で、このフレームワークで実装できるものにどれほどの柔軟性があるかを理解し始めるかもしれません。アイデア(上記のより難しい課題への回答を含む)については、 tff.learning.build_federated_averaging_process
ソースコードを参照するか、TFFを使用してさまざまな研究プロジェクトを確認してください。