ヘルプKaggleにTensorFlowグレートバリアリーフを保護チャレンジに参加

federated_selectとスパース集計によるクライアント効率の高い大規模モデルのフェデレーション学習

TensorFlow.orgで表示 GoogleColabで実行 GitHubでソースを表示 ノートブックをダウンロード

TFFは、各クライアントデバイスのみがダウンロード非常に大規模なモデルを訓練するために使用し、使用して、モデルの小さな部分を更新することができますどのようにこのチュートリアルを示しtff.federated_selectとスパース凝集を。このチュートリアルではかなり自己完結型ですが、 tff.federated_selectチュートリアルカスタムFLアルゴリズムは、チュートリアル、ここで使用される技術のいくつかに良い紹介を提供しています。

具体的には、このチュートリアルでは、マルチラベル分類のロジスティック回帰を検討し、単語の袋の特徴表現に基づいて、どの「タグ」がテキスト文字列に関連付けられているかを予測します。重要なことに、通信およびクライアント側の計算コストは、固定された定数(によって制御さMAX_TOKENS_SELECTED_PER_CLIENT )、及び実用的な設定で非常に大きくなる可能性が全体の語彙のサイズでスケーリングしません

!pip install --quiet --upgrade tensorflow-federated-nightly
!pip install --quiet --upgrade nest-asyncio

import nest_asyncio
nest_asyncio.apply()
import collections
import itertools
import numpy as np

from typing import Callable, List, Tuple

import tensorflow as tf
import tensorflow_federated as tff
tff.backends.native.set_local_python_execution_context()

各クライアントはなりfederated_selectせいぜいこの多くのユニークなトークンのためのモデルの重みの行を。これは、クライアントのローカルモデルのサイズとサーバの量を上位境界- >クライアント( federated_select )及びクライアント- >サーバ(federated_aggregate )通信を行います。

このチュートリアルは、これを1に設定した場合(各クライアントからのすべてのトークンが選択されていないことを確認)または大きい値に設定した場合でも正しく実行されますが、モデルの収束が影響を受ける可能性があります。

MAX_TOKENS_SELECTED_PER_CLIENT = 6

また、さまざまなタイプの定数をいくつか定義します。このコラボのために、トークンは、データセットを解析した後、特定の単語のための整数の識別子です。

# There are some constraints on types
# here that will require some explicit type conversions:
#    - `tff.federated_select` requires int32
#    - `tf.SparseTensor` requires int64 indices.
TOKEN_DTYPE = tf.int64
SELECT_KEY_DTYPE = tf.int32

# Type for counts of token occurences.
TOKEN_COUNT_DTYPE = tf.int32

# A sparse feature vector can be thought of as a map
# from TOKEN_DTYPE to FEATURE_DTYPE. 
# Our features are {0, 1} indicators, so we could potentially
# use tf.int8 as an optimization.
FEATURE_DTYPE = tf.int32

問題の設定:データセットとモデル

このチュートリアルでは、簡単に実験できるように小さなおもちゃのデータセットを作成します。しかし、データセットのフォーマットと互換性のあるフェデレーションStackOverflowの、及び前処理モデルアーキテクチャはのStackOverflowのタグ予測問題から採用される適応フェデレーション最適

データセットの解析と前処理

NUM_OOV_BUCKETS = 1

BatchType = collections.namedtuple('BatchType', ['tokens', 'tags'])

