למידה מאוחדת של מודלים גדולים חסכוניים ללקוח באמצעות federated_select וצבירה דלילה

צפה ב- TensorFlow.org הפעל בגוגל קולאב צפה במקור ב- GitHub הורד מחברת

הדרכה זו מראה כיצד ניתן להשתמש ב- TFF לאימון מודל גדול מאוד כאשר כל מכשיר לקוח מוריד ומעדכן רק חלק קטן מהמודל, באמצעות tff.federated_select וצבירה דלילה. למרות tff.federated_select זה הוא עצמאי למדי, ההדרכה tff.federated_select האלגוריתמים המותאמים אישית של FL מספקים הקדמות טובות לחלק מהטכניקות המשמשות כאן.

באופן קונקרטי, במדריך זה אנו רואים רגרסיה לוגיסטית לסיווג רב-תווי, תוך ניבוי אילו "תגים" משויכים למחרוזת טקסט המבוססת על ייצוג תכונות של תיק מילים. חשוב לציין כי עלויות החישוב בצד הלקוח נשלטות על ידי קבוע קבוע ( MAX_TOKENS_SELECTED_PER_CLIENT ), ואינן MAX_TOKENS_SELECTED_PER_CLIENT לפי גודל אוצר המילים הכללי, שיכול להיות גדול במיוחד בהגדרות מעשיות.

!pip install --quiet --upgrade tensorflow-federated-nightly
!pip install --quiet --upgrade nest-asyncio

import nest_asyncio
nest_asyncio.apply()
import collections
import itertools
import numpy as np

from typing import Callable, List, Tuple

import tensorflow as tf
import tensorflow_federated as tff

כל לקוח יהיה federated_select השורות של משקלות המודל לכל היותר זה הרבה אסימונים ייחודיים. זה גבול עליון בגודל המודל המקומי של הלקוח וכמות השרת -> לקוח ( federated_select ) ולקוח -> שרת (federated_aggregate ) שבוצע.

מדריך זה עדיין אמור לפעול כהלכה גם אם תגדיר את זה כ- 1 (כדי לוודא שלא כל האסימונים מכל לקוח נבחרים) או לערך גדול, אם כי ניתן להתכנס למודל.

MAX_TOKENS_SELECTED_PER_CLIENT = 6

אנו מגדירים גם מספר קבועים לסוגים שונים. עבור colab זה, אסימון הוא מזהה שלם עבור מילה מסוימת לאחר ניתוח מערך הנתונים.

# There are some constraints on types
# here that will require some explicit type conversions:
#    - `tff.federated_select` requires int32
#    - `tf.SparseTensor` requires int64 indices.
TOKEN_DTYPE = tf.int64
SELECT_KEY_DTYPE = tf.int32

# Type for counts of token occurences.
TOKEN_COUNT_DTYPE = tf.int32

# A sparse feature vector can be thought of as a map
# from TOKEN_DTYPE to FEATURE_DTYPE. 
# Our features are {0, 1} indicators, so we could potentially
# use tf.int8 as an optimization.
FEATURE_DTYPE = tf.int32

הגדרת הבעיה: מערך נתונים ומודל

אנו בונים מערך צעצועים זעיר לניסויים קלים במדריך זה. עם זאת, הפורמט של מערך הנתונים תואם את Federated StackOverflow , והארכיטקטורה של העיבוד המוקדם והמודל מאומצים מבעיית החיזוי של תג תג StackOverflow של Adaptive Federated Optimization .

ניתוח מערכי נתונים ועיבוד מקדים

NUM_OOV_BUCKETS = 1

BatchType = collections.namedtuple('BatchType', ['tokens', 'tags'])

