このページは Cloud Translation API によって翻訳されました。
Switch to English

画像分類のための連合学習

TensorFlow.orgで表示 GoogleColabで実行 GitHubでソースを表示

このチュートリアルでは、連合学習(FL)TFFのAPI層、導入する古典MNISTトレーニング例を使用tff.learning連合学習課題の一般的なタイプを実行するために使用することができる高レベルのインターフェイスのセット、などの- TensorFlowに実装されたユーザー提供モデルに対するフェデレーショントレーニング。

このチュートリアルとFederatedLearning APIは、主に独自のTensorFlowモデルをTFFにプラグインし、TFFをほとんどブラックボックスとして扱いたいユーザーを対象としています。 TFFの詳細と、独自のフェデレーション学習アルゴリズムの実装方法については、FC CoreAPIのチュートリアル-カスタムフェデレーションアルゴリズムパート1およびパート2を参照してください。

tff.learning詳細については、テキスト生成のためフェデレーションラーニングのチュートリアルを続けてください。このチュートリアルでは、反復モデルをカバーするだけでなく、事前にトレーニングされたシリアル化されたケラスモデルをロードして、ケラスを使用した評価と組み合わせたフェデレーションラーニングで改良する方法も示します。

始める前に

開始する前に、以下を実行して、環境が正しくセットアップされていることを確認してください。あいさつが表示されない場合は、インストールガイドを参照して手順を確認してください。

# tensorflow_federated_nightly also bring in tf_nightly, which
# can causes a duplicate tensorboard install, leading to errors.
!pip uninstall --yes tensorboard tb-nightly

!pip install --quiet --upgrade tensorflow_federated_nightly
!pip install --quiet --upgrade nest_asyncio
!pip install --quiet tb-nightly  # or tensorboard, but not both

import nest_asyncio
nest_asyncio.apply()
%load_ext tensorboard
Fetching TensorBoard MPM... done.

import collections

import numpy as np
import tensorflow as tf
import tensorflow_federated as tff

np.random.seed(0)

tff.federated_computation(lambda: 'Hello, World!')()
b'Hello, World!'

入力データの準備

データから始めましょう。連合学習には、連合データセット、つまり複数のユーザーからのデータの収集が必要です。連合データは通常、非iidであり、固有の一連の課題があります。

実験を容易にするために、いくつかのデータセットをTFFリポジトリにシードしました。これには、 Leafを使用して再処理された元のNISTデータセットのバージョンを含むMNISTのフェデレーションバージョンが含まれ、データは数字。各ライターには固有のスタイルがあるため、このデータセットは、フェデレーションデータセットに期待される種類の非iid動作を示します。

ロードする方法は次のとおりです。

emnist_train, emnist_test = tff.simulation.datasets.emnist.load_data()

load_data()によって返されるデータセットは、 load_data()インスタンスです。 tff.simulation.ClientDataは、ユーザーのセットを列挙したり、特定のユーザーのデータを表すtf.data.Datasetを作成したり、クエリを実行したりできるインターフェイスです。個々の要素の構造。このインターフェースを使用してデータセットのコンテンツを探索する方法は次のとおりです。このインターフェースではクライアントIDを反復処理できますが、これはシミュレーションデータの機能にすぎないことに注意してください。すぐにわかるように、クライアントIDは、フェデレーション学習フレームワークでは使用されません。その唯一の目的は、シミュレーション用のデータのサブセットを選択できるようにすることです。

len(emnist_train.client_ids)
3383
emnist_train.element_type_structure
OrderedDict([('pixels', TensorSpec(shape=(28, 28), dtype=tf.float32, name=None)), ('label', TensorSpec(shape=(), dtype=tf.int32, name=None))])
example_dataset = emnist_train.create_tf_dataset_for_client(
    emnist_train.client_ids[0])

example_element = next(iter(example_dataset))

example_element['label'].numpy()
1
from matplotlib import pyplot as plt
plt.imshow(example_element['pixels'].numpy(), cmap='gray', aspect='equal')
plt.grid(False)
_ = plt.show()

png

連合データの異質性の調査

連合データは通常、非iidであり、ユーザーは通常、使用パターンに応じてデータの分布が異なります。一部のクライアントは、ローカルでデータが不足しているため、デバイスでのトレーニング例が少ない場合がありますが、一部のクライアントは、十分な数のトレーニング例を持っている場合があります。利用可能なEMNISTデータを使用して、連合システムに典型的なデータの不均一性のこの概念を調べてみましょう。これはすべてのデータがローカルで利用できるシミュレーション環境であるため、クライアントのデータのこの詳細な分析は私たちだけが利用できることに注意することが重要です。実際の本番フェデレーション環境では、単一のクライアントのデータを検査することはできません。