def build_to_ids_fn(word_vocab: List[str],
                    tag_vocab: List[str]) -> Callable[[tf.Tensor], tf.Tensor]:
  """Constructs a function mapping examples to sequences of token indices."""
  word_table_values = np.arange(len(word_vocab), dtype=np.int64)
  word_table = tf.lookup.StaticVocabularyTable(
      tf.lookup.KeyValueTensorInitializer(word_vocab, word_table_values),
      num_oov_buckets=NUM_OOV_BUCKETS)

  tag_table_values = np.arange(len(tag_vocab), dtype=np.int64)
  tag_table = tf.lookup.StaticVocabularyTable(
      tf.lookup.KeyValueTensorInitializer(tag_vocab, tag_table_values),
      num_oov_buckets=NUM_OOV_BUCKETS)

  def to_ids(example):
    """Converts a Stack Overflow example to a bag-of-words/tags format."""
    sentence = tf.strings.join([example['tokens'], example['title']],
                               separator=' ')

    # We represent that label (output tags) densely.
    raw_tags = example['tags']
    tags = tf.strings.split(raw_tags, sep='|')
    tags = tag_table.lookup(tags)
    tags, _ = tf.unique(tags)
    tags = tf.one_hot(tags, len(tag_vocab) + NUM_OOV_BUCKETS)
    tags = tf.reduce_max(tags, axis=0)

    # We represent the features as a SparseTensor of {0, 1}s.
    words = tf.strings.split(sentence)
    tokens = word_table.lookup(words)
    tokens, _ = tf.unique(tokens)
    # Note:  We could choose to use the word counts as the feature vector
    # instead of just {0, 1} values (see tf.unique_with_counts).
    tokens = tf.reshape(tokens, shape=(tf.size(tokens), 1))
    tokens_st = tf.SparseTensor(
        tokens,
        tf.ones(tf.size(tokens), dtype=FEATURE_DTYPE),
        dense_shape=(len(word_vocab) + NUM_OOV_BUCKETS,))
    tokens_st = tf.sparse.reorder(tokens_st)

    return BatchType(tokens_st, tags)

  return to_ids
def build_preprocess_fn(word_vocab, tag_vocab):

  @tf.function
  def preprocess_fn(dataset):
    to_ids = build_to_ids_fn(word_vocab, tag_vocab)
    # We *don't* shuffle in order to make this colab deterministic for
    # easier testing and reproducibility.
    # But real-world training should use `.shuffle()`.
    return dataset.map(to_ids, num_parallel_calls=tf.data.experimental.AUTOTUNE)

  return preprocess_fn

小さなおもちゃのデータセット

12語と3クライアントのグローバル語彙で小さなおもちゃのデータセットを構築します。この小さな例では、エッジケースをテストするのに有用である(例えば、我々は、より少ない2つのクライアントがあるMAX_TOKENS_SELECTED_PER_CLIENT = 6 、コードを開発異なるトークン、および1つ以上を有します)。

ただし、このアプローチの実際のユースケースは、数千万以上のグローバルな語彙であり、各クライアントに数千の異なるトークンが表示される可能性があります。データの形式が同じであるため、より現実的なテストベッド上の問題への拡張は、例えばtff.simulation.datasets.stackoverflow.load_data()データセット、簡単です。

まず、単語とタグの語彙を定義します。

# Features
FRUIT_WORDS = ['apple', 'orange', 'pear', 'kiwi']
VEGETABLE_WORDS = ['carrot', 'broccoli', 'arugula', 'peas']
FISH_WORDS = ['trout', 'tuna', 'cod', 'salmon']
WORD_VOCAB = FRUIT_WORDS + VEGETABLE_WORDS + FISH_WORDS

# Labels
TAG_VOCAB = ['FRUIT', 'VEGETABLE', 'FISH']

ここで、小さなローカルデータセットを使用して3つのクライアントを作成します。このチュートリアルをcolabで実行している場合は、「タブのミラーセル」機能を使用して、このセルとその出力を固定し、以下で開発する関数の出力を解釈/確認すると便利な場合があります。

preprocess_fn = build_preprocess_fn(WORD_VOCAB, TAG_VOCAB)


def make_dataset(raw):
  d = tf.data.Dataset.from_tensor_slices(
      # Matches the StackOverflow formatting
      collections.OrderedDict(
          tokens=tf.constant([t[0] for t in raw]),
          tags=tf.constant([t[1] for t in raw]),
          title=['' for _ in raw]))
  d = preprocess_fn(d)
  return d


# 4 distinct tokens
CLIENT1_DATASET = make_dataset([
    ('apple orange apple orange', 'FRUIT'),
    ('carrot trout', 'VEGETABLE|FISH'),
    ('orange apple', 'FRUIT'),
    ('orange', 'ORANGE|CITRUS')  # 2 OOV tag
])

# 6 distinct tokens
CLIENT2_DATASET = make_dataset([
    ('pear cod', 'FRUIT|FISH'),
    ('arugula peas', 'VEGETABLE'),
    ('kiwi pear', 'FRUIT'),
    ('sturgeon', 'FISH'),  # OOV word
    ('sturgeon bass', 'FISH')  # 2 OOV words
])