def build_to_ids_fn(word_vocab: List[str],
                    tag_vocab: List[str]) -> Callable[[tf.Tensor], tf.Tensor]:
  """Constructs a function mapping examples to sequences of token indices."""
  word_table_values = np.arange(len(word_vocab), dtype=np.int64)
  word_table = tf.lookup.StaticVocabularyTable(
      tf.lookup.KeyValueTensorInitializer(word_vocab, word_table_values),
      num_oov_buckets=NUM_OOV_BUCKETS)

  tag_table_values = np.arange(len(tag_vocab), dtype=np.int64)
  tag_table = tf.lookup.StaticVocabularyTable(
      tf.lookup.KeyValueTensorInitializer(tag_vocab, tag_table_values),
      num_oov_buckets=NUM_OOV_BUCKETS)

  def to_ids(example):
    """Converts a Stack Overflow example to a bag-of-words/tags format."""
    sentence = tf.strings.join([example['tokens'], example['title']],
                               separator=' ')

    # We represent that label (output tags) densely.
    raw_tags = example['tags']
    tags = tf.strings.split(raw_tags, sep='|')
    tags = tag_table.lookup(tags)
    tags, _ = tf.unique(tags)
    tags = tf.one_hot(tags, len(tag_vocab) + NUM_OOV_BUCKETS)
    tags = tf.reduce_max(tags, axis=0)

    # We represent the features as a SparseTensor of {0, 1}s.
    words = tf.strings.split(sentence)
    tokens = word_table.lookup(words)
    tokens, _ = tf.unique(tokens)
    # Note:  We could choose to use the word counts as the feature vector
    # instead of just {0, 1} values (see tf.unique_with_counts).
    tokens = tf.reshape(tokens, shape=(tf.size(tokens), 1))
    tokens_st = tf.SparseTensor(
        tokens,
        tf.ones(tf.size(tokens), dtype=FEATURE_DTYPE),
        dense_shape=(len(word_vocab) + NUM_OOV_BUCKETS,))
    tokens_st = tf.sparse.reorder(tokens_st)

    return BatchType(tokens_st, tags)

  return to_ids
def build_preprocess_fn(word_vocab, tag_vocab):

  @tf.function
  def preprocess_fn(dataset):
    to_ids = build_to_ids_fn(word_vocab, tag_vocab)
    # We *don't* shuffle in order to make this colab deterministic for
    # easier testing and reproducibility.
    # But real-world training should use `.shuffle()`.
    return dataset.map(to_ids, num_parallel_calls=tf.data.experimental.AUTOTUNE)

  return preprocess_fn

מערך צעצועים זעיר

אנו בונים מערך צעצועים זעיר עם אוצר מילים עולמי של 12 מילים ושלושה לקוחות. דוגמה זעירה זו שימושית לבדיקת מקרי קצה (לדוגמה, יש לנו שני לקוחות עם פחות מ- MAX_TOKENS_SELECTED_PER_CLIENT = 6 אסימונים נפרדים, ואחד עם יותר) ופיתוח הקוד.

עם זאת, מקרי השימוש בעולם האמיתי של גישה זו יהיו אוצר מילים עולמי של 10 מיליונים ומעלה, ואולי מופיעים אלפי אסימונים מובחנים על כל לקוח. מכיוון שפורמט הנתונים זהה, ההרחבה לבעיות tff.simulation.datasets.stackoverflow.load_data() מציאותי יותר, למשל מערך tff.simulation.datasets.stackoverflow.load_data() , צריכה להיות פשוטה.

ראשית, אנו מגדירים את המילה שלנו ומתייגים אוצר מילים.

# Features
FRUIT_WORDS = ['apple', 'orange', 'pear', 'kiwi']
VEGETABLE_WORDS = ['carrot', 'broccoli', 'arugula', 'peas']
FISH_WORDS = ['trout', 'tuna', 'cod', 'salmon']
WORD_VOCAB = FRUIT_WORDS + VEGETABLE_WORDS + FISH_WORDS

# Labels
TAG_VOCAB = ['FRUIT', 'VEGETABLE', 'FISH']

כעת אנו יוצרים 3 לקוחות עם מערכי נתונים קטנים מקומיים. אם אתה מריץ הדרכה זו ב- colab, יכול להיות שימושי להשתמש בתכונה "תא מראה בכרטיסייה" כדי להצמיד תא זה ואת תפוקתו על מנת לפרש / לבדוק את תפוקת הפונקציות שפותחו להלן.

preprocess_fn = build_preprocess_fn(WORD_VOCAB, TAG_VOCAB)


def make_dataset(raw):
  d = tf.data.Dataset.from_tensor_slices(
      # Matches the StackOverflow formatting
      collections.OrderedDict(
          tokens=tf.constant([t[0] for t in raw]),
          tags=tf.constant([t[1] for t in raw]),
          title=['' for _ in raw]))
  d = preprocess_fn(d)
  return d


# 4 distinct tokens
CLIENT1_DATASET = make_dataset([
    ('apple orange apple orange', 'FRUIT'),
    ('carrot trout', 'VEGETABLE|FISH'),
    ('orange apple', 'FRUIT'),
    ('orange', 'ORANGE|CITRUS')  # 2 OOV tag
])

# 6 distinct tokens
CLIENT2_DATASET = make_dataset([
    ('pear cod', 'FRUIT|FISH'),
    ('arugula peas', 'VEGETABLE'),
    ('kiwi pear', 'FRUIT'),
    ('sturgeon', 'FISH'),  # OOV word
    ('sturgeon bass', 'FISH')  # 2 OOV words
])

