دعم تجريبي لـ JAX في TFF

عرض على TensorFlow.org تشغيل في Google Colab عرض على جيثب تحميل دفتر

بالإضافة إلى كونه جزءًا من نظام TensorFlow البيئي ، يهدف TFF إلى تمكين التشغيل البيني مع أطر ML الأخرى للواجهة الأمامية والخلفية. في الوقت الحالي ، لا يزال دعم أطر ML الأخرى في مرحلة الحضانة ، وقد تتغير واجهات برمجة التطبيقات والوظائف المدعومة (إلى حد كبير كدالة للطلب من مستخدمي TFF). يصف هذا البرنامج التعليمي كيفية استخدام TFF مع JAX كواجهة أمامية بديلة لـ ML ، ومترجم XLA كخلفية بديلة. تستند الأمثلة الموضحة هنا إلى مكدس JAX / XLA أصلي بالكامل ، من طرف إلى طرف. ستتم مناقشة إمكانية خلط الكود عبر الأطر (على سبيل المثال ، JAX مع TensorFlow) في أحد البرامج التعليمية المستقبلية.

كالعادة ، نرحب بمساهماتك إذا كان دعم JAX / XLA أو القدرة على التعامل مع أطر ML الأخرى أمرًا مهمًا بالنسبة لك ، فالرجاء التفكير في مساعدتنا في تطوير هذه الإمكانات نحو التكافؤ مع بقية TFF.

قبل أن نبدأ

يرجى الرجوع إلى الجزء الرئيسي من وثائق TFF لمعرفة كيفية تكوين بيئتك. بناءً على مكان تشغيل هذا البرنامج التعليمي ، قد ترغب في إلغاء التعليق وتشغيل بعض أو كل التعليمات البرمجية أدناه.

# !pip install --quiet --upgrade tensorflow-federated-nightly
# !pip install --quiet --upgrade nest-asyncio
# import nest_asyncio
# nest_asyncio.apply()

يفترض هذا البرنامج التعليمي أيضًا أنك قمت بمراجعة دروس TensorFlow الأولية لـ TFF ، وأنك على دراية بمفاهيم TFF الأساسية. إذا لم تكن قد قمت بذلك حتى الآن ، فالرجاء النظر في مراجعة واحدة منها على الأقل.

حسابات JAX

تم تصميم دعم JAX في TFF ليكون متماثلًا مع الطريقة التي يعمل بها TFF مع TensorFlow ، بدءًا من الواردات:

import jax
import numpy as np
import tensorflow_federated as tff

أيضًا ، تمامًا كما هو الحال مع TensorFlow ، فإن الأساس للتعبير عن أي كود TFF هو المنطق الذي يتم تشغيله محليًا. يمكنك التعبير عن هذا المنطق في JAX، كما هو مبين أدناه، وذلك باستخدام @tff.experimental.jax_computation المجمع. انها تتصرف على نحو مماثل لل @tff.tf_computation أن من الآن لديك معتادا. لنبدأ بشيء بسيط ، على سبيل المثال ، عملية حسابية تضيف عددين صحيحين:

@tff.experimental.jax_computation(np.int32, np.int32)
def add_numbers(x, y):
  return jax.numpy.add(x, y)

يمكنك استخدام حساب JAX المحدد أعلاه تمامًا كما تستخدم عادةً حساب TFF. على سبيل المثال ، يمكنك التحقق من توقيع النوع الخاص به ، على النحو التالي:

str(add_numbers.type_signature)
'(<x=int32,y=int32> -> int32)'

لاحظ أن كنا np.int32 لتحديد نوع من الحجج. TFF لا يميز بين أنواع نمباي (مثل np.int32 ) ونوع TensorFlow (مثل tf.int32 ). من منظور TFF ، إنها مجرد طرق للإشارة إلى نفس الشيء.