まず、1つのクライアントのデータのサンプリングを取得して、1つのシミュレートされたデバイスでの例の感触をつかみましょう。使用しているデータセットは一意のライターによってキー設定されているため、1人のクライアントのデータは、1人のユーザーの一意の「使用パターン」をシミュレートする0から9までの数字のサンプルに対する1人の手書きを表します。

## Example MNIST digits for one client
figure = plt.figure(figsize=(20, 4))
j = 0

for example in example_dataset.take(40):
  plt.subplot(4, 10, j+1)
  plt.imshow(example['pixels'].numpy(), cmap='gray', aspect='equal')
  plt.axis('off')
  j += 1

png

次に、各MNISTディジットラベルの各クライアントの例の数を視覚化してみましょう。フェデレーション環境では、各クライアントの例の数は、ユーザーの動作に応じてかなり異なる場合があります。

# Number of examples per layer for a sample of clients
f = plt.figure(figsize=(12, 7))
f.suptitle('Label Counts for a Sample of Clients')
for i in range(6):
  client_dataset = emnist_train.create_tf_dataset_for_client(
      emnist_train.client_ids[i])
  plot_data = collections.defaultdict(list)
  for example in client_dataset:
    # Append counts individually per label to make plots
    # more colorful instead of one color per plot.
    label = example['label'].numpy()
    plot_data[label].append(label)
  plt.subplot(2, 3, i+1)
  plt.title('Client {}'.format(i))
  for j in range(10):
    plt.hist(
        plot_data[j],
        density=False,
        bins=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10])

png

次に、各MNISTラベルのクライアントごとの平均画像を視覚化します。このコードは、1つのラベルのすべてのユーザーの例の各ピクセル値の平均を生成します。あるクライアントの1桁の平均画像は、各人の独自の手書きスタイルにより、同じ桁の別のクライアントの平均画像とは異なって見えることがわかります。そのローカルラウンドでのユーザー独自のデータから学習しているため、各ローカルトレーニングラウンドが各クライアントでモデルを異なる方向にナッジする方法について考えることができます。チュートリアルの後半では、すべてのクライアントからモデルの各更新を取得し、それらを集約して、クライアント独自の各データから学習した新しいグローバルモデルにする方法を説明します。

# Each client has different mean images, meaning each client will be nudging
# the model in their own directions locally.

for i in range(5):
  client_dataset = emnist_train.create_tf_dataset_for_client(
      emnist_train.client_ids[i])
  plot_data = collections.defaultdict(list)
  for example in client_dataset:
    plot_data[example['label'].numpy()].append(example['pixels'].numpy())
  f = plt.figure(i, figsize=(12, 5))
  f.suptitle("Client #{}'s Mean Image Per Label".format(i))
  for j in range(10):
    mean_img = np.mean(plot_data[j], 0)
    plt.subplot(2, 5, j+1)
    plt.imshow(mean_img.reshape((28, 28)))
    plt.axis('off')

png

png

png

png

png

ユーザーデータはノイズが多く、信頼性の低いラベルが付けられる可能性があります。たとえば、上記のクライアント#2のデータを見ると、ラベル2の場合、よりノイズの多い平均画像を作成するいくつかの誤ったラベルの例があった可能性があることがわかります。

入力データの前処理

データはすでにtf.data.Datasetであるため、データセット変換を使用して前処理を実行できます。ここでは、 28x28画像を784要素の配列にフラット化し、個々の例をシャッフルし、それらをバッチに整理し、 28x28使用するためにフィーチャの名前をpixelslabelからxyに変更します。また、データセットに対してrepeatをスローして、いくつかのエポックを実行します。

NUM_CLIENTS = 10
NUM_EPOCHS = 5
BATCH_SIZE = 20
SHUFFLE_BUFFER = 100
PREFETCH_BUFFER = 10

def preprocess(dataset):

  def batch_format_fn(element):
    """Flatten a batch `pixels` and return the features as an `OrderedDict`."""
    return collections.OrderedDict(
        x=tf.reshape(element['pixels'], [-1, 784]),
        y=tf.reshape(element['label'], [-1, 1]))

  return dataset.repeat(NUM_EPOCHS).shuffle(SHUFFLE_BUFFER).batch(
      BATCH_SIZE).map(batch_format_fn).prefetch(PREFETCH_BUFFER)

これが機能したことを確認しましょう。

preprocessed_example_dataset = preprocess(example_dataset)

sample_batch = tf.nest.map_structure(lambda x: x.numpy(),
                                     next(iter(preprocessed_example_dataset)))