# A client with all possible words & tags (13 distinct tokens).
# With MAX_TOKENS_SELECTED_PER_CLIENT = 6, we won't download the model
# slices for all tokens that occur on this client.
CLIENT3_DATASET = make_dataset([
    (' '.join(WORD_VOCAB + ['oovword']), '|'.join(TAG_VOCAB)),
    # Mathe the OOV token and 'salmon' occur in the largest number
    # of examples on this client:
    ('salmon oovword', 'FISH|OOVTAG')
])

print('Word vocab')
for i, word in enumerate(WORD_VOCAB):
  print(f'{i:2d} {word}')

print('\nTag vocab')
for i, tag in enumerate(TAG_VOCAB):
  print(f'{i:2d} {tag}')
Word vocab
 0 apple
 1 orange
 2 pear
 3 kiwi
 4 carrot
 5 broccoli
 6 arugula
 7 peas
 8 trout
 9 tuna
10 cod
11 salmon

Tag vocab
 0 FRUIT
 1 VEGETABLE
 2 FISH

הגדר קבועים למספרים הגולמיים של תכונות קלט (אסימונים / מילים) ותוויות (תגיות פוסט). NUM_OOV_BUCKETS = 1 הקלט / הפלט בפועל שלנו הם NUM_OOV_BUCKETS = 1 גדולים יותר מכיוון שאנחנו מוסיפים אסימון / תג OOV.

NUM_WORDS = len(WORD_VOCAB) 
NUM_TAGS = len(TAG_VOCAB)

WORD_VOCAB_SIZE = NUM_WORDS + NUM_OOV_BUCKETS
TAG_VOCAB_SIZE = NUM_TAGS + NUM_OOV_BUCKETS

צור גרסאות קבוצות של מערכי הנתונים, וקבוצות בודדות, אשר יהיו שימושיות בבדיקת קוד תוך כדי.

batched_dataset1 = CLIENT1_DATASET.batch(2)
batched_dataset2 = CLIENT2_DATASET.batch(3)
batched_dataset3 = CLIENT3_DATASET.batch(2)

batch1 = next(iter(batched_dataset1))
batch2 = next(iter(batched_dataset2))
batch3 = next(iter(batched_dataset3))

הגדר מודל עם תשומות דלילות

אנו משתמשים במודל רגרסיה לוגיסטית עצמאית פשוטה עבור כל תג.

def create_logistic_model(word_vocab_size: int, vocab_tags_size: int):

  model = tf.keras.models.Sequential([
      tf.keras.layers.InputLayer(input_shape=(word_vocab_size,), sparse=True),
      tf.keras.layers.Dense(
          vocab_tags_size,
          activation='sigmoid',
          kernel_initializer=tf.keras.initializers.zeros,
          # For simplicity, don't use a bias vector; this means the model
          # is a single tensor, and we only need sparse aggregation of
          # the per-token slices of the model. Generalizing to also handle
          # other model weights that are fully updated 
          # (non-dense broadcast and aggregate) would be a good exercise.
          use_bias=False),
  ])

  return model

בואו נוודא שזה עובד, תחילה על ידי ניבוי:

model = create_logistic_model(WORD_VOCAB_SIZE, TAG_VOCAB_SIZE)
p = model.predict(batch1.tokens)
print(p)
[[0.5 0.5 0.5 0.5]
 [0.5 0.5 0.5 0.5]]

וכמה אימונים מרכזיים פשוטים:

model.compile(optimizer=tf.keras.optimizers.Adagrad(learning_rate=0.001),
              loss=tf.keras.losses.BinaryCrossentropy())
model.train_on_batch(batch1.tokens, batch1.tags)

אבני בניין לחישוב המאוחד

אנו ניישם גרסה פשוטה של ​​האלגוריתם Federated Averaging עם ההבדל העיקרי שכל מכשיר מוריד רק קבוצת משנה רלוונטית של המודל, ותורם רק עדכונים לאותה תת קבוצה.