الآن ، تذكر أن TFF ليس لغة Python (وإذا لم يقرع هذا الجرس ، فيرجى مراجعة بعض دروسنا السابقة ، على سبيل المثال ، حول الخوارزميات المخصصة). يمكنك استخدام @tff.experimental.jax_computation المجمع مع أي JAX التعليمات البرمجية التي يمكن تتبع وتسلسل، أي مع التعليمات البرمجية التي تفعل عادة علق مع @jax.jit المتوقع أن يتم تجميعها في XLA (ولكن لا تحتاج إلى فعلا استخدام @jax.jit الشرح لتضمين كود JAX بك في TFF).

في الواقع ، تحت الغطاء ، يقوم TFF على الفور بترجمة حسابات JAX إلى XLA. يمكنك التحقق من ذلك لنفسك عن طريق استخراج يدويا وطباعة رمز XLA تسلسل من add_numbers ، على النحو التالي:

comp_pb = tff.framework.serialize_computation(add_numbers)
comp_pb.WhichOneof('computation')
'xla'
xla_code = jax.lib.xla_client.XlaComputation(comp_pb.xla.hlo_module.value)
print(xla_code.as_hlo_text())
HloModule xla_computation_add_numbers.7

ENTRY xla_computation_add_numbers.7 {
  constant.4 = pred[] constant(false)
  parameter.1 = (s32[], s32[]) parameter(0)
  get-tuple-element.2 = s32[] get-tuple-element(parameter.1), index=0
  get-tuple-element.3 = s32[] get-tuple-element(parameter.1), index=1
  add.5 = s32[] add(get-tuple-element.2, get-tuple-element.3)
  ROOT tuple.6 = (s32[]) tuple(add.5)
}

التفكير في تمثيل الحسابات JAX كرمز XLA باعتبارها المعادل الوظيفي tf.GraphDef للحسابات التي أعرب عنها في TensorFlow. هذا هو محمول وقابل للتنفيذ في مجموعة متنوعة من البيئات التي تدعم XLA، تماما مثل tf.GraphDef يمكن تنفيذها على أي وقت TensorFlow.

يوفر TFF مكدس وقت تشغيل يعتمد على مترجم XLA كواجهة خلفية. يمكنك تفعيله كالتالي:

tff.backends.xla.set_local_python_execution_context()

الآن ، يمكنك تنفيذ الحساب الذي حددناه أعلاه:

add_numbers(2, 3)
5

سهل بما فيه الكفاية. دعنا نتعامل مع الضربة ونفعل شيئًا أكثر تعقيدًا ، مثل MNIST.

مثال على تدريب MNIST باستخدام API المعلب

كالعادة ، نبدأ بتحديد مجموعة من أنواع TFF لمجموعات من البيانات ، وللنموذج (تذكر أن TFF عبارة عن إطار عمل مكتوب بشدة).

import collections

BATCH_TYPE = collections.OrderedDict([
    ('pixels', tff.TensorType(np.float32, (50, 784))),
    ('labels', tff.TensorType(np.int32, (50,)))
])

MODEL_TYPE = collections.OrderedDict([
    ('weights', tff.TensorType(np.float32, (784, 10))),
    ('bias', tff.TensorType(np.float32, (10,)))
])

الآن ، دعنا نحدد دالة الخسارة للنموذج في JAX ، مع أخذ النموذج ودفعة واحدة من البيانات كمعامل:

def loss(model, batch):
  y = jax.nn.softmax(
      jax.numpy.add(
          jax.numpy.matmul(batch['pixels'], model['weights']), model['bias']))
  targets = jax.nn.one_hot(jax.numpy.reshape(batch['labels'], -1), 10)
  return -jax.numpy.mean(jax.numpy.sum(targets * jax.numpy.log(y), axis=1))

الآن ، طريقة واحدة للذهاب هي استخدام واجهة برمجة التطبيقات المعلبة. فيما يلي مثال على كيفية استخدام واجهة برمجة التطبيقات الخاصة بنا لإنشاء عملية تدريب بناءً على وظيفة الخسارة التي تم تحديدها للتو.