sample_batch
OrderedDict([('x', array([[1., 1., 1., ..., 1., 1., 1.],
       [1., 1., 1., ..., 1., 1., 1.],
       [1., 1., 1., ..., 1., 1., 1.],
       ...,
       [1., 1., 1., ..., 1., 1., 1.],
       [1., 1., 1., ..., 1., 1., 1.],
       [1., 1., 1., ..., 1., 1., 1.]], dtype=float32)), ('y', array([[0],
       [5],
       [0],
       [1],
       [3],
       [0],
       [5],
       [4],
       [1],
       [7],
       [0],
       [4],
       [0],
       [1],
       [7],
       [2],
       [2],
       [0],
       [7],
       [1]], dtype=int32))])

連合データセットを構築するためのほぼすべてのビルディングブロックが用意されています。

シミュレーションで連合データをTFFにフィードする方法の1つは、単純にPythonリストとして、リストの各要素がリストとして、またはtf.data.Datasetとして、個々のユーザーのデータを保持することtf.data.Dataset 。後者を提供するインターフェースがすでにあるので、それを使用しましょう。

これは、トレーニングまたは評価のラウンドへの入力として、指定されたユーザーのセットからデータセットのリストを作成する単純なヘルパー関数です。

def make_federated_data(client_data, client_ids):
  return [
      preprocess(client_data.create_tf_dataset_for_client(x))
      for x in client_ids
  ]

では、どのようにクライアントを選択するのでしょうか。

典型的なフェデレーショントレーニングシナリオでは、非常に多数のユーザーデバイスを扱っている可能性があり、特定の時点でトレーニングに使用できるのはその一部にすぎません。これは、たとえば、クライアントデバイスが、電源に接続されている場合、従量制ネットワークから離れている場合、またはアイドル状態の場合にのみトレーニングに参加する携帯電話の場合です。

もちろん、私たちはシミュレーション環境にあり、すべてのデータはローカルで利用できます。通常、シミュレーションを実行するときは、トレーニングの各ラウンドに関与するクライアントのランダムなサブセットをサンプリングするだけで、通常は各ラウンドで異なります。

とはいえ、 Federated Averagingアルゴリズムに関する論文を研究することでわかるように、各ラウンドでランダムにサンプリングされたクライアントのサブセットを使用してシステムで収束を達成するには時間がかかる可能性があり、このインタラクティブなチュートリアル。

代わりに、クライアントのセットを1回サンプリングし、ラウンド全体で同じセットを再利用して、収束を高速化します(これらの少数のユーザーのデータに意図的に過剰適合します)。読者がこのチュートリアルを変更してランダムサンプリングをシミュレートするための演習として残します-実行するのはかなり簡単です(一度実行すると、モデルを収束させるのに時間がかかる場合があることに注意してください)。

sample_clients = emnist_train.client_ids[0:NUM_CLIENTS]

federated_train_data = make_federated_data(emnist_train, sample_clients)

print('Number of client datasets: {l}'.format(l=len(federated_train_data)))
print('First dataset: {d}'.format(d=federated_train_data[0]))
Number of client datasets: 10
First dataset: <DatasetV1Adapter shapes: OrderedDict([(x, (None, 784)), (y, (None, 1))]), types: OrderedDict([(x, tf.float32), (y, tf.int32)])>

Kerasでモデルを作成する

Kerasを使用している場合は、Kerasモデルを構築するコードがすでにある可能性があります。これは、私たちのニーズに十分な単純なモデルの例です。

def create_keras_model():
  return tf.keras.models.Sequential([
      tf.keras.layers.Input(shape=(784,)),
      tf.keras.layers.Dense(10, kernel_initializer='zeros'),
      tf.keras.layers.Softmax(),
  ])

TFFでモデルを使用するには、 tff.learning.Modelインターフェースのインスタンスでラップする必要があります。これは、 tff.learning.Modelと同様に、モデルのフォワードパス、メタデータプロパティなどをスタンプするメソッドを公開しますが、追加のフェデレーションメトリックを計算するプロセスを制御する方法などの要素。今のところ、これについて心配する必要はありません。上で定義したようなtff.learning.from_keras_modelモデルがある場合は、以下に示すように、 tff.learning.from_keras_modelを呼び出し、モデルとサンプルデータバッチを引数として渡すことで、TFFにラップさせることができます。

def model_fn():
  # We _must_ create a new model here, and _not_ capture it from an external
  # scope. TFF will call this within different graph contexts.
  keras_model = create_keras_model()
  return tff.learning.from_keras_model(
      keras_model,
      input_spec=preprocessed_example_dataset.element_spec,
      loss=tf.keras.losses.SparseCategoricalCrossentropy(),
      metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])

連合データに関するモデルのトレーニング