אנו משתמשים ב- M MAX_TOKENS_SELECTED_PER_CLIENT עבור MAX_TOKENS_SELECTED_PER_CLIENT . ברמה גבוהה, סבב אימונים אחד כולל את השלבים הבאים:

  1. כל לקוח משתתף סורק את מערך הנתונים המקומי שלו, מנתח את מחרוזות הקלט וממפה אותם לאסימונים הנכונים (אינדקסים int). זה דורש גישה למילון הגלובלי (הגדול) (זה יכול להימנע באמצעות טכניקות של גיבוב תכונות ). לאחר מכן אנו סופרים בדלילות כמה פעמים כל אסימון מתרחש. אם U אסימונים ייחודיים U מתרחשים במכשיר, אנו בוחרים num_actual_tokens = min(U, M) הנפוצים ביותר לאימון.

  2. הלקוחות משתמשים ב- federated_select כדי לאחזר את מקדמי המודל עבור האסימונים הנבחרים של num_actual_tokens מהשרת. כל פרוסת דגם היא טנסור של צורה (TAG_VOCAB_SIZE, ) , כך שכל הנתונים המועברים ללקוח הם לכל היותר בגודל TAG_VOCAB_SIZE * M (ראה הערה למטה).

  3. הלקוחות בונים מיפוי global_token -> local_token כאשר האסימון המקומי (אינדקס int) הוא האינדקס של האסימון הגלובלי ברשימת האסימונים שנבחרו.

  4. הלקוחות משתמשים בגרסה "קטנה" של המודל העולמי שיש בו רק מקדמים לכל היותר אסימונים M , מהטווח [0, num_actual_tokens) . המיפוי global -> local משמש לאתחול הפרמטרים הצפופים של מודל זה מפרוסות המודל שנבחרו.

  5. לקוחות מאמנים את המודל המקומי שלהם באמצעות SGD בנתונים שעובדו מראש עם המיפוי global -> local .

  6. לקוחות הופכים את הפרמטרים של המודל המקומי שלהם לעדכוני IndexedSlices באמצעות המיפוי local -> global כדי להוסיף את השורות לאינדקס. השרת מצטבר את העדכונים הללו באמצעות צבירת סכומים דלילה.

  7. השרת לוקח את התוצאה (הצפופה) של הצבירה שלעיל, מחלק אותה למספר הלקוחות המשתתפים, ומחיל את העדכון הממוצע המתקבל על המודל הגלובלי.

בחלק זה אנו בונים את אבני הבניין לצעדים אלה, אשר ישולבו לאחר מכן בחישוב federated_computation הסופי federated_computation את הלוגיקה המלאה של סבב אימונים אחד.

ספרו אסימונים של לקוחות והחליטו איזה פרוסות מודל ל- federated_select

כל מכשיר צריך להחליט אילו "פרוסות" של המודל רלוונטיות למערך האימונים המקומי שלו. לבעיה שלנו, אנו עושים זאת על ידי ספירה (בדלילות!) כמה דוגמאות מכילות כל אסימון בערכת נתוני האימון של הלקוח.

@tf.function
def token_count_fn(token_counts, batch):
  """Adds counts from `batch` to the running `token_counts` sum."""
  # Sum across the batch dimension.
  flat_tokens = tf.sparse.reduce_sum(
      batch.tokens, axis=0, output_is_sparse=True)
  flat_tokens = tf.cast(flat_tokens, dtype=TOKEN_COUNT_DTYPE)
  return tf.sparse.add(token_counts, flat_tokens)
# Simple tests
# Create the initial zero token counts using empty tensors.
initial_token_counts = tf.SparseTensor(
    indices=tf.zeros(shape=(0, 1), dtype=TOKEN_DTYPE),
    values=tf.zeros(shape=(0,), dtype=TOKEN_COUNT_DTYPE),
    dense_shape=(WORD_VOCAB_SIZE,))

client_token_counts = batched_dataset1.reduce(initial_token_counts,
                                              token_count_fn)
tokens = tf.reshape(client_token_counts.indices, (-1,)).numpy()
print('tokens:', tokens)
np.testing.assert_array_equal(tokens, [0, 1, 4, 8])
# The count is the number of *examples* in which the token/word
# occurs, not the total number of occurences, since we still featurize
# multiple occurences in the same example as a "1".
counts = client_token_counts.values.numpy()
print('counts:', counts)
np.testing.assert_array_equal(counts, [2, 3, 1, 1])
tokens: [0 1 4 8]
counts: [2 3 1 1]

אנו נבחר את פרמטרי הדגם המתאימים ל- MAX_TOKENS_SELECTED_PER_CLIENT האסימונים המופיעים בתדירות הגבוהה ביותר במכשיר. אם פחות מכמות אסימונים רבים מתרחשים במכשיר, אנו מרפדים את הרשימה כדי לאפשר את השימוש ב- federated_select .

שימו לב שאולי אסטרטגיות אחרות טובות יותר, למשל, בבחירה אקראית של אסימונים (אולי על סמך הסבירות להתרחשותם). זה יבטיח שלכל פרוסות המודל (שלגבי הלקוח יש נתונים) יש סיכוי כלשהו להתעדכן.

