مشتری بزرگ کارآمد مدل یادگیری فدرال از طریق federated_select و تجمع پراکنده

مشاهده در TensorFlow.org در Google Colab اجرا کنید مشاهده منبع در GitHub دانلود دفترچه یادداشت

این نشان می دهد آموزش نحوه TFF می توان مورد استفاده برای آموزش یک مدل بسیار بزرگ که در آن هر دستگاه سرویس گیرنده تنها دانلود و به روز رسانی یک بخش کوچک از این مدل، با استفاده از tff.federated_select و تجمع پراکنده. در حالی که این آموزش است که نسبتا خودکفا، در tff.federated_select آموزش و سفارشی FL الگوریتم آموزش ارائه معرفی خوبی برای برخی از تکنیک های مورد استفاده در اینجا.

به طور خاص ، در این آموزش ما رگرسیون لجستیک را برای طبقه بندی چند برچسب در نظر می گیریم و پیش بینی می کنیم که کدام "برچسب ها" با یک رشته متن بر اساس نمایش ویژگی کیسه کلمات مرتبط هستند. نکته مهم، هزینه های ارتباطی و محاسبات سمت سرویس گیرنده توسط یک ثابت ثابت (کنترل MAX_TOKENS_SELECTED_PER_CLIENTو با اندازه واژگان کلی، که می تواند در تنظیمات عملی بسیار بزرگ مقیاس نیست.

!pip install --quiet --upgrade tensorflow-federated-nightly
!pip install --quiet --upgrade nest-asyncio

import nest_asyncio
nest_asyncio.apply()
import collections
import itertools
import numpy as np

from typing import Callable, List, Tuple

import tensorflow as tf
import tensorflow_federated as tff

هر مشتری خواهد federated_select ردیف وزن مدل برای حداکثر این بسیاری از نشانه منحصر به فرد. این بالا مرزهای اندازه مدل محلی مشتری و میزان سرور -> مشتری ( federated_select ) و سرویس گیرنده -> سرور (federated_aggregate ) ارتباطات انجام شده است.

این آموزش هنوز باید به درستی اجرا شود ، حتی اگر این عدد را 1 قرار دهید (اطمینان حاصل کنید که همه نشانه های هر کلاینت انتخاب نشده اند) یا مقدار زیادی باشد ، اگرچه ممکن است همگرایی مدل ایجاد شود.

MAX_TOKENS_SELECTED_PER_CLIENT = 6

ما همچنین چند ثابت برای انواع مختلف تعریف می کنیم. برای این COLAB، یک نشانه یک شناسه عدد صحیح برای یک کلمه خاص پس از تجزیه مجموعه داده است.

# There are some constraints on types
# here that will require some explicit type conversions:
#    - `tff.federated_select` requires int32
#    - `tf.SparseTensor` requires int64 indices.
TOKEN_DTYPE = tf.int64
SELECT_KEY_DTYPE = tf.int32

# Type for counts of token occurences.
TOKEN_COUNT_DTYPE = tf.int32

# A sparse feature vector can be thought of as a map
# from TOKEN_DTYPE to FEATURE_DTYPE. 
# Our features are {0, 1} indicators, so we could potentially
# use tf.int8 as an optimization.
FEATURE_DTYPE = tf.int32

راه اندازی مشکل: مجموعه داده و مدل

ما در این آموزش یک مجموعه داده کوچک اسباب بازی برای آزمایش آسان ایجاد می کنیم. با این حال، فرمت از مجموعه داده های سازگار با فدرال استک اورفلو و پیش پردازش و معماری مدل از استک اورفلو مشکل پیش بینی تگ اتخاذ تطبیقی فدرال بهینه سازی .

تجزیه و تحلیل داده ها و پیش پردازش آنها

NUM_OOV_BUCKETS = 1

BatchType = collections.namedtuple('BatchType', ['tokens', 'tags'])

def build_to_ids_fn(word_vocab: List[str],
                    tag_vocab: List[str]) -> Callable[[tf.Tensor], tf.Tensor]:
  """Constructs a function mapping examples to sequences of token indices."""
  word_table_values = np.arange(len(word_vocab), dtype=np.int64)
  word_table = tf.lookup.StaticVocabularyTable(
      tf.lookup.KeyValueTensorInitializer(word_vocab, word_table_values),
      num_oov_buckets=NUM_OOV_BUCKETS)

  tag_table_values = np.arange(len(tag_vocab), dtype=np.int64)
  tag_table = tf.lookup.StaticVocabularyTable(
      tf.lookup.KeyValueTensorInitializer(tag_vocab, tag_table_values),
      num_oov_buckets=NUM_OOV_BUCKETS)

  def to_ids(example):
    """Converts a Stack Overflow example to a bag-of-words/tags format."""
    sentence = tf.strings.join([example['tokens'], example['title']],
                               separator=' ')

    # We represent that label (output tags) densely.
    raw_tags = example['tags']
    tags = tf.strings.split(raw_tags, sep='|')
    tags = tag_table.lookup(tags)
    tags, _ = tf.unique(tags)
    tags = tf.one_hot(tags, len(tag_vocab) + NUM_OOV_BUCKETS)
    tags = tf.reduce_max(tags, axis=0)

    # We represent the features as a SparseTensor of {0, 1}s.
    words = tf.strings.split(sentence)
    tokens = word_table.lookup(words)
    tokens, _ = tf.unique(tokens)
    # Note:  We could choose to use the word counts as the feature vector
    # instead of just {0, 1} values (see tf.unique_with_counts).
    tokens = tf.reshape(tokens, shape=(tf.size(tokens), 1))
    tokens_st = tf.SparseTensor(
        tokens,
        tf.ones(tf.size(tokens), dtype=FEATURE_DTYPE),
        dense_shape=(len(word_vocab) + NUM_OOV_BUCKETS,))
    tokens_st = tf.sparse.reorder(tokens_st)

    return BatchType(tokens_st, tags)

  return to_ids
def build_preprocess_fn(word_vocab, tag_vocab):

  @tf.function
  def preprocess_fn(dataset):
    to_ids = build_to_ids_fn(word_vocab, tag_vocab)
    # We *don't* shuffle in order to make this colab deterministic for
    # easier testing and reproducibility.
    # But real-world training should use `.shuffle()`.
    return dataset.map(to_ids, num_parallel_calls=tf.data.experimental.AUTOTUNE)

  return preprocess_fn

مجموعه داده اسباب بازی کوچک

ما یک مجموعه داده کوچک اسباب بازی با واژگان جهانی 12 کلمه و 3 مشتری ایجاد می کنیم. این مثال کوچک برای تست موارد لبه مفید است (برای مثال، ما دو مشتریان با کمتر از MAX_TOKENS_SELECTED_PER_CLIENT = 6 نشانه مجزا، و یکی با بیشتر) و در حال توسعه کد.

با این حال ، موارد استفاده از این رویکرد در دنیای واقعی واژگان جهانی 10 میلیون یا بیشتر است ، که شاید هزاران نشانه متمایز در هر مشتری ظاهر می شود. از آنجا که فرمت داده همان است، پسوند به مشکلات بستر آزمایشی واقعی تر، به عنوان مثال tff.simulation.datasets.stackoverflow.load_data() مجموعه داده، باید ساده باشد.

ابتدا ، واژگان واژه و برچسب خود را تعریف می کنیم.

# Features
FRUIT_WORDS = ['apple', 'orange', 'pear', 'kiwi']
VEGETABLE_WORDS = ['carrot', 'broccoli', 'arugula', 'peas']
FISH_WORDS = ['trout', 'tuna', 'cod', 'salmon']
WORD_VOCAB = FRUIT_WORDS + VEGETABLE_WORDS + FISH_WORDS

# Labels
TAG_VOCAB = ['FRUIT', 'VEGETABLE', 'FISH']

در حال حاضر ، ما 3 مشتری با مجموعه داده های محلی کوچک ایجاد می کنیم. اگر این آموزش را در colab اجرا می کنید ، ممکن است استفاده از ویژگی "سلول آینه در برگه" برای پین کردن این سلول و خروجی آن به منظور تفسیر/بررسی خروجی توابع ایجاد شده در زیر مفید باشد.

preprocess_fn = build_preprocess_fn(WORD_VOCAB, TAG_VOCAB)


def make_dataset(raw):
  d = tf.data.Dataset.from_tensor_slices(
      # Matches the StackOverflow formatting
      collections.OrderedDict(
          tokens=tf.constant([t[0] for t in raw]),
          tags=tf.constant([t[1] for t in raw]),
          title=['' for _ in raw]))
  d = preprocess_fn(d)
  return d


# 4 distinct tokens
CLIENT1_DATASET = make_dataset([
    ('apple orange apple orange', 'FRUIT'),
    ('carrot trout', 'VEGETABLE|FISH'),
    ('orange apple', 'FRUIT'),
    ('orange', 'ORANGE|CITRUS')  # 2 OOV tag
])

# 6 distinct tokens
CLIENT2_DATASET = make_dataset([
    ('pear cod', 'FRUIT|FISH'),
    ('arugula peas', 'VEGETABLE'),
    ('kiwi pear', 'FRUIT'),
    ('sturgeon', 'FISH'),  # OOV word
    ('sturgeon bass', 'FISH')  # 2 OOV words
])

# A client with all possible words & tags (13 distinct tokens).
# With MAX_TOKENS_SELECTED_PER_CLIENT = 6, we won't download the model
# slices for all tokens that occur on this client.
CLIENT3_DATASET = make_dataset([
    (' '.join(WORD_VOCAB + ['oovword']), '|'.join(TAG_VOCAB)),
    # Mathe the OOV token and 'salmon' occur in the largest number
    # of examples on this client:
    ('salmon oovword', 'FISH|OOVTAG')
])

print('Word vocab')
for i, word in enumerate(WORD_VOCAB):
  print(f'{i:2d} {word}')

print('\nTag vocab')
for i, tag in enumerate(TAG_VOCAB):
  print(f'{i:2d} {tag}')
Word vocab
 0 apple
 1 orange
 2 pear
 3 kiwi
 4 carrot
 5 broccoli
 6 arugula
 7 peas
 8 trout
 9 tuna
10 cod
11 salmon

Tag vocab
 0 FRUIT
 1 VEGETABLE
 2 FISH

برای اعداد خام ویژگی های ورودی (نشانه ها/کلمات) و برچسب ها (برچسب های پست) ثابت تعریف کنید. فضاهای ورودی / خروجی واقعی ما هستند NUM_OOV_BUCKETS = 1 بزرگتر از آنجا که ما اضافه کردن یک OOV رمز / برچسب.

NUM_WORDS = len(WORD_VOCAB) 
NUM_TAGS = len(TAG_VOCAB)

WORD_VOCAB_SIZE = NUM_WORDS + NUM_OOV_BUCKETS
TAG_VOCAB_SIZE = NUM_TAGS + NUM_OOV_BUCKETS

نسخه های دسته ای مجموعه های داده و دسته های جداگانه ایجاد کنید ، که در آزمایش کد مفید خواهد بود.

batched_dataset1 = CLIENT1_DATASET.batch(2)
batched_dataset2 = CLIENT2_DATASET.batch(3)
batched_dataset3 = CLIENT3_DATASET.batch(2)

batch1 = next(iter(batched_dataset1))
batch2 = next(iter(batched_dataset2))
batch3 = next(iter(batched_dataset3))

یک مدل با ورودی های کم تعریف کنید

ما از یک مدل رگرسیون لجستیک ساده مستقل برای هر برچسب استفاده می کنیم.

def create_logistic_model(word_vocab_size: int, vocab_tags_size: int):

  model = tf.keras.models.Sequential([
      tf.keras.layers.InputLayer(input_shape=(word_vocab_size,), sparse=True),
      tf.keras.layers.Dense(
          vocab_tags_size,
          activation='sigmoid',
          kernel_initializer=tf.keras.initializers.zeros,
          # For simplicity, don't use a bias vector; this means the model
          # is a single tensor, and we only need sparse aggregation of
          # the per-token slices of the model. Generalizing to also handle
          # other model weights that are fully updated 
          # (non-dense broadcast and aggregate) would be a good exercise.
          use_bias=False),
  ])

  return model

اجازه دهید مطمئن شویم که کار می کند ، ابتدا با پیش بینی:

model = create_logistic_model(WORD_VOCAB_SIZE, TAG_VOCAB_SIZE)
p = model.predict(batch1.tokens)
print(p)
[[0.5 0.5 0.5 0.5]
 [0.5 0.5 0.5 0.5]]

و چند آموزش ساده متمرکز:

model.compile(optimizer=tf.keras.optimizers.Adagrad(learning_rate=0.001),
              loss=tf.keras.losses.BinaryCrossentropy())
model.train_on_batch(batch1.tokens, batch1.tags)

بلوک های سازنده برای محاسبه فدرال

ما یک نسخه ساده از پیاده سازی به طور متوسط فدرال الگوریتم با این تفاوت کلیدی است که هر یک از دستگاه فقط دانلود یک زیر مجموعه مربوط به مدل، و تنها منجر به روز رسانی به که زیر مجموعه.

ما با استفاده از M به عنوان نماد MAX_TOKENS_SELECTED_PER_CLIENT . در سطح بالا ، یک دور آموزش شامل این مراحل است:

  1. هر مشتری شرکت کننده روی مجموعه داده محلی خود اسکن می کند ، رشته های ورودی را تجزیه می کند و آنها را در نشانه های صحیح (نمایه های int) ترسیم می کند. این نیاز به دسترسی به فرهنگ لغت جهانی (بزرگ) (این طور بالقوه می تواند با استفاده از اجتناب شود ویژگی هش کردن تکنیک). سپس تعداد دفعاتی که هر توکن رخ می دهد را پراکنده می شماریم. اگر U نشانه منحصر به فرد بر روی دستگاه رخ می دهد، ما را انتخاب num_actual_tokens = min(U, M) ترین نشانه مکرر به قطار.

  2. مشتریان استفاده federated_select برای بازیابی ضرایب مدل برای num_actual_tokens نشانه انتخاب از سرور. هر تکه مدل یک تانسور شکل است (TAG_VOCAB_SIZE, ) ، به طوری که کل داده های منتقل شده به مشتری است در بسیاری از اندازه TAG_VOCAB_SIZE * M (توجه داشته باشید پایین را ببینید).

  3. مشتریان ساخت یک نقشه برداری global_token -> local_token که در آن رمز محلی (شاخص INT) شاخص از این رمز جهانی در فهرست نشانه انتخاب شده است.

  4. مشتریان استفاده از یک "کوچک" نسخه ای از مدل جهانی است که تنها ضرایب برای حداکثر M نشانه، از محدوده [0, num_actual_tokens) . global -> local نقشه برداری استفاده برای مقدار دهی اولیه پارامترهای متراکم از این مدل از مدل برش انتخاب شده است.

  5. مشتریان آموزش مدل محلی خود را با استفاده SGD بر روی داده های پیش پردازش با global -> local نقشه برداری.

  6. مشتریان تبدیل پارامترهای مدل محلی خود را به IndexedSlices با استفاده از به روز رسانی local -> global نقشه برداری به صفحه اول ردیف. سرور این به روزرسانی ها را با استفاده از جمع جمع پراکنده تجمیع می کند.

  7. سرور نتیجه (متراکم) تجمع فوق را می گیرد ، آن را بر تعداد کلاینت های شرکت کننده تقسیم می کند و میانگین به روز شده حاصله را بر روی مدل جهانی اعمال می کند.

در این بخش ما بلوک های ساختمان برای این مراحل را، که پس از آن در یک فینال خواهد در ترکیب ساخت federated_computation که قطاری از منطق کامل از یک دور آموزش.

تعداد نشانه مشتری و تصمیم می گیرید که برش مدل به federated_select

هر دستگاه باید تصمیم بگیرد که کدام "برش" از مدل با مجموعه داده آموزشی محلی آن مرتبط است. برای مشکل ما ، ما این کار را (به صورت پراکنده) با شمارش تعداد نمونه های موجود در هر نشانه در مجموعه داده های آموزش مشتری انجام می دهیم.

@tf.function
def token_count_fn(token_counts, batch):
  """Adds counts from `batch` to the running `token_counts` sum."""
  # Sum across the batch dimension.
  flat_tokens = tf.sparse.reduce_sum(
      batch.tokens, axis=0, output_is_sparse=True)
  flat_tokens = tf.cast(flat_tokens, dtype=TOKEN_COUNT_DTYPE)
  return tf.sparse.add(token_counts, flat_tokens)
# Simple tests
# Create the initial zero token counts using empty tensors.
initial_token_counts = tf.SparseTensor(
    indices=tf.zeros(shape=(0, 1), dtype=TOKEN_DTYPE),
    values=tf.zeros(shape=(0,), dtype=TOKEN_COUNT_DTYPE),
    dense_shape=(WORD_VOCAB_SIZE,))

client_token_counts = batched_dataset1.reduce(initial_token_counts,
                                              token_count_fn)
tokens = tf.reshape(client_token_counts.indices, (-1,)).numpy()
print('tokens:', tokens)
np.testing.assert_array_equal(tokens, [0, 1, 4, 8])
# The count is the number of *examples* in which the token/word
# occurs, not the total number of occurences, since we still featurize
# multiple occurences in the same example as a "1".
counts = client_token_counts.values.numpy()
print('counts:', counts)
np.testing.assert_array_equal(counts, [2, 3, 1, 1])
tokens: [0 1 4 8]
counts: [2 3 1 1]

ما پارامترهای مدل مربوط به را انتخاب کنید MAX_TOKENS_SELECTED_PER_CLIENT اغلب رخ نشانه بر روی دستگاه. اگر کمتر از این بسیاری از نشانه بر روی دستگاه رخ می دهد، ما پد لیست برای فعال کردن استفاده از federated_select .

توجه داشته باشید که استراتژی های دیگر احتمالاً بهتر هستند ، به عنوان مثال ، انتخاب تصادفی توکن ها (شاید بر اساس احتمال وقوع آنها). این امر اطمینان می دهد که همه برش های مدل (که مشتری برای آنها اطلاعات دارد) دارای شانس به روز رسانی هستند.

@tf.function
def keys_for_client(client_dataset, max_tokens_per_client):
  """Computes a set of max_tokens_per_client keys."""
  initial_token_counts = tf.SparseTensor(
      indices=tf.zeros((0, 1), dtype=TOKEN_DTYPE),
      values=tf.zeros((0,), dtype=TOKEN_COUNT_DTYPE),
      dense_shape=(WORD_VOCAB_SIZE,))
  client_token_counts = client_dataset.reduce(initial_token_counts,
                                              token_count_fn)
  # Find the most-frequently occuring tokens
  tokens = tf.reshape(client_token_counts.indices, shape=(-1,))
  counts = client_token_counts.values
  perm = tf.argsort(counts, direction='DESCENDING')
  tokens = tf.gather(tokens, perm)
  counts = tf.gather(counts, perm)
  num_raw_tokens = tf.shape(tokens)[0]
  actual_num_tokens = tf.minimum(max_tokens_per_client, num_raw_tokens)
  selected_tokens = tokens[:actual_num_tokens]
  paddings = [[0, max_tokens_per_client - tf.shape(selected_tokens)[0]]]
  padded_tokens = tf.pad(selected_tokens, paddings=paddings)
  # Make sure the type is statically determined
  padded_tokens = tf.reshape(padded_tokens, shape=(max_tokens_per_client,))

  # We will pass these tokens as keys into `federated_select`, which
  # requires SELECT_KEY_DTYPE=tf.int32 keys.
  padded_tokens = tf.cast(padded_tokens, dtype=SELECT_KEY_DTYPE)
  return padded_tokens, actual_num_tokens
# Simple test

# Case 1: actual_num_tokens > max_tokens_per_client
selected_tokens, actual_num_tokens = keys_for_client(batched_dataset1, 3)
assert tf.size(selected_tokens) == 3
assert actual_num_tokens == 3

# Case 2: actual_num_tokens < max_tokens_per_client
selected_tokens, actual_num_tokens = keys_for_client(batched_dataset1, 10)
assert tf.size(selected_tokens) == 10
assert actual_num_tokens == 4

نشانه های جهانی را به نشانه های محلی ترسیم کنید

انتخاب بالا به ما می دهد مجموعه ای متراکم از نشانه در محدوده [0, actual_num_tokens) که ما برای مدل بر روی دستگاه استفاده خواهد کرد. با این حال، مجموعه داده ما به عنوان خوانده است نشانه از محدوده بسیار بزرگتر جهانی واژگان [0, WORD_VOCAB_SIZE) .

بنابراین ، ما باید نشانه های جهانی را به توکن های محلی مربوطه ترسیم کنیم. شناسه رمز محلی به سادگی توسط شاخص های به داده selected_tokens تانسور در مرحله قبل محاسبه می شود.

@tf.function
def map_to_local_token_ids(client_data, client_keys):
  global_to_local = tf.lookup.StaticHashTable(
      # Note int32 -> int64 maps are not supported
      tf.lookup.KeyValueTensorInitializer(
          keys=tf.cast(client_keys, dtype=TOKEN_DTYPE),
          # Note we need to use tf.shape, not the static 
          # shape client_keys.shape[0]
          values=tf.range(0, limit=tf.shape(client_keys)[0],
                          dtype=TOKEN_DTYPE)),
      # We use -1 for tokens that were not selected, which can occur for clients
      # with more than MAX_TOKENS_SELECTED_PER_CLIENT distinct tokens.
      # We will simply remove these invalid indices from the batch below.
      default_value=-1)

  def to_local_ids(sparse_tokens):
    indices_t = tf.transpose(sparse_tokens.indices)
    batch_indices = indices_t[0]  # First column
    tokens = indices_t[1]  # Second column
    tokens = tf.map_fn(
        lambda global_token_id: global_to_local.lookup(global_token_id), tokens)
    # Remove tokens that aren't actually available (looked up as -1):
    available_tokens = tokens >= 0
    tokens = tokens[available_tokens]
    batch_indices = batch_indices[available_tokens]

    updated_indices = tf.transpose(
        tf.concat([[batch_indices], [tokens]], axis=0))
    st = tf.sparse.SparseTensor(
        updated_indices,
        tf.ones(tf.size(tokens), dtype=FEATURE_DTYPE),
        dense_shape=sparse_tokens.dense_shape)
    st = tf.sparse.reorder(st)
    return st

  return client_data.map(lambda b: BatchType(to_local_ids(b.tokens), b.tags))
# Simple test
client_keys, actual_num_tokens = keys_for_client(
    batched_dataset3, MAX_TOKENS_SELECTED_PER_CLIENT)
client_keys = client_keys[:actual_num_tokens]

d = map_to_local_token_ids(batched_dataset3, client_keys)
batch  = next(iter(d))
all_tokens = tf.gather(batch.tokens.indices, indices=1, axis=1)
# Confirm we have local indices in the range [0, MAX):
assert tf.math.reduce_max(all_tokens) < MAX_TOKENS_SELECTED_PER_CLIENT
assert tf.math.reduce_max(all_tokens) >= 0

مدل محلی (فرعی) را روی هر مشتری آموزش دهید

توجه داشته باشید federated_select خواهد برش انتخاب به عنوان بازگشت tf.data.Dataset در همان جهت به عنوان کلید انتخاب کنید. بنابراین ، ما ابتدا یک تابع سودمند برای گرفتن چنین مجموعه داده ای و تبدیل آن به یک تانسور متراکم تکی تعریف می کنیم که می تواند به عنوان وزن مدل مدل مشتری استفاده شود.

@tf.function
def slices_dataset_to_tensor(slices_dataset):
  """Convert a dataset of slices to a tensor."""
  # Use batching to gather all of the slices into a single tensor.
  d = slices_dataset.batch(MAX_TOKENS_SELECTED_PER_CLIENT,
                           drop_remainder=False)
  iter_d = iter(d)
  tensor = next(iter_d)
  # Make sure we have consumed everything
  opt = iter_d.get_next_as_optional()
  tf.Assert(tf.logical_not(opt.has_value()), data=[''], name='CHECK_EMPTY')
  return tensor
# Simple test
weights = np.random.random(
    size=(MAX_TOKENS_SELECTED_PER_CLIENT, TAG_VOCAB_SIZE)).astype(np.float32)
model_slices_as_dataset = tf.data.Dataset.from_tensor_slices(weights)
weights2 = slices_dataset_to_tensor(model_slices_as_dataset)
np.testing.assert_array_equal(weights, weights2)

ما در حال حاضر تمام اجزای مورد نیاز برای تعریف یک حلقه آموزشی ساده محلی را داریم که برای هر مشتری اجرا می شود.

@tf.function
def client_train_fn(model, client_optimizer,
                    model_slices_as_dataset, client_data,
                    client_keys, actual_num_tokens):

  initial_model_weights = slices_dataset_to_tensor(model_slices_as_dataset)
  assert len(model.trainable_variables) == 1
  model.trainable_variables[0].assign(initial_model_weights)

  # Only keep the "real" (unpadded) keys.
  client_keys = client_keys[:actual_num_tokens]

  client_data = map_to_local_token_ids(client_data, client_keys)

  loss_fn = tf.keras.losses.BinaryCrossentropy()
  for features, labels in client_data:
    with tf.GradientTape() as tape:
      predictions = model(features)
      loss = loss_fn(labels, predictions)
    grads = tape.gradient(loss, model.trainable_variables)
    client_optimizer.apply_gradients(zip(grads, model.trainable_variables))

  model_weights_delta = model.trainable_weights[0] - initial_model_weights
  model_weights_delta = tf.slice(model_weights_delta, begin=[0, 0], 
                           size=[actual_num_tokens, -1])
  return client_keys, model_weights_delta
# Simple test
# Note if you execute this cell a second time, you need to also re-execute
# the preceeding cell to avoid "tf.function-decorated function tried to 
# create variables on non-first call" errors.
on_device_model = create_logistic_model(MAX_TOKENS_SELECTED_PER_CLIENT,
                                        TAG_VOCAB_SIZE)
client_optimizer = tf.keras.optimizers.SGD(learning_rate=0.001)
client_keys, actual_num_tokens = keys_for_client(
    batched_dataset2, MAX_TOKENS_SELECTED_PER_CLIENT)

model_slices_as_dataset = tf.data.Dataset.from_tensor_slices(
    np.zeros((MAX_TOKENS_SELECTED_PER_CLIENT, TAG_VOCAB_SIZE),
             dtype=np.float32))

keys, delta = client_train_fn(
    on_device_model,
    client_optimizer,
    model_slices_as_dataset,
    client_data=batched_dataset3,
    client_keys=client_keys,
    actual_num_tokens=actual_num_tokens)

print(delta)

برش های فهرست بندی شده

ما با استفاده از tff.federated_aggregate برای ساخت یک جمع پراکنده فدرال برای IndexedSlices . این پیاده سازی ساده است محدودیت که dense_shape آماری در پیش شناخته شده است. توجه داشته باشید که این مبلغ تنها نیمه پراکنده، به این معنا است که مشتری -> ارتباط سرور است پراکنده، اما سرور حفظ یک نمایش متراکم از مبلغ در accumulate و merge ، و خروجی این نمایندگی متراکم.

def federated_indexed_slices_sum(slice_indices, slice_values, dense_shape):
  """
  Sumes IndexedSlices@CLIENTS to a dense @SERVER Tensor.

  Intermediate aggregation is performed by converting to a dense representation,
  which may not be suitable for all applications.

  Args:
    slice_indices: An IndexedSlices.indices tensor @CLIENTS.
    slice_values: An IndexedSlices.values tensor @CLIENTS.
    dense_shape: A statically known dense shape.

  Returns:
    A dense tensor placed @SERVER representing the sum of the client's
    IndexedSclies.
  """
  slices_dtype = slice_values.type_signature.member.dtype
  zero = tff.tf_computation(
      lambda: tf.zeros(dense_shape, dtype=slices_dtype))()

  @tf.function
  def accumulate_slices(dense, client_value):
    indices, slices = client_value
    # There is no built-in way to add `IndexedSlices`, but 
    # tf.convert_to_tensor is a quick way to convert to a dense representation
    # so we can add them.
    return dense + tf.convert_to_tensor(
        tf.IndexedSlices(slices, indices, dense_shape))


  return tff.federated_aggregate(
      (slice_indices, slice_values),
      zero=zero,
      accumulate=tff.tf_computation(accumulate_slices),
      merge=tff.tf_computation(lambda d1, d2: tf.add(d1, d2, name='merge')),
      report=tff.tf_computation(lambda d: d))

ساخت یک حداقل federated_computation به عنوان یک آزمون

dense_shape = (6, 2)
indices_type = tff.TensorType(tf.int64, (None,))
values_type = tff.TensorType(tf.float32, (None, 2))
client_slice_type = tff.type_at_clients(
    (indices_type, values_type))

@tff.federated_computation(client_slice_type)
def test_sum_indexed_slices(indices_values_at_client):
  indices, values = indices_values_at_client
  return federated_indexed_slices_sum(indices, values, dense_shape)

print(test_sum_indexed_slices.type_signature)
({<int64[?],float32[?,2]>}@CLIENTS -> float32[6,2]@SERVER)
x = tf.IndexedSlices(
    values=np.array([[2., 2.1], [0., 0.1], [1., 1.1], [5., 5.1]],
                    dtype=np.float32),
    indices=[2, 0, 1, 5],
    dense_shape=dense_shape)
y = tf.IndexedSlices(
    values=np.array([[0., 0.3], [3.1, 3.2]], dtype=np.float32),
    indices=[1, 3],
    dense_shape=dense_shape)

# Sum one.
result = test_sum_indexed_slices([(x.indices, x.values)])
np.testing.assert_array_equal(tf.convert_to_tensor(x), result)

# Sum two.
expected = [[0., 0.1], [1., 1.4], [2., 2.1], [3.1, 3.2], [0., 0.], [5., 5.1]]
result = test_sum_indexed_slices([(x.indices, x.values), (y.indices, y.values)])
np.testing.assert_array_almost_equal(expected, result)

قرار دادن آن همه با هم در یک federated_computation

ما اکنون با استفاده از TFF به هم پیوند می اجزاء به یک tff.federated_computation .

DENSE_MODEL_SHAPE = (WORD_VOCAB_SIZE, TAG_VOCAB_SIZE)
client_data_type = tff.SequenceType(batched_dataset1.element_spec)
model_type = tff.TensorType(tf.float32, shape=DENSE_MODEL_SHAPE)

ما از یک تابع آموزش سرور پایه بر اساس میانگین یابی فدرال استفاده می کنیم و از به روز رسانی با نرخ یادگیری سرور 1.0 استفاده می کنیم. این مهم است که ما به جای مدل متوسط ​​ارائه شده توسط مشتری ، به روزرسانی (دلتا) را بر روی مدل اعمال کنیم ، زیرا در غیر این صورت اگر یک برش مشخص از مدل توسط هیچ مشتری در یک دور معین آموزش داده نشود ، ضرایب آن می تواند صفر شود بیرون

@tff.tf_computation
def server_update(current_model_weights, update_sum, num_clients):
  average_update = update_sum / num_clients
  return current_model_weights + average_update

ما نیاز به یک زن و شوهر بیشتر tff.tf_computation قطعات:

# Function to select slices from the model weights in federated_select:
select_fn = tff.tf_computation(
    lambda model_weights, index: tf.gather(model_weights, index))


# We need to wrap `client_train_fn` as a `tff.tf_computation`, making
# sure we do any operations that might construct `tf.Variable`s outside
# of the `tf.function` we are wrapping.
@tff.tf_computation
def client_train_fn_tff(model_slices_as_dataset, client_data, client_keys,
                        actual_num_tokens):
  # Note this is amaller than the global model, using
  # MAX_TOKENS_SELECTED_PER_CLIENT which is much smaller than WORD_VOCAB_SIZE.
  # W7e would like a model of size `actual_num_tokens`, but we
  # can't build the model dynamically, so we will slice off the padded
  # weights at the end.
  client_model = create_logistic_model(MAX_TOKENS_SELECTED_PER_CLIENT,
                                       TAG_VOCAB_SIZE)
  client_optimizer = tf.keras.optimizers.SGD(learning_rate=0.1)
  return client_train_fn(client_model, client_optimizer,
                         model_slices_as_dataset, client_data, client_keys,
                         actual_num_tokens)

@tff.tf_computation
def keys_for_client_tff(client_data):
  return keys_for_client(client_data, MAX_TOKENS_SELECTED_PER_CLIENT)

ما اکنون آماده ایم که همه قطعات را کنار هم قرار دهیم!

@tff.federated_computation(
    tff.type_at_server(model_type), tff.type_at_clients(client_data_type))
def sparse_model_update(server_model, client_data):
  max_tokens = tff.federated_value(MAX_TOKENS_SELECTED_PER_CLIENT, tff.SERVER)
  keys_at_clients, actual_num_tokens = tff.federated_map(
      keys_for_client_tff, client_data)

  model_slices = tff.federated_select(keys_at_clients, max_tokens, server_model,
                                      select_fn)

  update_keys, update_slices = tff.federated_map(
      client_train_fn_tff,
      (model_slices, client_data, keys_at_clients, actual_num_tokens))

  dense_update_sum = federated_indexed_slices_sum(update_keys, update_slices,
                                                  DENSE_MODEL_SHAPE)
  num_clients = tff.federated_sum(tff.federated_value(1.0, tff.CLIENTS))

  updated_server_model = tff.federated_map(
      server_update, (server_model, dense_update_sum, num_clients))

  return updated_server_model


print(sparse_model_update.type_signature)
(<server_model=float32[13,4]@SERVER,client_data={<tokens=<indices=int64[?,2],values=int32[?],dense_shape=int64[2]>,tags=float32[?,4]>*}@CLIENTS> -> float32[13,4]@SERVER)

بیایید یک مدل آموزش دهیم!

اکنون که عملکرد آموزشی خود را داریم ، بیایید آن را امتحان کنیم.

server_model = create_logistic_model(WORD_VOCAB_SIZE, TAG_VOCAB_SIZE)
server_model.compile(  # Compile to make evaluation easy.
    optimizer=tf.keras.optimizers.Adagrad(learning_rate=0.0),  # Unused
    loss=tf.keras.losses.BinaryCrossentropy(),
    metrics=[ 
      tf.keras.metrics.Precision(name='precision'),
      tf.keras.metrics.AUC(name='auc'),
      tf.keras.metrics.Recall(top_k=2, name='recall_at_2'),
  ])

def evaluate(model, dataset, name):
  metrics = model.evaluate(dataset, verbose=0)
  metrics_str = ', '.join([f'{k}={v:.2f}' for k, v in 
                          (zip(server_model.metrics_names, metrics))])
  print(f'{name}: {metrics_str}')
print('Before training')
evaluate(server_model, batched_dataset1, 'Client 1')
evaluate(server_model, batched_dataset2, 'Client 2')
evaluate(server_model, batched_dataset3, 'Client 3')

model_weights = server_model.trainable_weights[0]

client_datasets = [batched_dataset1, batched_dataset2, batched_dataset3]
for _ in range(10):  # Run 10 rounds of FedAvg
  # We train on 1, 2, or 3 clients per round, selecting
  # randomly.
  cohort_size = np.random.randint(1, 4)
  clients = np.random.choice([0, 1, 2], cohort_size, replace=False)
  print('Training on clients', clients)
  model_weights = sparse_model_update(
      model_weights, [client_datasets[i] for i in clients])
server_model.set_weights([model_weights])

print('After training')
evaluate(server_model, batched_dataset1, 'Client 1')
evaluate(server_model, batched_dataset2, 'Client 2')
evaluate(server_model, batched_dataset3, 'Client 3')
Before training
Client 1: loss=0.69, precision=0.00, auc=0.50, recall_at_2=0.60
Client 2: loss=0.69, precision=0.00, auc=0.50, recall_at_2=0.50
Client 3: loss=0.69, precision=0.00, auc=0.50, recall_at_2=0.40
Training on clients [0 1]
Training on clients [0 2 1]
Training on clients [2 0]
Training on clients [1 0 2]
Training on clients [2]
Training on clients [2 0]
Training on clients [1 2 0]
Training on clients [0]
Training on clients [2]
Training on clients [1 2]
After training
Client 1: loss=0.67, precision=0.80, auc=0.91, recall_at_2=0.80
Client 2: loss=0.68, precision=0.67, auc=0.96, recall_at_2=1.00
Client 3: loss=0.65, precision=1.00, auc=0.93, recall_at_2=0.80