TFFで使用するモデルをtff.learning.Modelとしてtff.learning.Modelしたtff.learning.build_federated_averaging_process 、次のようにヘルパー関数tff.learning.build_federated_averaging_processを呼び出すことにより、TFFにフェデレーション平均化アルゴリズムを構築させることができます。

モデルの構築がTFFによって制御されるコンテキストで行われるように、引数は既に構築されたインスタンスではなくコンストラクター(上記のmodel_fnなど)である必要があることに注意してください(理由についてmodel_fn場合)これについては、カスタムアルゴリズムのフォローアップチュートリアルを読むことをお勧めします)。

以下のフェデレーション平均化アルゴリズムに関する重要な注意点として、 2つのオプティマイザーがあります。_clientオプティマイザーと_serverオプティマイザーです。 _clientオプティマイザーは、各クライアントのローカルモデルの更新を計算するためにのみ使用されます。 _serverオプティマイザーは、平均された更新をサーバーのグローバルモデルに適用します。特に、これは、使用されるオプティマイザーと学習率の選択が、標準のiidデータセットでモデルをトレーニングするために使用したものとは異なる必要がある可能性があることを意味します。通常のSGDから始めることをお勧めします。おそらく、通常よりも学習率が低くなります。私たちが使用する学習率は慎重に調整されていません。自由に実験してください。

iterative_process = tff.learning.build_federated_averaging_process(
    model_fn,
    client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=0.02),
    server_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=1.0))

今何があったの? TFFは、フェデレーション計算のペアを構築し、それらをtff.templates.IterativeProcessにパッケージ化しました。これらの計算は、プロパティのペアがinitializeおよびnextとして使用できます。

一言で言えば、フェデレーション計算は、さまざまなフェデレーションアルゴリズムを表現できるTFFの内部言語のプログラムです(これについて詳しくは、カスタムアルゴリズムのチュートリアルを参照してください)。この場合、生成されてiterative_processパックされた2つの計算は、 FederatedAveragingを実装します。

TFFの目標は、実際のフェデレーション学習設定で実行できるように計算を定義することですが、現在はローカル実行シミュレーションランタイムのみが実装されています。シミュレーターで計算を実行するには、Python関数のように呼び出すだけです。このデフォルトの解釈された環境は、高性能を目的として設計されていませんが、このチュートリアルには十分です。将来のリリースで大規模な研究を促進するために、より高性能なシミュレーションランタイムを提供する予定です。

initialize計算から始めましょう。すべてのフェデレーション計算の場合と同様に、それは関数と考えることができます。計算は引数をとらず、1つの結果(サーバー上のFederated Averagingプロセスの状態の表現)を返します。 TFFの詳細については詳しく説明したくありませんが、この状態がどのように見えるかを確認することは有益かもしれません。次のように視覚化できます。

str(iterative_process.initialize.type_signature)
'( -> <model=<trainable=<float32[784,10],float32[10]>,non_trainable=<>>,optimizer_state=<int64>,delta_aggregate_state=<value_sum_process=<>,weight_sum_process=<>>,model_broadcast_state=<>>@SERVER)'

上記のタイプシグネチャは最初は少しわかりにくいように見えるかもしれませんが、サーバーの状態はmodel (すべてのデバイスに配布されるMNISTの初期モデルパラメーター)とoptimizer_state (サーバーによって維持される追加情報)で構成されていることがわかります。ハイパーパラメータスケジュールに使用するラウンド数など)。

initialize計算を呼び出して、サーバーの状態を構築しましょう。

state = iterative_process.initialize()

フェデレーション計算のペアの2番目のnextは、サーバーの状態(モデルパラメーターを含む)をクライアントにプッシュすること、ローカルデータのデバイス上トレーニング、モデルの更新の収集と平均化で構成される、フェデレーション平均化の1回のラウンドを表します。 、サーバーで新しい更新モデルを作成します。

概念的には、 nextは、 nextような機能タイプのシグネチャを持つと考えることができます。

SERVER_STATE, FEDERATED_DATA -> SERVER_STATE, TRAINING_METRICS

特に、 next()はサーバー上で実行される関数ではなく、分散型計算全体の宣言型関数表現であると考える必要があります。一部の入力はサーバー( SERVER_STATE )によって提供されますが、それぞれが参加します。デバイスは独自のローカルデータセットを提供します。

トレーニングを1回実行して、結果を視覚化してみましょう。上記で生成したフェデレーションデータをユーザーのサンプルに使用できます。

state, metrics = iterative_process.next(state, federated_train_data)
print('round  1, metrics={}'.format(metrics))
round  1, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('value_sum_process', ()), ('weight_sum_process', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.11502057), ('loss', 3.244929)]))])