@tf.function
def keys_for_client(client_dataset, max_tokens_per_client):
  """Computes a set of max_tokens_per_client keys."""
  initial_token_counts = tf.SparseTensor(
      indices=tf.zeros((0, 1), dtype=TOKEN_DTYPE),
      values=tf.zeros((0,), dtype=TOKEN_COUNT_DTYPE),
      dense_shape=(WORD_VOCAB_SIZE,))
  client_token_counts = client_dataset.reduce(initial_token_counts,
                                              token_count_fn)
  # Find the most-frequently occuring tokens
  tokens = tf.reshape(client_token_counts.indices, shape=(-1,))
  counts = client_token_counts.values
  perm = tf.argsort(counts, direction='DESCENDING')
  tokens = tf.gather(tokens, perm)
  counts = tf.gather(counts, perm)
  num_raw_tokens = tf.shape(tokens)[0]
  actual_num_tokens = tf.minimum(max_tokens_per_client, num_raw_tokens)
  selected_tokens = tokens[:actual_num_tokens]
  paddings = [[0, max_tokens_per_client - tf.shape(selected_tokens)[0]]]
  padded_tokens = tf.pad(selected_tokens, paddings=paddings)
  # Make sure the type is statically determined
  padded_tokens = tf.reshape(padded_tokens, shape=(max_tokens_per_client,))

  # We will pass these tokens as keys into `federated_select`, which
  # requires SELECT_KEY_DTYPE=tf.int32 keys.
  padded_tokens = tf.cast(padded_tokens, dtype=SELECT_KEY_DTYPE)
  return padded_tokens, actual_num_tokens
# Simple test

# Case 1: actual_num_tokens > max_tokens_per_client
selected_tokens, actual_num_tokens = keys_for_client(batched_dataset1, 3)
assert tf.size(selected_tokens) == 3
assert actual_num_tokens == 3

# Case 2: actual_num_tokens < max_tokens_per_client
selected_tokens, actual_num_tokens = keys_for_client(batched_dataset1, 10)
assert tf.size(selected_tokens) == 10
assert actual_num_tokens == 4

מיפוי אסימונים גלובליים לאסימונים מקומיים

הבחירה שלעיל נותנת לנו קבוצה צפופה של אסימונים בטווח [0, actual_num_tokens) בהם נשתמש למודל המכשיר. עם זאת, מערך הנתונים שקראנו מכיל אסימונים מהטווח הגדול הרבה יותר של אוצר המילים [0, WORD_VOCAB_SIZE) .

לפיכך, עלינו למפות את האסימונים הגלובליים לאסימונים המקומיים המתאימים להם. זיהוי האסימון המקומי ניתנים פשוט על ידי האינדקסים לתוך selected_tokens מותח מחושב בשלב הקודם.

@tf.function
def map_to_local_token_ids(client_data, client_keys):
  global_to_local = tf.lookup.StaticHashTable(
      # Note int32 -> int64 maps are not supported
      tf.lookup.KeyValueTensorInitializer(
          keys=tf.cast(client_keys, dtype=TOKEN_DTYPE),
          # Note we need to use tf.shape, not the static 
          # shape client_keys.shape[0]
          values=tf.range(0, limit=tf.shape(client_keys)[0],
                          dtype=TOKEN_DTYPE)),
      # We use -1 for tokens that were not selected, which can occur for clients
      # with more than MAX_TOKENS_SELECTED_PER_CLIENT distinct tokens.
      # We will simply remove these invalid indices from the batch below.
      default_value=-1)

  def to_local_ids(sparse_tokens):
    indices_t = tf.transpose(sparse_tokens.indices)
    batch_indices = indices_t[0]  # First column
    tokens = indices_t[1]  # Second column
    tokens = tf.map_fn(
        lambda global_token_id: global_to_local.lookup(global_token_id), tokens)
    # Remove tokens that aren't actually available (looked up as -1):
    available_tokens = tokens >= 0
    tokens = tokens[available_tokens]
    batch_indices = batch_indices[available_tokens]

    updated_indices = tf.transpose(
        tf.concat([[batch_indices], [tokens]], axis=0))
    st = tf.sparse.SparseTensor(
        updated_indices,
        tf.ones(tf.size(tokens), dtype=FEATURE_DTYPE),
        dense_shape=sparse_tokens.dense_shape)
    st = tf.sparse.reorder(st)
    return st

  return client_data.map(lambda b: BatchType(to_local_ids(b.tokens), b.tags))
# Simple test
client_keys, actual_num_tokens = keys_for_client(
    batched_dataset3, MAX_TOKENS_SELECTED_PER_CLIENT)
client_keys = client_keys[:actual_num_tokens]