# A client with all possible words & tags (13 distinct tokens).
# With MAX_TOKENS_SELECTED_PER_CLIENT = 6, we won't download the model
# slices for all tokens that occur on this client.
CLIENT3_DATASET = make_dataset([
    (' '.join(WORD_VOCAB + ['oovword']), '|'.join(TAG_VOCAB)),
    # Mathe the OOV token and 'salmon' occur in the largest number
    # of examples on this client:
    ('salmon oovword', 'FISH|OOVTAG')
])

print('Word vocab')
for i, word in enumerate(WORD_VOCAB):
  print(f'{i:2d} {word}')

print('\nTag vocab')
for i, tag in enumerate(TAG_VOCAB):
  print(f'{i:2d} {tag}')
Word vocab
 0 apple
 1 orange
 2 pear
 3 kiwi
 4 carrot
 5 broccoli
 6 arugula
 7 peas
 8 trout
 9 tuna
10 cod
11 salmon

Tag vocab
 0 FRUIT
 1 VEGETABLE
 2 FISH

入力特徴(トークン/単語)とラベル(ポストタグ)の生の数の定数を定義します。私たちの実際の入力/出力スペースがあるNUM_OOV_BUCKETS = 1 、我々はOOVトークン/タグを追加しているため、より大きな。

NUM_WORDS = len(WORD_VOCAB) 
NUM_TAGS = len(TAG_VOCAB)

WORD_VOCAB_SIZE = NUM_WORDS + NUM_OOV_BUCKETS
TAG_VOCAB_SIZE = NUM_TAGS + NUM_OOV_BUCKETS

データセットのバッチバージョンと個々のバッチを作成します。これは、コードのテストに役立ちます。

batched_dataset1 = CLIENT1_DATASET.batch(2)
batched_dataset2 = CLIENT2_DATASET.batch(3)
batched_dataset3 = CLIENT3_DATASET.batch(2)

batch1 = next(iter(batched_dataset1))
batch2 = next(iter(batched_dataset2))
batch3 = next(iter(batched_dataset3))

スパース入力でモデルを定義する

タグごとに単純な独立ロジスティック回帰モデルを使用します。

def create_logistic_model(word_vocab_size: int, vocab_tags_size: int):

  model = tf.keras.models.Sequential([
      tf.keras.layers.InputLayer(input_shape=(word_vocab_size,), sparse=True),
      tf.keras.layers.Dense(
          vocab_tags_size,
          activation='sigmoid',
          kernel_initializer=tf.keras.initializers.zeros,
          # For simplicity, don't use a bias vector; this means the model
          # is a single tensor, and we only need sparse aggregation of
          # the per-token slices of the model. Generalizing to also handle
          # other model weights that are fully updated 
          # (non-dense broadcast and aggregate) would be a good exercise.
          use_bias=False),
  ])

  return model

まず、予測を行って、それが機能することを確認しましょう。

model = create_logistic_model(WORD_VOCAB_SIZE, TAG_VOCAB_SIZE)
p = model.predict(batch1.tokens)
print(p)
[[0.5 0.5 0.5 0.5]
 [0.5 0.5 0.5 0.5]]

そして、いくつかの簡単な集中トレーニング:

model.compile(optimizer=tf.keras.optimizers.Adagrad(learning_rate=0.001),
              loss=tf.keras.losses.BinaryCrossentropy())
model.train_on_batch(batch1.tokens, batch1.tags)

フェデレーション計算のビルディングブロック

私たちは、単純なバージョンを実装します連合平均化各デバイスは、唯一のモデルの関連するサブセットをダウンロードすることを主な違いとアルゴリズムを、そして唯一のそのサブセットへの更新を貢献しています。