さらにいくつかのラウンドを実行してみましょう。前述のように、通常、この時点で、ユーザーが継続的に行き来する現実的な展開をシミュレートするために、各ラウンドでランダムに選択された新しいユーザーのサンプルからシミュレーションデータのサブセットを選択しますが、このインタラクティブノートブックでは、デモンストレーションのために、同じユーザーを再利用するだけで、システムがすばやく収束します。

NUM_ROUNDS = 11
for round_num in range(2, NUM_ROUNDS):
  state, metrics = iterative_process.next(state, federated_train_data)
  print('round {:2d}, metrics={}'.format(round_num, metrics))
round  2, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('value_sum_process', ()), ('weight_sum_process', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.14609054), ('loss', 2.9141645)]))])
round  3, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('value_sum_process', ()), ('weight_sum_process', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.15205762), ('loss', 2.9237952)]))])
round  4, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('value_sum_process', ()), ('weight_sum_process', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.18600823), ('loss', 2.7629454)]))])
round  5, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('value_sum_process', ()), ('weight_sum_process', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.20884773), ('loss', 2.622908)]))])
round  6, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('value_sum_process', ()), ('weight_sum_process', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.21872428), ('loss', 2.543587)]))])
round  7, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('value_sum_process', ()), ('weight_sum_process', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.2372428), ('loss', 2.4210362)]))])
round  8, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('value_sum_process', ()), ('weight_sum_process', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.28209877), ('loss', 2.2297976)]))])
round  9, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('value_sum_process', ()), ('weight_sum_process', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.2685185), ('loss', 2.195803)]))])
round 10, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('value_sum_process', ()), ('weight_sum_process', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.33868313), ('loss', 2.0523348)]))])

フェデレーショントレーニングの各ラウンド後にトレーニング損失が減少しており、モデルが収束していることを示しています。これらのトレーニング指標にはいくつかの重要な注意事項がありますが、このチュートリアルで後述する評価のセクションを参照してください。

TensorBoardでのモデルメトリックの表示

次に、Tensorboardを使用して、これらのフェデレーション計算からのメトリックを視覚化します。

まず、メトリックを書き込むディレクトリと対応するサマリーライターを作成します。

logdir = "/tmp/logs/scalars/training/"
summary_writer = tf.summary.create_file_writer(logdir)
state = iterative_process.initialize()

同じサマリーライターを使用して、関連するスカラーメトリックをプロットします。

with summary_writer.as_default():
  for round_num in range(1, NUM_ROUNDS):
    state, metrics = iterative_process.next(state, federated_train_data)
    for name, value in metrics['train'].items():
      tf.summary.scalar(name, value, step=round_num)

上記で指定したルートログディレクトリを使用してTensorBoardを起動します。データの読み込みには数秒かかる場合があります。

!ls {logdir}
%tensorboard --logdir {logdir} --port=0
events.out.tfevents.1604020204.isim77-20020ad609500000b02900f40f27a5f6.prod.google.com.686098.10633.v2
events.out.tfevents.1604020602.isim77-20020ad609500000b02900f40f27a5f6.prod.google.com.794554.10607.v2

Launching TensorBoard...
<IPython.core.display.Javascript at 0x7fc5e8d3c128>
# Uncomment and run this this cell to clean your directory of old output for
# future graphs from this directory. We don't run it by default so that if 
# you do a "Runtime > Run all" you don't lose your results.

# !rm -R /tmp/logs/scalars/*

評価指標を同じように表示するために、「logs / scalars / eval」のような別のevalフォルダーを作成してTensorBoardに書き込むことができます。

モデル実装のカスタマイズ

KerasはTensorFlow推奨される高レベルモデルAPIであり、可能な限りTFFでtff.learning.from_keras_modelモデル( tff.learning.from_keras_modelを介して)を使用することをお勧めします。

しかし、 tff.learning下位モデルインタフェース、提供tff.learning.Model連合学習のためのモデルを使用するために必要な最小限の機能を公開し、。このインターフェースを直接実装すると(おそらくtf.keras.layersなどのビルディングブロックを使用して)、フェデレーション学習アルゴリズムの内部を変更することなく、最大限のカスタマイズが可能になります。

それでは、最初からやり直しましょう。

モデル変数、フォワードパス、およびメトリックの定義

最初のステップは、使用するTensorFlow変数を特定することです。次のコードを読みやすくするために、セット全体を表すデータ構造を定義しましょう。これは、次のような変数に含まれるweightsbias我々が訓練すること、などなど、私たちがトレーニング中に更新される様々な累積統計やカウンタを保持する変数、 loss_sumaccuracy_sum 、およびnum_examples

MnistVariables = collections.namedtuple(
    'MnistVariables', 'weights bias num_examples loss_sum accuracy_sum')

