סימולציות בעלות ביצועים גבוהים עם TFF

מדריך זה יתאר כיצד להגדיר סימולציות בעלות ביצועים גבוהים עם TFF במגוון תרחישים נפוצים.

TODO(b/134543154): אכלס את התוכן, כמה מהדברים שיש לכסות כאן:

  • שימוש במעבדי GPU בהגדרה של מכונה אחת,
  • הגדרת ריבוי מכונות ב-GCP/GKE, עם ובלי TPUs,
  • מתממשקים לממשקים עורפיים דמויי MapReduce,
  • המגבלות הנוכחיות ומתי/איך הן יורפו.
הצג באתר TensorFlow.org הפעל בגוגל קולאב צפה במקור ב-GitHub הורד מחברת

לפני שנתחיל

ראשית, ודא שהמחברת שלך מחוברת ל-backend שבו הרכיבים הרלוונטיים (כולל תלות gRPC עבור תרחישים מרובי מכונות) מורכבים.

כעת, נתחיל בטעינת הדוגמה של MNIST מאתר TFF, והכרזה על פונקציית Python שתפעיל לולאת ניסוי קטנה על פני קבוצה של 10 לקוחות.

!pip install --quiet --upgrade tensorflow-federated-nightly
!pip install --quiet --upgrade nest-asyncio

import nest_asyncio
nest_asyncio.apply()
import collections
import time

import tensorflow as tf

import tensorflow_federated as tff

source, _ = tff.simulation.datasets.emnist.load_data()


def map_fn(example):
  return collections.OrderedDict(
      x=tf.reshape(example['pixels'], [-1, 784]), y=example['label'])


def client_data(n):
  ds = source.create_tf_dataset_for_client(source.client_ids[n])
  return ds.repeat(10).shuffle(500).batch(20).map(map_fn)


train_data = [client_data(n) for n in range(10)]
element_spec = train_data[0].element_spec


def model_fn():
  model = tf.keras.models.Sequential([
      tf.keras.layers.InputLayer(input_shape=(784,)),
      tf.keras.layers.Dense(units=10, kernel_initializer='zeros'),
      tf.keras.layers.Softmax(),
  ])
  return tff.learning.from_keras_model(
      model,
      input_spec=element_spec,
      loss=tf.keras.losses.SparseCategoricalCrossentropy(),
      metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])


trainer = tff.learning.build_federated_averaging_process(
    model_fn, client_optimizer_fn=lambda: tf.keras.optimizers.SGD(0.02))


def evaluate(num_rounds=10):
  state = trainer.initialize()
  for _ in range(num_rounds):
    t1 = time.time()
    state, metrics = trainer.next(state, train_data)
    t2 = time.time()
    print('metrics {m}, round time {t:.2f} seconds'.format(
        m=metrics, t=t2 - t1))

הדמיות של מכונה אחת

כעת פועל כברירת מחדל.

evaluate()
metrics <sparse_categorical_accuracy=0.13858024775981903,loss=3.0073554515838623>, round time 3.59 seconds
metrics <sparse_categorical_accuracy=0.1796296238899231,loss=2.749046802520752>, round time 2.29 seconds
metrics <sparse_categorical_accuracy=0.21656379103660583,loss=2.514779567718506>, round time 2.33 seconds
metrics <sparse_categorical_accuracy=0.2637860178947449,loss=2.312587261199951>, round time 2.06 seconds
metrics <sparse_categorical_accuracy=0.3334362208843231,loss=2.068122386932373>, round time 2.00 seconds
metrics <sparse_categorical_accuracy=0.3737654387950897,loss=1.9268712997436523>, round time 2.42 seconds
metrics <sparse_categorical_accuracy=0.4296296238899231,loss=1.7216310501098633>, round time 2.20 seconds
metrics <sparse_categorical_accuracy=0.4655349850654602,loss=1.6489890813827515>, round time 2.18 seconds
metrics <sparse_categorical_accuracy=0.5048353672027588,loss=1.5485210418701172>, round time 2.16 seconds
metrics <sparse_categorical_accuracy=0.5564814805984497,loss=1.4140453338623047>, round time 2.41 seconds

סימולציות מרובות מכונות על GCP/GKE, GPUs, TPUs ומעבר...

מגיע בקרוב מאוד.