私たちは、使用Mための速記としてMAX_TOKENS_SELECTED_PER_CLIENT 。大まかに言うと、1ラウンドのトレーニングには次の手順が含まれます。

  1. 参加している各クライアントは、ローカルデータセットをスキャンし、入力文字列を解析して、正しいトークン(intインデックス)にマッピングします。これは、グローバル(大)の辞書へのアクセスを(これは潜在的に使用して回避することができる必要がフィーチャのハッシュ技術を)。次に、各トークンが発生する回数をまばらにカウントします。場合はUユニークなトークンは、デバイス上で発生する、我々は選択しnum_actual_tokens = min(U, M)電車に最も頻繁にトークンを。

  2. クライアントが使用federated_selectためのモデル係数取得するためにnum_actual_tokensサーバーからトークンを選択します。各モデルのスライス形状のテンソルである(TAG_VOCAB_SIZE, )クライアントへ送信される全データサイズの最大であるので、 TAG_VOCAB_SIZE * M (下記の注を参照)。

  3. クライアントがマッピング構築global_token -> local_tokenローカルトークン(int型のインデックス)は、選択したトークンのリストでグローバルトークンの指数です。

  4. クライアントは、せいぜいの係数を有し、グローバルモデルの「小さい」バージョンを使用M範囲から、トークンを[0, num_actual_tokens) global -> localマッピングは、選択したモデルスライスからこのモデルの密なパラメータを初期化するために使用されます。

  5. クライアントは、前処理とデータにSGDを使用してローカルモデルを訓練global -> localマッピング。

  6. クライアントは、ローカルに彼らのモデルのパラメータを回しIndexedSlices使用してアップデートlocal -> globalインデックスに行をマッピングします。サーバーは、スパースサム集計を使用してこれらの更新を集計します。

  7. サーバーは、上記の集計の(密な)結果を取得し、それを参加しているクライアントの数で除算し、結果の平均更新をグローバルモデルに適用します。

このセクションでは、最終で結合される、これらのステップのためのビルディングブロックを構築federated_computation 1つのトレーニングラウンドの完全なロジックをキャプチャします。

クライアントのトークンをカウントし、モデルにスライスを決めるfederated_select

各デバイスは、モデルのどの「スライス」がローカルトレーニングデータセットに関連しているかを判断する必要があります。私たちの問題では、クライアントトレーニングデータセットの各トークンを含む例の数を(まばらに!)数えることでこれを行います。

@tf.function
def token_count_fn(token_counts, batch):
  """Adds counts from `batch` to the running `token_counts` sum."""
  # Sum across the batch dimension.
  flat_tokens = tf.sparse.reduce_sum(
      batch.tokens, axis=0, output_is_sparse=True)
  flat_tokens = tf.cast(flat_tokens, dtype=TOKEN_COUNT_DTYPE)
  return tf.sparse.add(token_counts, flat_tokens)
# Simple tests
# Create the initial zero token counts using empty tensors.
initial_token_counts = tf.SparseTensor(
    indices=tf.zeros(shape=(0, 1), dtype=TOKEN_DTYPE),
    values=tf.zeros(shape=(0,), dtype=TOKEN_COUNT_DTYPE),
    dense_shape=(WORD_VOCAB_SIZE,))

client_token_counts = batched_dataset1.reduce(initial_token_counts,
                                              token_count_fn)
tokens = tf.reshape(client_token_counts.indices, (-1,)).numpy()
print('tokens:', tokens)
np.testing.assert_array_equal(tokens, [0, 1, 4, 8])
# The count is the number of *examples* in which the token/word
# occurs, not the total number of occurences, since we still featurize
# multiple occurences in the same example as a "1".
counts = client_token_counts.values.numpy()
print('counts:', counts)
np.testing.assert_array_equal(counts, [2, 3, 1, 1])
tokens: [0 1 4 8]
counts: [2 3 1 1]

私たちは、に対応するモデルパラメータを選択しますMAX_TOKENS_SELECTED_PER_CLIENT最も頻繁にデバイス上のトークンを破壊に対する。この多くのトークンよりも少ないが、デバイス上で発生した場合、我々はパッドリストが使用可能にしfederated_select

トークンをランダムに選択するなど、他の戦略の方がおそらく優れていることに注意してください(おそらくそれらの発生確率に基づいて)。これにより、モデルのすべてのスライス(クライアントがデータを持っている)が更新される可能性があります。