変数を作成するメソッドは次のとおりです。簡単にするために、すべての統計をtf.float32として表します。これにより、後の段階で型変換を行う必要がなくなります。変数初期化子をラムダとしてラップすることは、リソース変数によって課せられる要件です

def create_mnist_variables():
  return MnistVariables(
      weights=tf.Variable(
          lambda: tf.zeros(dtype=tf.float32, shape=(784, 10)),
          name='weights',
          trainable=True),
      bias=tf.Variable(
          lambda: tf.zeros(dtype=tf.float32, shape=(10)),
          name='bias',
          trainable=True),
      num_examples=tf.Variable(0.0, name='num_examples', trainable=False),
      loss_sum=tf.Variable(0.0, name='loss_sum', trainable=False),
      accuracy_sum=tf.Variable(0.0, name='accuracy_sum', trainable=False))

モデルパラメータと累積統計の変数を配置したら、次のように、損失を計算し、予測を出力し、入力データの単一バッチの累積統計を更新するフォワードパスメソッドを定義できます。

def mnist_forward_pass(variables, batch):
  y = tf.nn.softmax(tf.matmul(batch['x'], variables.weights) + variables.bias)
  predictions = tf.cast(tf.argmax(y, 1), tf.int32)

  flat_labels = tf.reshape(batch['y'], [-1])
  loss = -tf.reduce_mean(
      tf.reduce_sum(tf.one_hot(flat_labels, 10) * tf.math.log(y), axis=[1]))
  accuracy = tf.reduce_mean(
      tf.cast(tf.equal(predictions, flat_labels), tf.float32))

  num_examples = tf.cast(tf.size(batch['y']), tf.float32)

  variables.num_examples.assign_add(num_examples)
  variables.loss_sum.assign_add(loss * num_examples)
  variables.accuracy_sum.assign_add(accuracy * num_examples)

  return loss, predictions

次に、再びTensorFlowを使用して、ローカルメトリックのセットを返す関数を定義します。これらは、(自動的に処理されるモデルの更新に加えて)フェデレーション学習または評価プロセスでサーバーに集約するのに適格な値です。

ここでは、平均lossaccuracy 、およびnum_examplesだけnum_examples 。これは、フェデレーション集計を計算するときに、さまざまなユーザーからの寄与を正しく重み付けする必要があります。

def get_local_mnist_metrics(variables):
  return collections.OrderedDict(
      num_examples=variables.num_examples,
      loss=variables.loss_sum / variables.num_examples,
      accuracy=variables.accuracy_sum / variables.num_examples)

最後に、 get_local_mnist_metricsを介して各デバイスからget_local_mnist_metricsされたローカルメトリックを集約する方法を決定する必要があります。これは、TensorFlowで記述されていないコードの唯一の部分です。これは、TFFで表現されたフェデレーション計算です。さらに深く掘り下げたい場合は、カスタムアルゴリズムのチュートリアルをざっと読んでください。ただし、ほとんどのアプリケーションでは、実際にそうする必要はありません。以下に示すパターンのバリエーションで十分です。外観は次のとおりです。

@tff.federated_computation
def aggregate_mnist_metrics_across_clients(metrics):
  return collections.OrderedDict(
      num_examples=tff.federated_sum(metrics.num_examples),
      loss=tff.federated_mean(metrics.loss, metrics.num_examples),
      accuracy=tff.federated_mean(metrics.accuracy, metrics.num_examples))

入力metrics引数に対応OrderedDictによって返さget_local_mnist_metrics以上、しかし批判的に値がされなくなりましたtf.Tensors -彼らはとして「箱入り」ですtff.Valueだけ、それはあなたがもはやTensorFlowを使用してそれらを操作することができますクリアしない作るために、S tff.federated_meantff.federated_sumなどのTFFのフェデレーション演算子を使用します。返されるグローバル集計のディクショナリは、サーバーで使用できるメトリックのセットを定義します。

tff.learning.Modelインスタンスをtff.learning.Model

上記のすべてが整ったので、TFFにKerasモデルを取り込んだときに生成されるものと同様の、TFFで使用するためのモデル表現を構築する準備が整いました。