d = map_to_local_token_ids(batched_dataset3, client_keys)
batch  = next(iter(d))
all_tokens = tf.gather(batch.tokens.indices, indices=1, axis=1)
# Confirm we have local indices in the range [0, MAX):
assert tf.math.reduce_max(all_tokens) < MAX_TOKENS_SELECTED_PER_CLIENT
assert tf.math.reduce_max(all_tokens) >= 0

הכשר את המודל המקומי (תת) על כל לקוח

הערה federated_select יחזיר את הפרוסות שנבחרו כtf.data.Dataset באותו סדר כמו מקשי הבחירה. לכן, אנו מגדירים תחילה פונקציית שירות שתיקח מערך נתונים כזה ותמיר אותו לטנסור צפוף יחיד שיכול לשמש כמשקולות המודל של מודל הלקוח.

@tf.function
def slices_dataset_to_tensor(slices_dataset):
  """Convert a dataset of slices to a tensor."""
  # Use batching to gather all of the slices into a single tensor.
  d = slices_dataset.batch(MAX_TOKENS_SELECTED_PER_CLIENT,
                           drop_remainder=False)
  iter_d = iter(d)
  tensor = next(iter_d)
  # Make sure we have consumed everything
  opt = iter_d.get_next_as_optional()
  tf.Assert(tf.logical_not(opt.has_value()), data=[''], name='CHECK_EMPTY')
  return tensor
# Simple test
weights = np.random.random(
    size=(MAX_TOKENS_SELECTED_PER_CLIENT, TAG_VOCAB_SIZE)).astype(np.float32)
model_slices_as_dataset = tf.data.Dataset.from_tensor_slices(weights)
weights2 = slices_dataset_to_tensor(model_slices_as_dataset)
np.testing.assert_array_equal(weights, weights2)

כעת יש לנו את כל המרכיבים הדרושים לנו כדי להגדיר לולאת אימונים מקומית פשוטה אשר תפעל על כל לקוח.

@tf.function
def client_train_fn(model, client_optimizer,
                    model_slices_as_dataset, client_data,
                    client_keys, actual_num_tokens):

  initial_model_weights = slices_dataset_to_tensor(model_slices_as_dataset)
  assert len(model.trainable_variables) == 1
  model.trainable_variables[0].assign(initial_model_weights)

  # Only keep the "real" (unpadded) keys.
  client_keys = client_keys[:actual_num_tokens]

  client_data = map_to_local_token_ids(client_data, client_keys)

  loss_fn = tf.keras.losses.BinaryCrossentropy()
  for features, labels in client_data:
    with tf.GradientTape() as tape:
      predictions = model(features)
      loss = loss_fn(labels, predictions)
    grads = tape.gradient(loss, model.trainable_variables)
    client_optimizer.apply_gradients(zip(grads, model.trainable_variables))

  model_weights_delta = model.trainable_weights[0] - initial_model_weights
  model_weights_delta = tf.slice(model_weights_delta, begin=[0, 0], 
                           size=[actual_num_tokens, -1])
  return client_keys, model_weights_delta
# Simple test
# Note if you execute this cell a second time, you need to also re-execute
# the preceeding cell to avoid "tf.function-decorated function tried to 
# create variables on non-first call" errors.
on_device_model = create_logistic_model(MAX_TOKENS_SELECTED_PER_CLIENT,
                                        TAG_VOCAB_SIZE)
client_optimizer = tf.keras.optimizers.SGD(learning_rate=0.001)
client_keys, actual_num_tokens = keys_for_client(
    batched_dataset2, MAX_TOKENS_SELECTED_PER_CLIENT)

model_slices_as_dataset = tf.data.Dataset.from_tensor_slices(
    np.zeros((MAX_TOKENS_SELECTED_PER_CLIENT, TAG_VOCAB_SIZE),
             dtype=np.float32))

keys, delta = client_train_fn(
    on_device_model,
    client_optimizer,
    model_slices_as_dataset,
    client_data=batched_dataset3,
    client_keys=client_keys,
    actual_num_tokens=actual_num_tokens)

print(delta)

פרוסות אינדקס מצטבר

אנו משתמשים ב- tff.federated_aggregate לבניית סכום דליל מאוחד עבור IndexedSlices . ליישום פשוט זה יש אילוץ כי הצורה dense_shape ידועה סטטית מראש. שימו לב גם כי סכום זה הוא רק דליל למחצה , במובן שהלקוח -> תקשורת השרתים היא דלילה, אך השרת שומר על ייצוג צפוף של הסכום accumulate merge , ומפיק ייצוג צפוף זה.

