پشتیبانی آزمایشی از JAX در TFF

مشاهده در TensorFlow.org در Google Colab اجرا کنید مشاهده در GitHub دانلود دفترچه یادداشت

TFF علاوه بر اینکه بخشی از اکوسیستم TensorFlow است ، قابلیت همکاری با سایر چارچوب های ML رو به جلو و پشتیبان را نیز فعال می کند. در حال حاضر ، پشتیبانی از سایر چارچوب های ML هنوز در مرحله جوجه کشی است و API ها و عملکرد پشتیبانی شده ممکن است تغییر کند (عمدتا به عنوان تابعی از تقاضای کاربران TFF). این آموزش نحوه استفاده از TFF با JAX را به عنوان یک جایگزین ML پیش نماینده و کامپایلر XLA به عنوان پشتیبان جایگزین توضیح می دهد. مثالهایی که در اینجا نشان داده شده است بر اساس یک پشته کاملاً بومی JAX/XLA ، از انتها به انتها است. امکان ترکیب کد در چارچوب ها (به عنوان مثال ، JAX با TensorFlow) در یکی از آموزش های آینده مورد بحث قرار می گیرد.

مثل همیشه ، ما از مشارکت شما استقبال می کنیم. اگر پشتیبانی از JAX/XLA یا توانایی همکاری با سایر چارچوب های ML برای شما مهم است ، لطفاً به ما کمک کنید تا این قابلیت ها را در جهت برابری باقیمانده TFF توسعه دهیم.

قبل از اینکه شروع کنیم

لطفاً برای پیکربندی محیط خود به بخش اصلی اسناد TFF مراجعه کنید. بسته به جایی که این آموزش را اجرا می کنید ، ممکن است بخواهید بخشی یا تمام کد زیر را کامنت نگذارید و اجرا کنید.

# !pip install --quiet --upgrade tensorflow-federated-nightly
# !pip install --quiet --upgrade nest-asyncio
# import nest_asyncio
# nest_asyncio.apply()

این آموزش همچنین فرض می کند که شما آموزشهای اولیه TensorFlow TFF را مرور کرده اید و با مفاهیم اصلی TFF آشنا هستید. اگر هنوز این کار را نکرده اید ، لطفاً حداقل یکی از آنها را مرور کنید.

محاسبات JAX

پشتیبانی از JAX در TFF طوری طراحی شده است که متقارن باشد و نحوه همکاری TFF با TensorFlow ، با واردات شروع می شود:

import jax
import numpy as np
import tensorflow_federated as tff

همچنین ، درست مانند TensorFlow ، اساس بیان هر کد TFF منطقی است که به صورت محلی اجرا می شود. شما می توانید این منطق در JAX بیان، به عنوان زیر نشان داده شده، با استفاده از @tff.experimental.jax_computation لفاف بسته بندی. رفتاری شبیه به @tff.tf_computation که در حال حاضر خود را با آن آشنا هستند. بیایید با یک چیز ساده شروع کنیم ، به عنوان مثال ، محاسبه ای که دو عدد صحیح را اضافه می کند:

@tff.experimental.jax_computation(np.int32, np.int32)
def add_numbers(x, y):
  return jax.numpy.add(x, y)

شما می توانید از محاسبات JAX که در بالا تعریف شده استفاده کنید ، درست مانند یک روش محاسبه TFF. به عنوان مثال ، می توانید امضای نوع آن را به شرح زیر بررسی کنید:

str(add_numbers.type_signature)
'(<x=int32,y=int32> -> int32)'

توجه داشته باشید که ما استفاده می شود np.int32 برای تعریف نوع استدلال است. TFF کند بین انواع نامپای (مانند تشخیص نیست np.int32 ) و نوع TensorFlow (مانند tf.int32 ). از دیدگاه TFF ، آنها فقط راه هایی برای اشاره به یک چیز هستند.

