עזרה להגן על שונית המחסום הגדולה עם TensorFlow על Kaggle הצטרפו אתגר

תמיכה ניסיונית ב- JAX ב- TFF

הצג באתר TensorFlow.org הפעל בגוגל קולאב הצג ב-GitHub הורד מחברת

בנוסף להיותו חלק מהמערכת האקולוגית של TensorFlow, TFF שואפת לאפשר יכולת פעולה הדדית עם מסגרות ML חזיתיות ו-backend אחרות. נכון לעכשיו, התמיכה במסגרות ML אחרות נמצאת עדיין בשלב הדגירה, וה-APIs והפונקציונליות הנתמכת עשויים להשתנות (במידה רבה כפונקציה של ביקוש ממשתמשי TFF). מדריך זה מתאר כיצד להשתמש ב-TFF עם JAX כחזית ML חלופית, ובקומפיילר XLA כחלק אחורי חלופי. הדוגמאות המוצגות כאן מבוססות על ערימת JAX/XLA מקורית לחלוטין, מקצה לקצה. האפשרות של ערבוב קוד בין מסגרות (למשל, JAX עם TensorFlow) תידון באחד המדריכים העתידיים.

כמו תמיד, נשמח לתרומתך. אם חשובה לך תמיכה ב-JAX/XLA או היכולת לפעול עם מסגרות ML אחרות, אנא שקול לעזור לנו לפתח את היכולות הללו לקראת שוויון עם שאר ה-TFF.

לפני שנתחיל

אנא עיין בגוף הראשי של תיעוד TFF כיצד להגדיר את הסביבה שלך. בהתאם למקום שבו אתה מפעיל את המדריך הזה, ייתכן שתרצה לבטל את ההערות ולהפעיל חלק מהקוד שלמטה או את כולו.

# !pip install --quiet --upgrade tensorflow-federated-nightly
# !pip install --quiet --upgrade nest-asyncio
# import nest_asyncio
# nest_asyncio.apply()

מדריך זה גם מניח שסקרת את המדריכים העיקריים של TensorFlow של TFF, ושאתה מכיר את מושגי הליבה של TFF. אם עדיין לא עשית זאת, אנא שקול לבדוק לפחות אחד מהם.

חישובי JAX

התמיכה ב-JAX ב-TFF נועדה להיות סימטרית עם האופן שבו TFF פועל עם TensorFlow, החל מייבוא:

import jax
import numpy as np
import tensorflow_federated as tff

כמו כן, בדיוק כמו עם TensorFlow, הבסיס לביטוי כל קוד TFF הוא ההיגיון שפועל באופן מקומי. אתה יכול להביע את ההיגיון הזה JAX, כפי שנראה בהמשך, באמצעות @tff.experimental.jax_computation המעטפת. הוא מתנהג באופן דומה @tff.tf_computation שעד עכשיו שלך מכירים. נתחיל עם משהו פשוט, למשל, חישוב שמוסיף שני מספרים שלמים:

@tff.experimental.jax_computation(np.int32, np.int32)
def add_numbers(x, y):
  return jax.numpy.add(x, y)

אתה יכול להשתמש בחישוב JAX שהוגדר לעיל בדיוק כמו שאתה משתמש בדרך כלל בחישוב TFF. לדוגמה, אתה יכול לבדוק את חתימת הסוג שלה, באופן הבא:

str(add_numbers.type_signature)
'(<x=int32,y=int32> -> int32)'

שים לב השתמשנו np.int32 להגדיר את סוג הטיעונים. TFF אינו מבחין בין סוגי numpy (כגון np.int32 ) וסוג TensorFlow (כגון tf.int32 ). מנקודת המבט של TFF, הם רק דרכים להתייחס לאותו דבר.

כעת, זכור ש-TFF אינו Python (ואם זה לא מצלצל, אנא עיין בכמה מהמדריכים הקודמים שלנו, למשל, על אלגוריתמים מותאמים אישית). אתה יכול להשתמש @tff.experimental.jax_computation מעטפת עם כל JAX קוד שניתן לייחס ו בהמשכים, כלומר, עם קוד שאתה עושה בדרך כלל לכתוב הערות עם @jax.jit צפוי להיות מלוקט לתוך XLA (אבל אתה לא צריך למעשה להשתמש @jax.jit ביאור להטביע קוד JAX שלך TFF).

ואכן, מתחת למכסה המנוע, TFF מחבר באופן מיידי חישובי JAX ל-XLA. אתה יכול לבדוק את זה בעצמך על ידי חילוץ ידני והדפסת קוד XLA בהמשכים מ add_numbers , כדלקמן:

comp_pb = tff.framework.serialize_computation(add_numbers)
comp_pb.WhichOneof('computation')
'xla'
xla_code = jax.lib.xla_client.XlaComputation(comp_pb.xla.hlo_module.value)
print(xla_code.as_hlo_text())
HloModule xla_computation_add_numbers.7