@tf.function
def keys_for_client(client_dataset, max_tokens_per_client):
  """Computes a set of max_tokens_per_client keys."""
  initial_token_counts = tf.SparseTensor(
      indices=tf.zeros((0, 1), dtype=TOKEN_DTYPE),
      values=tf.zeros((0,), dtype=TOKEN_COUNT_DTYPE),
      dense_shape=(WORD_VOCAB_SIZE,))
  client_token_counts = client_dataset.reduce(initial_token_counts,
                                              token_count_fn)
  # Find the most-frequently occuring tokens
  tokens = tf.reshape(client_token_counts.indices, shape=(-1,))
  counts = client_token_counts.values
  perm = tf.argsort(counts, direction='DESCENDING')
  tokens = tf.gather(tokens, perm)
  counts = tf.gather(counts, perm)
  num_raw_tokens = tf.shape(tokens)[0]
  actual_num_tokens = tf.minimum(max_tokens_per_client, num_raw_tokens)
  selected_tokens = tokens[:actual_num_tokens]
  paddings = [[0, max_tokens_per_client - tf.shape(selected_tokens)[0]]]
  padded_tokens = tf.pad(selected_tokens, paddings=paddings)
  # Make sure the type is statically determined
  padded_tokens = tf.reshape(padded_tokens, shape=(max_tokens_per_client,))

  # We will pass these tokens as keys into `federated_select`, which
  # requires SELECT_KEY_DTYPE=tf.int32 keys.
  padded_tokens = tf.cast(padded_tokens, dtype=SELECT_KEY_DTYPE)
  return padded_tokens, actual_num_tokens
# Simple test

# Case 1: actual_num_tokens > max_tokens_per_client
selected_tokens, actual_num_tokens = keys_for_client(batched_dataset1, 3)
assert tf.size(selected_tokens) == 3
assert actual_num_tokens == 3

# Case 2: actual_num_tokens < max_tokens_per_client
selected_tokens, actual_num_tokens = keys_for_client(batched_dataset1, 10)
assert tf.size(selected_tokens) == 10
assert actual_num_tokens == 4

グローバルトークンをローカルトークンにマップする

上記の選択は私たちの範囲内のトークンの稠密集合与え[0, actual_num_tokens)我々は、オンデバイスモデルに使用されます。しかし、私たちが読んでデータセットがはるかに大きいグローバルな語彙の範囲からトークンを持っている[0, WORD_VOCAB_SIZE)

したがって、グローバルトークンを対応するローカルトークンにマップする必要があります。ローカルトークンIDは、単純にインデックスによって与えられるselected_tokens前のステップで計算されたテンソル。

@tf.function
def map_to_local_token_ids(client_data, client_keys):
  global_to_local = tf.lookup.StaticHashTable(
      # Note int32 -> int64 maps are not supported
      tf.lookup.KeyValueTensorInitializer(
          keys=tf.cast(client_keys, dtype=TOKEN_DTYPE),
          # Note we need to use tf.shape, not the static 
          # shape client_keys.shape[0]
          values=tf.range(0, limit=tf.shape(client_keys)[0],
                          dtype=TOKEN_DTYPE)),
      # We use -1 for tokens that were not selected, which can occur for clients
      # with more than MAX_TOKENS_SELECTED_PER_CLIENT distinct tokens.
      # We will simply remove these invalid indices from the batch below.
      default_value=-1)

  def to_local_ids(sparse_tokens):
    indices_t = tf.transpose(sparse_tokens.indices)
    batch_indices = indices_t[0]  # First column
    tokens = indices_t[1]  # Second column
    tokens = tf.map_fn(
        lambda global_token_id: global_to_local.lookup(global_token_id), tokens)
    # Remove tokens that aren't actually available (looked up as -1):
    available_tokens = tokens >= 0
    tokens = tokens[available_tokens]
    batch_indices = batch_indices[available_tokens]

    updated_indices = tf.transpose(
        tf.concat([[batch_indices], [tokens]], axis=0))
    st = tf.sparse.SparseTensor(
        updated_indices,
        tf.ones(tf.size(tokens), dtype=FEATURE_DTYPE),
        dense_shape=sparse_tokens.dense_shape)
    st = tf.sparse.reorder(st)
    return st

  return client_data.map(lambda b: BatchType(to_local_ids(b.tokens), b.tags))
# Simple test
client_keys, actual_num_tokens = keys_for_client(
    batched_dataset3, MAX_TOKENS_SELECTED_PER_CLIENT)
client_keys = client_keys[:actual_num_tokens]

d = map_to_local_token_ids(batched_dataset3, client_keys)
batch  = next(iter(d))
all_tokens = tf.gather(batch.tokens.indices, indices=1, axis=1)
# Confirm we have local indices in the range [0, MAX):
assert tf.math.reduce_max(all_tokens) < MAX_TOKENS_SELECTED_PER_CLIENT
assert tf.math.reduce_max(all_tokens) >= 0

各クライアントでローカル(サブ)モデルをトレーニングする