class MnistModel(tff.learning.Model):

  def __init__(self):
    self._variables = create_mnist_variables()

  @property
  def trainable_variables(self):
    return [self._variables.weights, self._variables.bias]

  @property
  def non_trainable_variables(self):
    return []

  @property
  def local_variables(self):
    return [
        self._variables.num_examples, self._variables.loss_sum,
        self._variables.accuracy_sum
    ]

  @property
  def input_spec(self):
    return collections.OrderedDict(
        x=tf.TensorSpec([None, 784], tf.float32),
        y=tf.TensorSpec([None, 1], tf.int32))

  @tf.function
  def forward_pass(self, batch, training=True):
    del training
    loss, predictions = mnist_forward_pass(self._variables, batch)
    num_exmaples = tf.shape(batch['x'])[0]
    return tff.learning.BatchOutput(
        loss=loss, predictions=predictions, num_examples=num_exmaples)

  @tf.function
  def report_local_outputs(self):
    return get_local_mnist_metrics(self._variables)

  @property
  def federated_output_computation(self):
    return aggregate_mnist_metrics_across_clients

ご覧のとおり、 tff.learning.Modelによって定義された抽象メソッドとプロパティは、変数を導入し、損失と統計を定義した前のセクションのコードスニペットに対応しています。

ここに強調する価値のあるいくつかのポイントがあります:

  • TFFは実行時にPythonを使用しないため、モデルが使用するすべての状態をTensorFlow変数としてキャプチャする必要があります(モバイルデバイスにデプロイできるようにコードを記述する必要があることに注意してください。詳細については、カスタムアルゴリズムのチュートリアルを参照してください)。理由の解説)。
  • 一般に、TFFは強く型付けされた環境であり、すべてのコンポーネントの型シグネチャを決定する必要があるため、モデルは受け入れるデータの形式( input_spec )を記述する必要があります。モデルの入力の形式を宣言することは、その重要な部分です。
  • 技術的には必須ではありませんが、すべてのTensorFlowロジック(フォワードパス、メトリック計算など)をtf.functionとしてラップすることをお勧めします。これにより、TensorFlowをシリアル化できるようになり、明示的な制御依存関係が不要になります。

Federated SGDのような評価やアルゴリズムには、上記で十分です。ただし、Federated Averagingの場合、モデルが各バッチでローカルにトレーニングする方法を指定する必要があります。 Federated Averagingアルゴリズムを構築するときに、ローカルオプティマイザーを指定します。

新しいモデルを使用したフェデレーショントレーニングのシミュレーション

上記のすべてが整った状態で、プロセスの残りの部分はすでに見たもののようになります-モデルコンストラクターを新しいモデルクラスのコンストラクターに置き換え、作成した反復プロセスで2つのフェデレーション計算を使用して循環しますトレーニングラウンド。

iterative_process = tff.learning.build_federated_averaging_process(
    MnistModel,
    client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=0.02))
state = iterative_process.initialize()
state, metrics = iterative_process.next(state, federated_train_data)
print('round  1, metrics={}'.format(metrics))
round  1, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('value_sum_process', ()), ('weight_sum_process', ())])), ('train', OrderedDict([('num_examples', 4860.0), ('loss', 3.1527398), ('accuracy', 0.12469136)]))])

for round_num in range(2, 11):
  state, metrics = iterative_process.next(state, federated_train_data)
  print('round {:2d}, metrics={}'.format(round_num, metrics))
round  2, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('value_sum_process', ()), ('weight_sum_process', ())])), ('train', OrderedDict([('num_examples', 4860.0), ('loss', 2.941014), ('accuracy', 0.14218107)]))])
round  3, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('value_sum_process', ()), ('weight_sum_process', ())])), ('train', OrderedDict([('num_examples', 4860.0), ('loss', 2.9052832), ('accuracy', 0.14444445)]))])
round  4, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('value_sum_process', ()), ('weight_sum_process', ())])), ('train', OrderedDict([('num_examples', 4860.0), ('loss', 2.7491086), ('accuracy', 0.17962962)]))])
round  5, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('value_sum_process', ()), ('weight_sum_process', ())])), ('train', OrderedDict([('num_examples', 4860.0), ('loss', 2.5129666), ('accuracy', 0.19526748)]))])
round  6, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('value_sum_process', ()), ('weight_sum_process', ())])), ('train', OrderedDict([('num_examples', 4860.0), ('loss', 2.4175923), ('accuracy', 0.23600823)]))])
round  7, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('value_sum_process', ()), ('weight_sum_process', ())])), ('train', OrderedDict([('num_examples', 4860.0), ('loss', 2.4273515), ('accuracy', 0.24176955)]))])
round  8, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('value_sum_process', ()), ('weight_sum_process', ())])), ('train', OrderedDict([('num_examples', 4860.0), ('loss', 2.2426176), ('accuracy', 0.2802469)]))])
round  9, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('value_sum_process', ()), ('weight_sum_process', ())])), ('train', OrderedDict([('num_examples', 4860.0), ('loss', 2.1567981), ('accuracy', 0.295679)]))])
round 10, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('value_sum_process', ()), ('weight_sum_process', ())])), ('train', OrderedDict([('num_examples', 4860.0), ('loss', 2.1092515), ('accuracy', 0.30843621)]))])