def federated_indexed_slices_sum(slice_indices, slice_values, dense_shape):
  """
  Sumes IndexedSlices@CLIENTS to a dense @SERVER Tensor.

  Intermediate aggregation is performed by converting to a dense representation,
  which may not be suitable for all applications.

  Args:
    slice_indices: An IndexedSlices.indices tensor @CLIENTS.
    slice_values: An IndexedSlices.values tensor @CLIENTS.
    dense_shape: A statically known dense shape.

  Returns:
    A dense tensor placed @SERVER representing the sum of the client's
    IndexedSclies.
  """
  slices_dtype = slice_values.type_signature.member.dtype
  zero = tff.tf_computation(
      lambda: tf.zeros(dense_shape, dtype=slices_dtype))()

  @tf.function
  def accumulate_slices(dense, client_value):
    indices, slices = client_value
    # There is no built-in way to add `IndexedSlices`, but 
    # tf.convert_to_tensor is a quick way to convert to a dense representation
    # so we can add them.
    return dense + tf.convert_to_tensor(
        tf.IndexedSlices(slices, indices, dense_shape))


  return tff.federated_aggregate(
      (slice_indices, slice_values),
      zero=zero,
      accumulate=tff.tf_computation(accumulate_slices),
      merge=tff.tf_computation(lambda d1, d2: tf.add(d1, d2, name='merge')),
      report=tff.tf_computation(lambda d: d))

בנה federated_computation כמבחן

dense_shape = (6, 2)
indices_type = tff.TensorType(tf.int64, (None,))
values_type = tff.TensorType(tf.float32, (None, 2))
client_slice_type = tff.type_at_clients(
    (indices_type, values_type))

@tff.federated_computation(client_slice_type)
def test_sum_indexed_slices(indices_values_at_client):
  indices, values = indices_values_at_client
  return federated_indexed_slices_sum(indices, values, dense_shape)

print(test_sum_indexed_slices.type_signature)
({<int64[?],float32[?,2]>}@CLIENTS -> float32[6,2]@SERVER)
x = tf.IndexedSlices(
    values=np.array([[2., 2.1], [0., 0.1], [1., 1.1], [5., 5.1]],
                    dtype=np.float32),
    indices=[2, 0, 1, 5],
    dense_shape=dense_shape)
y = tf.IndexedSlices(
    values=np.array([[0., 0.3], [3.1, 3.2]], dtype=np.float32),
    indices=[1, 3],
    dense_shape=dense_shape)

# Sum one.
result = test_sum_indexed_slices([(x.indices, x.values)])
np.testing.assert_array_equal(tf.convert_to_tensor(x), result)

# Sum two.
expected = [[0., 0.1], [1., 1.4], [2., 2.1], [3.1, 3.2], [0., 0.], [5., 5.1]]
result = test_sum_indexed_slices([(x.indices, x.values), (y.indices, y.values)])
np.testing.assert_array_almost_equal(expected, result)

לשים את הכל יחדיו ב- federated_computation

כעת אנו משתמשים ב- TFF כדי לאגד את הרכיבים tff.federated_computation .

DENSE_MODEL_SHAPE = (WORD_VOCAB_SIZE, TAG_VOCAB_SIZE)
client_data_type = tff.SequenceType(batched_dataset1.element_spec)
model_type = tff.TensorType(tf.float32, shape=DENSE_MODEL_SHAPE)

אנו משתמשים בפונקציית הכשרת שרתים בסיסית המבוססת על Federated Averaging, ומיישמות את העדכון בקצב למידה של שרת 1.0. חשוב שנחיל עדכון (דלתא) על המודל, ולא פשוט ממוצע של מודלים המסופקים על ידי הלקוח, שכן אחרת אם נתח מסוים של המודל לא היה מאומן על ידי לקוח כלשהו בסיבוב נתון, ניתן לאפס את מקדמיו. הַחוּצָה.

@tff.tf_computation
def server_update(current_model_weights, update_sum, num_clients):
  average_update = update_sum / num_clients
  return current_model_weights + average_update

אנו זקוקים לעוד כמה רכיבי tff.tf_computation :

# Function to select slices from the model weights in federated_select:
select_fn = tff.tf_computation(
    lambda model_weights, index: tf.gather(model_weights, index))


# We need to wrap `client_train_fn` as a `tff.tf_computation`, making
# sure we do any operations that might construct `tf.Variable`s outside
# of the `tf.function` we are wrapping.
@tff.tf_computation
def client_train_fn_tff(model_slices_as_dataset, client_data, client_keys,
                        actual_num_tokens):
  # Note this is amaller than the global model, using
  # MAX_TOKENS_SELECTED_PER_CLIENT which is much smaller than WORD_VOCAB_SIZE.
  # W7e would like a model of size `actual_num_tokens`, but we
  # can't build the model dynamically, so we will slice off the padded
  # weights at the end.
  client_model = create_logistic_model(MAX_TOKENS_SELECTED_PER_CLIENT,
                                       TAG_VOCAB_SIZE)
  client_optimizer = tf.keras.optimizers.SGD(learning_rate=0.1)
  return client_train_fn(client_model, client_optimizer,
                         model_slices_as_dataset, client_data, client_keys,
                         actual_num_tokens)