اکنون ، به یاد داشته باشید که TFF پایتون نیست (و اگر این زنگ نمی زند ، لطفاً برخی از آموزش های قبلی ما را بررسی کنید ، به عنوان مثال ، در مورد الگوریتم های سفارشی). شما می توانید با استفاده از @tff.experimental.jax_computation لفاف بسته بندی با هر JAX کد است که می توان و سریال، به عنوان مثال، با کد که شما به طور معمول حاشیه نویسی با @jax.jit انتظار می رود به XLA وارد شده باشد (اما شما لازم نیست که در واقع استفاده از @jax.jit حاشیه نویسی به جاسازی کد JAX خود را در TFF).

در واقع ، زیر پوشش ، TFF بلافاصله محاسبات JAX را به XLA کامپایل می کند. شما می توانید این را برای خودتان به صورت دستی استخراج و چاپ کد XLA مرتب از بررسی add_numbers شرح زیر است،:

comp_pb = tff.framework.serialize_computation(add_numbers)
comp_pb.WhichOneof('computation')
'xla'
xla_code = jax.lib.xla_client.XlaComputation(comp_pb.xla.hlo_module.value)
print(xla_code.as_hlo_text())
HloModule xla_computation_add_numbers.7

ENTRY xla_computation_add_numbers.7 {
  constant.4 = pred[] constant(false)
  parameter.1 = (s32[], s32[]) parameter(0)
  get-tuple-element.2 = s32[] get-tuple-element(parameter.1), index=0
  get-tuple-element.3 = s32[] get-tuple-element(parameter.1), index=1
  add.5 = s32[] add(get-tuple-element.2, get-tuple-element.3)
  ROOT tuple.6 = (s32[]) tuple(add.5)
}

فکر می کنم از نمایندگی از محاسبات JAX به عنوان کد XLA به عنوان معادل کارکردی tf.GraphDef برای محاسبات بیان شده در TensorFlow. این قابل حمل و قابل اجرا در محیط های مختلف است که حمایت از XLA، درست مثل این است tf.GraphDef می توان در هر زمان اجرا TensorFlow اجرا می شود.

TFF یک پشته زمان اجرا را بر اساس کامپایلر XLA به عنوان پشتیبان ارائه می دهد. می توانید آن را به صورت زیر فعال کنید:

tff.backends.xla.set_local_python_execution_context()

اکنون می توانید محاسبه ای را که در بالا تعریف کردیم اجرا کنید:

add_numbers(2, 3)
5

به حد کافی ساده است. بیایید با ضربه وارد عمل شویم و کاری پیچیده تر انجام دهیم ، مانند MNIST.

نمونه ای از آموزش MNIST با API کنسرو شده

طبق معمول ، ما دسته ای از انواع TFF را برای دسته های داده و برای مدل تعریف می کنیم (به یاد داشته باشید ، TFF یک چارچوب قوی تایپ شده است).

import collections

BATCH_TYPE = collections.OrderedDict([
    ('pixels', tff.TensorType(np.float32, (50, 784))),
    ('labels', tff.TensorType(np.int32, (50,)))
])

MODEL_TYPE = collections.OrderedDict([
    ('weights', tff.TensorType(np.float32, (784, 10))),
    ('bias', tff.TensorType(np.float32, (10,)))
])

حال ، بیایید یک تابع ضرر برای مدل در JAX تعریف کنیم و مدل و یک دسته از داده ها را به عنوان پارامتر در نظر بگیریم:

def loss(model, batch):
  y = jax.nn.softmax(
      jax.numpy.add(
          jax.numpy.matmul(batch['pixels'], model['weights']), model['bias']))
  targets = jax.nn.one_hot(jax.numpy.reshape(batch['labels'], -1), 10)
  return -jax.numpy.mean(jax.numpy.sum(targets * jax.numpy.log(y), axis=1))

در حال حاضر ، یکی از راه های استفاده از API کنسرو شده است. در اینجا یک مثال وجود دارد که چگونه می توانید از API ما برای ایجاد یک فرآیند آموزشی بر اساس تابع ضرر که به تازگی تعریف شده است ، استفاده کنید.