ENTRY xla_computation_add_numbers.7 {
  constant.4 = pred[] constant(false)
  parameter.1 = (s32[], s32[]) parameter(0)
  get-tuple-element.2 = s32[] get-tuple-element(parameter.1), index=0
  get-tuple-element.3 = s32[] get-tuple-element(parameter.1), index=1
  add.5 = s32[] add(get-tuple-element.2, get-tuple-element.3)
  ROOT tuple.6 = (s32[]) tuple(add.5)
}

תחשוב על ייצוג של חישובים JAX כקוד XLA כמו להיות שוות ערך מבחינה פונקציונאלית tf.GraphDef עבור חישובים לידי ביטוי TensorFlow. זהו נייד הפעלה במגוון סביבות התומכות XLA, בדיוק כמו tf.GraphDef ניתן לבצע על כל ריצה TensorFlow.

TFF מספק ערימת זמן ריצה המבוססת על המהדר XLA כ-backend. אתה יכול להפעיל אותו באופן הבא:

tff.backends.xla.set_local_python_execution_context()

כעת, אתה יכול לבצע את החישוב שהגדרנו למעלה:

add_numbers(2, 3)
5

קל מספיק. בוא נלך עם המכה ונעשה משהו יותר מסובך, כמו MNIST.

דוגמה לאימון MNIST עם API משומר

כרגיל, אנו מתחילים בהגדרת חבורה של סוגי TFF עבור אצווה של נתונים, ועבור המודל (זכור, TFF הוא מסגרת עם הקלדה חזקה).

import collections

BATCH_TYPE = collections.OrderedDict([
    ('pixels', tff.TensorType(np.float32, (50, 784))),
    ('labels', tff.TensorType(np.int32, (50,)))
])

MODEL_TYPE = collections.OrderedDict([
    ('weights', tff.TensorType(np.float32, (784, 10))),
    ('bias', tff.TensorType(np.float32, (10,)))
])

כעת, בואו נגדיר פונקציית אובדן עבור המודל ב-JAX, ניקח את המודל ואצווה בודדת של נתונים כפרמטר:

def loss(model, batch):
  y = jax.nn.softmax(
      jax.numpy.add(
          jax.numpy.matmul(batch['pixels'], model['weights']), model['bias']))
  targets = jax.nn.one_hot(jax.numpy.reshape(batch['labels'], -1), 10)
  return -jax.numpy.mean(jax.numpy.sum(targets * jax.numpy.log(y), axis=1))

כעת, דרך אחת ללכת היא להשתמש ב-API משומר. הנה דוגמה לאופן שבו אתה יכול להשתמש ב-API שלנו כדי ליצור תהליך הדרכה המבוסס על פונקציית ההפסד שהוגדרה זה עתה.

STEP_SIZE = 0.001

trainer = tff.experimental.learning.build_jax_federated_averaging_process(
    BATCH_TYPE, MODEL_TYPE, loss, STEP_SIZE)

ניתן להשתמש מעל בדיוק כפי שהיית משתמש לבנות מאמן מנקודת tf.Keras מודל TensorFlow. לדוגמה, כך תוכל ליצור את המודל הראשוני לאימון:

initial_model = trainer.initialize()
initial_model
Struct([('weights', array([[0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       ...,
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.]], dtype=float32)), ('bias', array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32))])

על מנת לבצע הכשרה בפועל, אנו זקוקים לכמה נתונים. בואו ניצור נתונים אקראיים כדי שיהיה פשוט. מכיוון שהנתונים הם אקראיים, אנו הולכים להעריך על נתוני אימון, שכן אחרת, עם נתוני הערכה אקראיים, יהיה קשה לצפות שהמודל יבצע. כמו כן, עבור הדגמה בקנה מידה קטן זה, לא נדאג לגבי דגימה אקראית של לקוחות (אנו משאירים זאת כתרגיל למשתמש לחקור את סוגי השינויים הללו על ידי ביצוע התבניות ממדריכים אחרים):

def random_batch():
  pixels = np.random.uniform(
      low=0.0, high=1.0, size=(50, 784)).astype(np.float32)
  labels = np.random.randint(low=0, high=9, size=(50,), dtype=np.int32)
  return collections.OrderedDict([('pixels', pixels), ('labels', labels)])

NUM_CLIENTS = 2
NUM_BATCHES = 10

train_data = [
    [random_batch() for _ in range(NUM_BATCHES)]
    for _ in range(NUM_CLIENTS)]

עם זה, אנו יכולים לבצע שלב יחיד של אימון, כדלקמן:

trained_model = trainer.next(initial_model, train_data)
trained_model
Struct([('weights', array([[ 1.04456245e-04, -1.53498477e-05,  2.54597180e-05, ...,
         5.61640409e-05, -5.32875274e-05, -4.62881755e-04],
       [ 7.30908650e-05,  4.67643113e-05,  2.03352147e-06, ...,
         3.77510623e-05,  3.52839161e-05, -4.59865667e-04],
       [ 8.14835730e-05,  3.03147244e-05, -1.89143739e-05, ...,
         1.12527239e-04,  4.09212225e-06, -4.59960109e-04],
       ...,
       [ 9.23552434e-05,  2.44302555e-06, -2.20817346e-05, ...,
         7.61375341e-05,  1.76906979e-05, -4.43495519e-04],
       [ 1.17451040e-04,  2.47748958e-05,  1.04728279e-05, ...,
         5.26388249e-07,  7.21131510e-05, -4.67137404e-04],
       [ 3.75041491e-05,  6.58061981e-05,  1.14522081e-05, ...,
         2.52584141e-05,  3.55410739e-05, -4.30888613e-04]], dtype=float32)), ('bias', array([ 1.5096272e-04,  2.6502126e-05, -1.9462314e-05,  8.1269856e-05,
        2.1832302e-04,  1.6636557e-04,  1.2815947e-04,  9.0642272e-05,
        7.7109929e-05, -9.1987278e-04], dtype=float32))])