federated_select 、選択したスライスを返しますtf.data.Dataset選択キーと同じ順序で。したがって、最初に、そのようなデータセットを取得し、それをクライアントモデルのモデルの重みとして使用できる単一の密なテンソルに変換する効用関数を定義します。

@tf.function
def slices_dataset_to_tensor(slices_dataset):
  """Convert a dataset of slices to a tensor."""
  # Use batching to gather all of the slices into a single tensor.
  d = slices_dataset.batch(MAX_TOKENS_SELECTED_PER_CLIENT,
                           drop_remainder=False)
  iter_d = iter(d)
  tensor = next(iter_d)
  # Make sure we have consumed everything
  opt = iter_d.get_next_as_optional()
  tf.Assert(tf.logical_not(opt.has_value()), data=[''], name='CHECK_EMPTY')
  return tensor
# Simple test
weights = np.random.random(
    size=(MAX_TOKENS_SELECTED_PER_CLIENT, TAG_VOCAB_SIZE)).astype(np.float32)
model_slices_as_dataset = tf.data.Dataset.from_tensor_slices(weights)
weights2 = slices_dataset_to_tensor(model_slices_as_dataset)
np.testing.assert_array_equal(weights, weights2)

これで、各クライアントで実行される単純なローカルトレーニングループを定義するために必要なすべてのコンポーネントができました。

@tf.function
def client_train_fn(model, client_optimizer,
                    model_slices_as_dataset, client_data,
                    client_keys, actual_num_tokens):

  initial_model_weights = slices_dataset_to_tensor(model_slices_as_dataset)
  assert len(model.trainable_variables) == 1
  model.trainable_variables[0].assign(initial_model_weights)

  # Only keep the "real" (unpadded) keys.
  client_keys = client_keys[:actual_num_tokens]

  client_data = map_to_local_token_ids(client_data, client_keys)

  loss_fn = tf.keras.losses.BinaryCrossentropy()
  for features, labels in client_data:
    with tf.GradientTape() as tape:
      predictions = model(features)
      loss = loss_fn(labels, predictions)
    grads = tape.gradient(loss, model.trainable_variables)
    client_optimizer.apply_gradients(zip(grads, model.trainable_variables))

  model_weights_delta = model.trainable_weights[0] - initial_model_weights
  model_weights_delta = tf.slice(model_weights_delta, begin=[0, 0], 
                           size=[actual_num_tokens, -1])
  return client_keys, model_weights_delta
# Simple test
# Note if you execute this cell a second time, you need to also re-execute
# the preceeding cell to avoid "tf.function-decorated function tried to 
# create variables on non-first call" errors.
on_device_model = create_logistic_model(MAX_TOKENS_SELECTED_PER_CLIENT,
                                        TAG_VOCAB_SIZE)
client_optimizer = tf.keras.optimizers.SGD(learning_rate=0.001)
client_keys, actual_num_tokens = keys_for_client(
    batched_dataset2, MAX_TOKENS_SELECTED_PER_CLIENT)

model_slices_as_dataset = tf.data.Dataset.from_tensor_slices(
    np.zeros((MAX_TOKENS_SELECTED_PER_CLIENT, TAG_VOCAB_SIZE),
             dtype=np.float32))

keys, delta = client_train_fn(
    on_device_model,
    client_optimizer,
    model_slices_as_dataset,
    client_data=batched_dataset3,
    client_keys=client_keys,
    actual_num_tokens=actual_num_tokens)

print(delta)

IndexedSlicesの集計

私たちは、使用tff.federated_aggregateのための連合スパース合計を構築するIndexedSlices 。この単純な実装ではないという制約があるdense_shape事前に静的に知られているが。注この合計値は、クライアントがあるという意味で、唯一の半希薄であること- >サーバ通信がまばらであるが、サーバはで和の緻密な表現維持accumulatemerge 、この緻密な表現を出力します。