STEP_SIZE = 0.001

trainer = tff.experimental.learning.build_jax_federated_averaging_process(
    BATCH_TYPE, MODEL_TYPE, loss, STEP_SIZE)

شما می توانید در بالا همانطور که شما ساخت از یک مربی استفاده کنید tf.Keras مدل در TensorFlow. به عنوان مثال ، در اینجا نحوه ایجاد مدل اولیه برای آموزش آورده شده است:

initial_model = trainer.initialize()
initial_model
Struct([('weights', array([[0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       ...,
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.]], dtype=float32)), ('bias', array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32))])

برای انجام آموزش واقعی ، ما به برخی داده ها نیاز داریم. اجازه دهید داده های تصادفی بسازیم تا ساده باشد. از آنجا که داده ها تصادفی هستند ، ما قصد داریم داده های آموزشی را ارزیابی کنیم ، زیرا در غیر این صورت ، با داده های ارزیابی تصادفی ، انتظار عملکرد مدل دشوار است. همچنین ، برای این نسخه ی نمایشی در مقیاس کوچک ، ما نگران نمونه گیری تصادفی از مشتریان نخواهیم بود (ما این کار را به عنوان یک تمرین به کاربر واگذار می کنیم تا با دنبال کردن الگوهای سایر آموزش ها ، این تغییرات را بررسی کند):

def random_batch():
  pixels = np.random.uniform(
      low=0.0, high=1.0, size=(50, 784)).astype(np.float32)
  labels = np.random.randint(low=0, high=9, size=(50,), dtype=np.int32)
  return collections.OrderedDict([('pixels', pixels), ('labels', labels)])

NUM_CLIENTS = 2
NUM_BATCHES = 10

train_data = [
    [random_batch() for _ in range(NUM_BATCHES)]
    for _ in range(NUM_CLIENTS)]

با این کار ، ما می توانیم یک مرحله آموزشی را به شرح زیر انجام دهیم:

trained_model = trainer.next(initial_model, train_data)
trained_model
Struct([('weights', array([[ 1.04456245e-04, -1.53498477e-05,  2.54597180e-05, ...,
         5.61640409e-05, -5.32875274e-05, -4.62881755e-04],
       [ 7.30908650e-05,  4.67643113e-05,  2.03352147e-06, ...,
         3.77510623e-05,  3.52839161e-05, -4.59865667e-04],
       [ 8.14835730e-05,  3.03147244e-05, -1.89143739e-05, ...,
         1.12527239e-04,  4.09212225e-06, -4.59960109e-04],
       ...,
       [ 9.23552434e-05,  2.44302555e-06, -2.20817346e-05, ...,
         7.61375341e-05,  1.76906979e-05, -4.43495519e-04],
       [ 1.17451040e-04,  2.47748958e-05,  1.04728279e-05, ...,
         5.26388249e-07,  7.21131510e-05, -4.67137404e-04],
       [ 3.75041491e-05,  6.58061981e-05,  1.14522081e-05, ...,
         2.52584141e-05,  3.55410739e-05, -4.30888613e-04]], dtype=float32)), ('bias', array([ 1.5096272e-04,  2.6502126e-05, -1.9462314e-05,  8.1269856e-05,
        2.1832302e-04,  1.6636557e-04,  1.2815947e-04,  9.0642272e-05,
        7.7109929e-05, -9.1987278e-04], dtype=float32))])

بیایید نتیجه مرحله آموزش را ارزیابی کنیم. برای سهولت کار ، می توانیم آن را به صورت متمرکز ارزیابی کنیم:

import itertools
eval_data = list(itertools.chain.from_iterable(train_data))

def average_loss(model, data):
  return np.mean([loss(model, batch) for batch in data])

print (average_loss(initial_model, eval_data))
print (average_loss(trained_model, eval_data))
2.3025854
2.282762

ضرر در حال کاهش است. عالی! حالا ، بیایید این را در چند دور اجرا کنیم:

NUM_ROUNDS = 20
for _ in range(NUM_ROUNDS):
  trained_model = trainer.next(trained_model, train_data)
  print(average_loss(trained_model, eval_data))
2.2685437
2.257856
2.2495182
2.2428129
2.2372835
2.2326245
2.2286277
2.2251441
2.2220676
2.219318
2.2168345
2.2145717
2.2124937
2.2105706
2.2087805
2.2071042
2.2055268
2.2040353
2.2026198
2.2012706

همانطور که مشاهده می کنید ، استفاده از JAX با TFF تفاوت چندانی ندارد ، اگرچه API های آزمایشی هنوز از نظر عملکرد با TIsorFlow API ها همسان نیستند.

در زیر کاپوت

اگر ترجیح می دهید از API کنسروی ما استفاده نکنید ، می توانید محاسبات سفارشی خود را پیاده سازی کنید ، بسیار شبیه به آنچه که در آموزش الگوریتم های سفارشی TensorFlow مشاهده کرده اید ، با این تفاوت که از مکانیزم JAX برای شیب نزولی استفاده خواهید کرد. به عنوان مثال ، در زیر نحوه تعیین محاسبه JAX است که مدل را در یک مینی دسته واحد به روز می کند:

@tff.experimental.jax_computation(MODEL_TYPE, BATCH_TYPE)
def train_on_one_batch(model, batch):
  grads = jax.grad(loss)(model, batch)
  return collections.OrderedDict([
      (k, model[k] - STEP_SIZE * grads[k]) for k in ['weights', 'bias']
  ])

در اینجا نحوه بررسی عملکرد آن آمده است:

sample_batch = random_batch()
trained_model = train_on_one_batch(initial_model, sample_batch)
print(average_loss(initial_model, [sample_batch]))
print(average_loss(trained_model, [sample_batch]))
2.3025854
2.2977567

یک نصیحت از کار با JAX است که آن را معادل ارائه نمی tf.data.Dataset . بنابراین ، برای تکرار مجموعه داده ها ، باید از دستورات اعلان TFF برای عملیات در دنباله ها استفاده کنید ، مانند آنچه در زیر نشان داده شده است:

@tff.federated_computation(MODEL_TYPE, tff.SequenceType(BATCH_TYPE))
def train_on_one_client(model, batches):
  return tff.sequence_reduce(batches, model, train_on_one_batch)

بیایید ببینیم که کار می کند:

sample_dataset = [random_batch() for _ in range(100)]
trained_model = train_on_one_client(initial_model, sample_dataset)
print(average_loss(initial_model, sample_dataset))
print(average_loss(trained_model, sample_dataset))
2.3025854
2.2284968

محاسبه ای که یک دور آموزش را انجام می دهد دقیقاً شبیه آنچه در آموزش های TensorFlow دیده اید است:

@tff.federated_computation(
    tff.FederatedType(MODEL_TYPE, tff.SERVER),
    tff.FederatedType(tff.SequenceType(BATCH_TYPE), tff.CLIENTS))
def train_one_round(model, federated_data):
  locally_trained_models = tff.federated_map(
      train_on_one_client,
      collections.OrderedDict([
          ('model', tff.federated_broadcast(model)),
          ('batches', federated_data)]))
  return tff.federated_mean(locally_trained_models)

بیایید ببینیم که کار می کند:

trained_model = train_one_round(initial_model, train_data)
print(average_loss(initial_model, eval_data))
print(average_loss(trained_model, eval_data))
2.3025854
2.282762

همانطور که مشاهده می کنید ، استفاده از JAX در TFF ، چه از طریق API های کنسرو شده ، چه مستقیم از ساختارهای TFF سطح پایین ، مشابه استفاده از TFF با TensorFlow است. منتظر به روزرسانی های آینده باشید و اگر می خواهید پشتیبانی بهتری از قابلیت همکاری در چارچوب های ML مشاهده کنید ، با خیال راحت یک درخواست کشش برای ما ارسال کنید!