STEP_SIZE = 0.001

trainer = tff.experimental.learning.build_jax_federated_averaging_process(
    BATCH_TYPE, MODEL_TYPE, loss, STEP_SIZE)

يمكنك استخدام ما سبق تماما كما كنت تستخدم بناء المدرب من tf.Keras نموذج في TensorFlow. على سبيل المثال ، إليك كيفية إنشاء النموذج الأولي للتدريب:

initial_model = trainer.initialize()
initial_model
Struct([('weights', array([[0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       ...,
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.]], dtype=float32)), ('bias', array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32))])

من أجل أداء التدريب الفعلي ، نحتاج إلى بعض البيانات. لنجعل البيانات العشوائية بسيطة. نظرًا لأن البيانات عشوائية ، سنقوم بتقييم بيانات التدريب ، لأنه بخلاف ذلك ، مع بيانات التقييم العشوائية ، سيكون من الصعب توقع أداء النموذج. أيضًا ، بالنسبة لهذا العرض التوضيحي صغير الحجم ، لن نقلق بشأن أخذ عينات عشوائية من العملاء (نتركها كتمرين للمستخدم لاستكشاف هذه الأنواع من التغييرات باتباع القوالب من البرامج التعليمية الأخرى):

def random_batch():
  pixels = np.random.uniform(
      low=0.0, high=1.0, size=(50, 784)).astype(np.float32)
  labels = np.random.randint(low=0, high=9, size=(50,), dtype=np.int32)
  return collections.OrderedDict([('pixels', pixels), ('labels', labels)])

NUM_CLIENTS = 2
NUM_BATCHES = 10

train_data = [
    [random_batch() for _ in range(NUM_BATCHES)]
    for _ in range(NUM_CLIENTS)]

وبهذه الطريقة يمكننا القيام بخطوة تدريب واحدة على النحو التالي:

trained_model = trainer.next(initial_model, train_data)
trained_model
Struct([('weights', array([[ 1.04456245e-04, -1.53498477e-05,  2.54597180e-05, ...,
         5.61640409e-05, -5.32875274e-05, -4.62881755e-04],
       [ 7.30908650e-05,  4.67643113e-05,  2.03352147e-06, ...,
         3.77510623e-05,  3.52839161e-05, -4.59865667e-04],
       [ 8.14835730e-05,  3.03147244e-05, -1.89143739e-05, ...,
         1.12527239e-04,  4.09212225e-06, -4.59960109e-04],
       ...,
       [ 9.23552434e-05,  2.44302555e-06, -2.20817346e-05, ...,
         7.61375341e-05,  1.76906979e-05, -4.43495519e-04],
       [ 1.17451040e-04,  2.47748958e-05,  1.04728279e-05, ...,
         5.26388249e-07,  7.21131510e-05, -4.67137404e-04],
       [ 3.75041491e-05,  6.58061981e-05,  1.14522081e-05, ...,
         2.52584141e-05,  3.55410739e-05, -4.30888613e-04]], dtype=float32)), ('bias', array([ 1.5096272e-04,  2.6502126e-05, -1.9462314e-05,  8.1269856e-05,
        2.1832302e-04,  1.6636557e-04,  1.2815947e-04,  9.0642272e-05,
        7.7109929e-05, -9.1987278e-04], dtype=float32))])

دعنا نقيم نتيجة خطوة التدريب. لتسهيل الأمر ، يمكننا تقييمه بطريقة مركزية:

import itertools
eval_data = list(itertools.chain.from_iterable(train_data))

def average_loss(model, data):
  return np.mean([loss(model, batch) for batch in data])

print (average_loss(initial_model, eval_data))
print (average_loss(trained_model, eval_data))
2.3025854
2.282762

تتناقص الخسارة. رائعة! الآن ، دعنا نجري هذا على عدة جولات:

NUM_ROUNDS = 20
for _ in range(NUM_ROUNDS):
  trained_model = trainer.next(trained_model, train_data)
  print(average_loss(trained_model, eval_data))
2.2685437
2.257856
2.2495182
2.2428129
2.2372835
2.2326245
2.2286277
2.2251441
2.2220676
2.219318
2.2168345
2.2145717
2.2124937
2.2105706
2.2087805
2.2071042
2.2055268
2.2040353
2.2026198
2.2012706

كما ترى ، فإن استخدام JAX مع TFF لا يختلف كثيرًا ، على الرغم من أن واجهات برمجة التطبيقات التجريبية لا تتساوى مع وظائف TensorFlow APIs.

تحت الغطاء

إذا كنت تفضل عدم استخدام API المعلب الخاص بنا ، فيمكنك تنفيذ حساباتك المخصصة ، بنفس الطريقة التي رأيتها بها في دروس الخوارزميات المخصصة لـ TensorFlow ، باستثناء أنك ستستخدم آلية JAX للنسب المتدرج. على سبيل المثال ، فيما يلي كيفية تحديد حساب JAX الذي يقوم بتحديث النموذج في دقيقة واحدة:

@tff.experimental.jax_computation(MODEL_TYPE, BATCH_TYPE)
def train_on_one_batch(model, batch):
  grads = jax.grad(loss)(model, batch)
  return collections.OrderedDict([
      (k, model[k] - STEP_SIZE * grads[k]) for k in ['weights', 'bias']
  ])

إليك كيف يمكنك اختبار مدى نجاحها:

sample_batch = random_batch()
trained_model = train_on_one_batch(initial_model, sample_batch)
print(average_loss(initial_model, [sample_batch]))
print(average_loss(trained_model, [sample_batch]))
2.3025854
2.2977567

التحذير واحدة من العمل مع JAX هو أنه لا يقدم أي ما يعادل tf.data.Dataset . وبالتالي ، من أجل التكرار على مجموعات البيانات ، ستحتاج إلى استخدام العقود التصريحية الخاصة بـ TFF للعمليات على التسلسلات ، مثل تلك الموضحة أدناه:

@tff.federated_computation(MODEL_TYPE, tff.SequenceType(BATCH_TYPE))
def train_on_one_client(model, batches):
  return tff.sequence_reduce(batches, model, train_on_one_batch)

دعونا نرى أنه يعمل:

sample_dataset = [random_batch() for _ in range(100)]
trained_model = train_on_one_client(initial_model, sample_dataset)
print(average_loss(initial_model, sample_dataset))
print(average_loss(trained_model, sample_dataset))
2.3025854
2.2284968

تبدو العملية الحسابية التي تؤدي جولة واحدة من التدريب تمامًا مثل تلك التي ربما تكون قد شاهدتها في دروس TensorFlow:

@tff.federated_computation(
    tff.FederatedType(MODEL_TYPE, tff.SERVER),
    tff.FederatedType(tff.SequenceType(BATCH_TYPE), tff.CLIENTS))
def train_one_round(model, federated_data):
  locally_trained_models = tff.federated_map(
      train_on_one_client,
      collections.OrderedDict([
          ('model', tff.federated_broadcast(model)),
          ('batches', federated_data)]))
  return tff.federated_mean(locally_trained_models)

دعونا نرى أنه يعمل:

trained_model = train_one_round(initial_model, train_data)
print(average_loss(initial_model, eval_data))
print(average_loss(trained_model, eval_data))
2.3025854
2.282762

كما ترى ، فإن استخدام JAX في TFF ، سواء عبر واجهات برمجة التطبيقات المعلبة ، أو مباشرة باستخدام بنيات TFF منخفضة المستوى ، يشبه استخدام TFF مع TensorFlow. ترقبوا التحديثات المستقبلية ، وإذا كنت ترغب في رؤية دعم أفضل لقابلية التشغيل البيني عبر أطر ML ، فلا تتردد في إرسال طلب سحب إلينا!