TensorBoard内でこれらのメトリックを確認するには、上記の「TensorBoardでのモデルメトリックの表示」の手順を参照してください。

評価

これまでのすべての実験では、フェデレーショントレーニングメトリックのみが提示されました。これは、ラウンド内のすべてのクライアントにわたってトレーニングされたデータのすべてのバッチの平均メトリックです。これにより、特に単純化のために各ラウンドで同じクライアントセットを使用したため、過剰適合に関する通常の懸念が生じますが、FederatedAveragingアルゴリズムに固有のトレーニングメトリックには過剰適合の概念が追加されています。これは、各クライアントに単一のデータバッチがあると想像すると最も簡単に確認でき、そのバッチで多くの反復(エポック)をトレーニングします。この場合、ローカルモデルはその1つのバッチにすばやく正確に適合するため、平均するローカル精度メトリックは1.0に近づきます。したがって、これらのトレーニングメトリックは、トレーニングが進行していることを示すものと見なすことができますが、それ以上ではありません。

フェデレーションデータの評価を実行するには、 tff.learning.build_federated_evaluation関数を使用し、モデルコンストラクターを引数として渡すことで、この目的のために設計された別のフェデレーション計算を構築できます。 MnistTrainableModelを使用したFederatedAveragingとは異なり、 MnistModelを渡すだけで十分であることに注意してください。評価は最急降下法を実行せず、オプティマイザーを作成する必要はありません。

実験と研究のために、一元化されたテストデータセットが利用可能な場合、テキスト生成のフェデレーションラーニングは別の評価オプションを示します。フェデレーションラーニングからトレーニングされた重みを取得し、それらを標準のKerasモデルに適用してから、単にtf.keras.models.Model.evaluate()呼び出しtf.keras.models.Model.evaluate()一元化されたデータセットでtf.keras.models.Model.evaluate()

evaluation = tff.learning.build_federated_evaluation(MnistModel)

評価関数の抽象型シグネチャは、次のように検査できます。

str(evaluation.type_signature)
'(<server_model_weights=<trainable=<float32[784,10],float32[10]>,non_trainable=<>>@SERVER,federated_dataset={<x=float32[?,784],y=int32[?,1]>*}@CLIENTS> -> <num_examples=float32@SERVER,loss=float32@SERVER,accuracy=float32@SERVER>)'

この時点で詳細を気にする必要はありませんtff.templates.IterativeProcess.nextに似ていtff.templates.IterativeProcess.nextが、2つの重要な違いがある次の一般的な形式をとることに注意してください。まず、評価によってモデルや状態の他の側面が変更されないため、サーバーの状態は返されません。ステートレスと考えることができます。第2に、評価にはモデルのみが必要であり、オプティマイザー変数など、トレーニングに関連する可能性のあるサーバー状態の他の部分は必要ありません。

SERVER_MODEL, FEDERATED_DATA -> TRAINING_METRICS

トレーニング中に到達した最新の状態の評価を呼び出しましょう。サーバーの状態から最新のトレーニング済みモデルを抽出するには、次のように.modelメンバーにアクセスするだけです。

train_metrics = evaluation(state.model, federated_train_data)

これが私たちが得るものです。上記のトレーニングの最後のラウンドで報告された数値よりもわずかに良く見えることに注意してください。慣例により、反復トレーニングプロセスによって報告されるトレーニングメトリックは、通常、トレーニングラウンドの開始時のモデルのパフォーマンスを反映するため、評価メトリックは常に一歩先を行くことになります。

str(train_metrics)
'<num_examples=4860.0,loss=1.7142657041549683,accuracy=0.38683128356933594>'

それでは、連合データのテストサンプルをコンパイルし、テストデータに対して評価を再実行してみましょう。データは、実際のユーザーの同じサンプルから取得されますが、個別の差し出されたデータセットから取得されます。

federated_test_data = make_federated_data(emnist_test, sample_clients)

len(federated_test_data), federated_test_data[0]
(10,
 <DatasetV1Adapter shapes: OrderedDict([(x, (None, 784)), (y, (None, 1))]), types: OrderedDict([(x, tf.float32), (y, tf.int32)])>)
test_metrics = evaluation(state.model, federated_test_data)
str(test_metrics)
'<num_examples=580.0,loss=1.861915111541748,accuracy=0.3362068831920624>'

これでチュートリアルは終了です。パラメータ(バッチサイズ、ユーザー数、エポック、学習率など)を試して、上記のコードを変更して、各ラウンドのユーザーのランダムサンプルのトレーニングをシミュレートし、他のチュートリアルを探索することをお勧めします。私たちは開発しました。