def federated_indexed_slices_sum(slice_indices, slice_values, dense_shape):
  """
  Sumes IndexedSlices@CLIENTS to a dense @SERVER Tensor.

  Intermediate aggregation is performed by converting to a dense representation,
  which may not be suitable for all applications.

  Args:
    slice_indices: An IndexedSlices.indices tensor @CLIENTS.
    slice_values: An IndexedSlices.values tensor @CLIENTS.
    dense_shape: A statically known dense shape.

  Returns:
    A dense tensor placed @SERVER representing the sum of the client's
    IndexedSclies.
  """
  slices_dtype = slice_values.type_signature.member.dtype
  zero = tff.tf_computation(
      lambda: tf.zeros(dense_shape, dtype=slices_dtype))()

  @tf.function
  def accumulate_slices(dense, client_value):
    indices, slices = client_value
    # There is no built-in way to add `IndexedSlices`, but 
    # tf.convert_to_tensor is a quick way to convert to a dense representation
    # so we can add them.
    return dense + tf.convert_to_tensor(
        tf.IndexedSlices(slices, indices, dense_shape))


  return tff.federated_aggregate(
      (slice_indices, slice_values),
      zero=zero,
      accumulate=tff.tf_computation(accumulate_slices),
      merge=tff.tf_computation(lambda d1, d2: tf.add(d1, d2, name='merge')),
      report=tff.tf_computation(lambda d: d))

最小限の構築federated_computationテストとして

dense_shape = (6, 2)
indices_type = tff.TensorType(tf.int64, (None,))
values_type = tff.TensorType(tf.float32, (None, 2))
client_slice_type = tff.type_at_clients(
    (indices_type, values_type))

@tff.federated_computation(client_slice_type)
def test_sum_indexed_slices(indices_values_at_client):
  indices, values = indices_values_at_client
  return federated_indexed_slices_sum(indices, values, dense_shape)

print(test_sum_indexed_slices.type_signature)
({<int64[?],float32[?,2]>}@CLIENTS -> float32[6,2]@SERVER)
x = tf.IndexedSlices(
    values=np.array([[2., 2.1], [0., 0.1], [1., 1.1], [5., 5.1]],
                    dtype=np.float32),
    indices=[2, 0, 1, 5],
    dense_shape=dense_shape)
y = tf.IndexedSlices(
    values=np.array([[0., 0.3], [3.1, 3.2]], dtype=np.float32),
    indices=[1, 3],
    dense_shape=dense_shape)

# Sum one.
result = test_sum_indexed_slices([(x.indices, x.values)])
np.testing.assert_array_equal(tf.convert_to_tensor(x), result)

# Sum two.
expected = [[0., 0.1], [1., 1.4], [2., 2.1], [3.1, 3.2], [0., 0.], [5., 5.1]]
result = test_sum_indexed_slices([(x.indices, x.values), (y.indices, y.values)])
np.testing.assert_array_almost_equal(expected, result)

すべて一緒にそれを置くfederated_computation

私たちは、今に部品を一緒にバインドするためにTFFを使用していますtff.federated_computation

DENSE_MODEL_SHAPE = (WORD_VOCAB_SIZE, TAG_VOCAB_SIZE)
client_data_type = tff.SequenceType(batched_dataset1.element_spec)
model_type = tff.TensorType(tf.float32, shape=DENSE_MODEL_SHAPE)

Federated Averagingに基づく基本的なサーバートレーニング機能を使用し、サーバー学習率1.0で更新を適用します。モデルの特定のスライスが特定のラウンドでどのクライアントによってもトレーニングされていない場合、その係数がゼロになる可能性があるため、クライアント提供のモデルを単純に平均化するのではなく、モデルに更新(デルタ)を適用することが重要です。アウト。

@tff.tf_computation
def server_update(current_model_weights, update_sum, num_clients):
  average_update = update_sum / num_clients
  return current_model_weights + average_update

私たちはカップルより多く必要tff.tf_computationコンポーネントを:

# Function to select slices from the model weights in federated_select:
select_fn = tff.tf_computation(
    lambda model_weights, index: tf.gather(model_weights, index))


# We need to wrap `client_train_fn` as a `tff.tf_computation`, making
# sure we do any operations that might construct `tf.Variable`s outside
# of the `tf.function` we are wrapping.
@tff.tf_computation
def client_train_fn_tff(model_slices_as_dataset, client_data, client_keys,
                        actual_num_tokens):
  # Note this is amaller than the global model, using
  # MAX_TOKENS_SELECTED_PER_CLIENT which is much smaller than WORD_VOCAB_SIZE.
  # W7e would like a model of size `actual_num_tokens`, but we
  # can't build the model dynamically, so we will slice off the padded
  # weights at the end.
  client_model = create_logistic_model(MAX_TOKENS_SELECTED_PER_CLIENT,
                                       TAG_VOCAB_SIZE)
  client_optimizer = tf.keras.optimizers.SGD(learning_rate=0.1)
  return client_train_fn(client_model, client_optimizer,
                         model_slices_as_dataset, client_data, client_keys,
                         actual_num_tokens)