@tff.tf_computation
def keys_for_client_tff(client_data):
  return keys_for_client(client_data, MAX_TOKENS_SELECTED_PER_CLIENT)

עכשיו אנחנו מוכנים להרכיב את כל החלקים!

@tff.federated_computation(
    tff.type_at_server(model_type), tff.type_at_clients(client_data_type))
def sparse_model_update(server_model, client_data):
  max_tokens = tff.federated_value(MAX_TOKENS_SELECTED_PER_CLIENT, tff.SERVER)
  keys_at_clients, actual_num_tokens = tff.federated_map(
      keys_for_client_tff, client_data)

  model_slices = tff.federated_select(keys_at_clients, max_tokens, server_model,
                                      select_fn)

  update_keys, update_slices = tff.federated_map(
      client_train_fn_tff,
      (model_slices, client_data, keys_at_clients, actual_num_tokens))

  dense_update_sum = federated_indexed_slices_sum(update_keys, update_slices,
                                                  DENSE_MODEL_SHAPE)
  num_clients = tff.federated_sum(tff.federated_value(1.0, tff.CLIENTS))

  updated_server_model = tff.federated_map(
      server_update, (server_model, dense_update_sum, num_clients))

  return updated_server_model


print(sparse_model_update.type_signature)
(<server_model=float32[13,4]@SERVER,client_data={<tokens=<indices=int64[?,2],values=int32[?],dense_shape=int64[2]>,tags=float32[?,4]>*}@CLIENTS> -> float32[13,4]@SERVER)

בואו נכשיר דוגמנית!

עכשיו שיש לנו את פונקציית האימון שלנו, בואו ננסה את זה.

server_model = create_logistic_model(WORD_VOCAB_SIZE, TAG_VOCAB_SIZE)
server_model.compile(  # Compile to make evaluation easy.
    optimizer=tf.keras.optimizers.Adagrad(learning_rate=0.0),  # Unused
    loss=tf.keras.losses.BinaryCrossentropy(),
    metrics=[ 
      tf.keras.metrics.Precision(name='precision'),
      tf.keras.metrics.AUC(name='auc'),
      tf.keras.metrics.Recall(top_k=2, name='recall_at_2'),
  ])

def evaluate(model, dataset, name):
  metrics = model.evaluate(dataset, verbose=0)
  metrics_str = ', '.join([f'{k}={v:.2f}' for k, v in 
                          (zip(server_model.metrics_names, metrics))])
  print(f'{name}: {metrics_str}')
print('Before training')
evaluate(server_model, batched_dataset1, 'Client 1')
evaluate(server_model, batched_dataset2, 'Client 2')
evaluate(server_model, batched_dataset3, 'Client 3')

model_weights = server_model.trainable_weights[0]

client_datasets = [batched_dataset1, batched_dataset2, batched_dataset3]
for _ in range(10):  # Run 10 rounds of FedAvg
  # We train on 1, 2, or 3 clients per round, selecting
  # randomly.
  cohort_size = np.random.randint(1, 4)
  clients = np.random.choice([0, 1, 2], cohort_size, replace=False)
  print('Training on clients', clients)
  model_weights = sparse_model_update(
      model_weights, [client_datasets[i] for i in clients])
server_model.set_weights([model_weights])

print('After training')
evaluate(server_model, batched_dataset1, 'Client 1')
evaluate(server_model, batched_dataset2, 'Client 2')
evaluate(server_model, batched_dataset3, 'Client 3')
Before training
Client 1: loss=0.69, precision=0.00, auc=0.50, recall_at_2=0.60
Client 2: loss=0.69, precision=0.00, auc=0.50, recall_at_2=0.50
Client 3: loss=0.69, precision=0.00, auc=0.50, recall_at_2=0.40
Training on clients [0 1]
Training on clients [0 2 1]
Training on clients [2 0]
Training on clients [1 0 2]
Training on clients [2]
Training on clients [2 0]
Training on clients [1 2 0]
Training on clients [0]
Training on clients [2]
Training on clients [1 2]
After training
Client 1: loss=0.67, precision=0.80, auc=0.91, recall_at_2=0.80
Client 2: loss=0.68, precision=0.67, auc=0.96, recall_at_2=1.00
Client 3: loss=0.65, precision=1.00, auc=0.93, recall_at_2=0.80