בואו נעריך את התוצאה של שלב האימון. כדי שיהיה קל, נוכל להעריך את זה בצורה מרוכזת:

import itertools
eval_data = list(itertools.chain.from_iterable(train_data))

def average_loss(model, data):
  return np.mean([loss(model, batch) for batch in data])

print (average_loss(initial_model, eval_data))
print (average_loss(trained_model, eval_data))
2.3025854
2.282762

ההפסד הולך ופוחת. גדול! עכשיו, בוא נרוץ את זה על פני מספר סבבים:

NUM_ROUNDS = 20
for _ in range(NUM_ROUNDS):
  trained_model = trainer.next(trained_model, train_data)
  print(average_loss(trained_model, eval_data))
2.2685437
2.257856
2.2495182
2.2428129
2.2372835
2.2326245
2.2286277
2.2251441
2.2220676
2.219318
2.2168345
2.2145717
2.2124937
2.2105706
2.2087805
2.2071042
2.2055268
2.2040353
2.2026198
2.2012706

כפי שאתה רואה, השימוש ב-JAX עם TFF אינו שונה כל כך, אם כי ממשקי ה-API הניסיוניים עדיין אינם דומים לפונקציונליות ה-API של TensorFlow.

מתחת למכסת המנוע

אם אתה מעדיף לא להשתמש ב-API המשומר שלנו, אתה יכול ליישם חישובים מותאמים אישית משלך, בדיוק כמו איך ראית את זה נעשה במדריכי האלגוריתמים המותאמים אישית עבור TensorFlow, פרט לכך שתשתמש במנגנון של JAX לירידה בשיפוע. לדוגמה, להלן כיצד ניתן להגדיר חישוב JAX שמעדכן את המודל במיני-אצץ בודד:

@tff.experimental.jax_computation(MODEL_TYPE, BATCH_TYPE)
def train_on_one_batch(model, batch):
  grads = jax.grad(loss)(model, batch)
  return collections.OrderedDict([
      (k, model[k] - STEP_SIZE * grads[k]) for k in ['weights', 'bias']
  ])

כך תוכל לבדוק שזה עובד:

sample_batch = random_batch()
trained_model = train_on_one_batch(initial_model, sample_batch)
print(average_loss(initial_model, [sample_batch]))
print(average_loss(trained_model, [sample_batch]))
2.3025854
2.2977567

אזהרה אחת של עבודה עם JAX הוא שהיא אינה מציעה את המקבילה של tf.data.Dataset . לפיכך, על מנת לחזור על מערכי נתונים, תצטרך להשתמש במבנים ההצהרתיים של TFF עבור פעולות על רצפים, כגון זה המוצג להלן:

@tff.federated_computation(MODEL_TYPE, tff.SequenceType(BATCH_TYPE))
def train_on_one_client(model, batches):
  return tff.sequence_reduce(batches, model, train_on_one_batch)

בוא נראה שזה עובד:

sample_dataset = [random_batch() for _ in range(100)]
trained_model = train_on_one_client(initial_model, sample_dataset)
print(average_loss(initial_model, sample_dataset))
print(average_loss(trained_model, sample_dataset))
2.3025854
2.2284968

החישוב שמבצע סבב אימון בודד נראה בדיוק כמו זה שאולי ראית במדריכי TensorFlow:

@tff.federated_computation(
    tff.FederatedType(MODEL_TYPE, tff.SERVER),
    tff.FederatedType(tff.SequenceType(BATCH_TYPE), tff.CLIENTS))
def train_one_round(model, federated_data):
  locally_trained_models = tff.federated_map(
      train_on_one_client,
      collections.OrderedDict([
          ('model', tff.federated_broadcast(model)),
          ('batches', federated_data)]))
  return tff.federated_mean(locally_trained_models)

בוא נראה שזה עובד:

trained_model = train_one_round(initial_model, train_data)
print(average_loss(initial_model, eval_data))
print(average_loss(trained_model, eval_data))
2.3025854
2.282762

כפי שאתה רואה, שימוש ב-JAX ב-TFF, בין אם באמצעות ממשקי API משומרים, או באמצעות שימוש ישיר במבני TFF ברמה נמוכה, דומה לשימוש ב-TFF עם TensorFlow. הישאר מעודכן לעדכונים עתידיים, ואם תרצה לראות תמיכה טובה יותר עבור יכולת פעולה הדדית בין מסגרות ML, אל תהסס לשלוח לנו בקשת משיכה!