@tff.tf_computation
def keys_for_client_tff(client_data):
  return keys_for_client(client_data, MAX_TOKENS_SELECTED_PER_CLIENT)

これで、すべてのピースをまとめる準備ができました。

@tff.federated_computation(
    tff.type_at_server(model_type), tff.type_at_clients(client_data_type))
def sparse_model_update(server_model, client_data):
  max_tokens = tff.federated_value(MAX_TOKENS_SELECTED_PER_CLIENT, tff.SERVER)
  keys_at_clients, actual_num_tokens = tff.federated_map(
      keys_for_client_tff, client_data)

  model_slices = tff.federated_select(keys_at_clients, max_tokens, server_model,
                                      select_fn)

  update_keys, update_slices = tff.federated_map(
      client_train_fn_tff,
      (model_slices, client_data, keys_at_clients, actual_num_tokens))

  dense_update_sum = federated_indexed_slices_sum(update_keys, update_slices,
                                                  DENSE_MODEL_SHAPE)
  num_clients = tff.federated_sum(tff.federated_value(1.0, tff.CLIENTS))

  updated_server_model = tff.federated_map(
      server_update, (server_model, dense_update_sum, num_clients))

  return updated_server_model


print(sparse_model_update.type_signature)
(<server_model=float32[13,4]@SERVER,client_data={<tokens=<indices=int64[?,2],values=int32[?],dense_shape=int64[2]>,tags=float32[?,4]>*}@CLIENTS> -> float32[13,4]@SERVER)

モデルを鍛えよう!

トレーニング機能ができたので、試してみましょう。

server_model = create_logistic_model(WORD_VOCAB_SIZE, TAG_VOCAB_SIZE)
server_model.compile(  # Compile to make evaluation easy.
    optimizer=tf.keras.optimizers.Adagrad(learning_rate=0.0),  # Unused
    loss=tf.keras.losses.BinaryCrossentropy(),
    metrics=[ 
      tf.keras.metrics.Precision(name='precision'),
      tf.keras.metrics.AUC(name='auc'),
      tf.keras.metrics.Recall(top_k=2, name='recall_at_2'),
  ])

def evaluate(model, dataset, name):
  metrics = model.evaluate(dataset, verbose=0)
  metrics_str = ', '.join([f'{k}={v:.2f}' for k, v in 
                          (zip(server_model.metrics_names, metrics))])
  print(f'{name}: {metrics_str}')
print('Before training')
evaluate(server_model, batched_dataset1, 'Client 1')
evaluate(server_model, batched_dataset2, 'Client 2')
evaluate(server_model, batched_dataset3, 'Client 3')

model_weights = server_model.trainable_weights[0]

client_datasets = [batched_dataset1, batched_dataset2, batched_dataset3]
for _ in range(10):  # Run 10 rounds of FedAvg
  # We train on 1, 2, or 3 clients per round, selecting
  # randomly.
  cohort_size = np.random.randint(1, 4)
  clients = np.random.choice([0, 1, 2], cohort_size, replace=False)
  print('Training on clients', clients)
  model_weights = sparse_model_update(
      model_weights, [client_datasets[i] for i in clients])
server_model.set_weights([model_weights])

print('After training')
evaluate(server_model, batched_dataset1, 'Client 1')
evaluate(server_model, batched_dataset2, 'Client 2')
evaluate(server_model, batched_dataset3, 'Client 3')
Before training
Client 1: loss=0.69, precision=0.00, auc=0.50, recall_at_2=0.60
Client 2: loss=0.69, precision=0.00, auc=0.50, recall_at_2=0.50
Client 3: loss=0.69, precision=0.00, auc=0.50, recall_at_2=0.40
Training on clients [0 1]
Training on clients [0 2 1]
Training on clients [2 0]
Training on clients [1 0 2]
Training on clients [2]
Training on clients [2 0]
Training on clients [1 2 0]
Training on clients [0]
Training on clients [2]
Training on clients [1 2]
After training
Client 1: loss=0.67, precision=0.80, auc=0.91, recall_at_2=0.80
Client 2: loss=0.68, precision=0.67, auc=0.96, recall_at_2=1.00
Client 3: loss=0.65, precision=1.00, auc=0.93, recall_at_2=0.80