TFF'de JAX için deneysel destek

TensorFlow.org'da görüntüleyin Google Colab'da çalıştırın GitHub'da görüntüle Not defterini indir

TFF, TensorFlow ekosisteminin bir parçası olmanın yanı sıra, diğer ön uç ve arka uç ML çerçeveleriyle birlikte çalışabilirliği sağlamayı amaçlamaktadır. Şu anda, diğer ML çerçeveleri için destek hala kuluçka aşamasındadır ve API'ler ve desteklenen işlevler değişebilir (büyük ölçüde TFF kullanıcılarından gelen talebin bir işlevi olarak). Bu öğretici, alternatif bir ML ön ucu olarak JAX ile TFF'nin ve alternatif bir arka uç olarak XLA derleyicisinin nasıl kullanılacağını açıklar. Burada gösterilen örnekler, uçtan uca tamamen yerel bir JAX/XLA yığınına dayanmaktadır. Çerçeveler arasında kodu karıştırma olasılığı (örneğin, JAX ile TensorFlow) gelecekteki eğitimlerden birinde tartışılacaktır.

Her zaman olduğu gibi katkılarınızı bekliyoruz. JAX/XLA desteği veya diğer ML çerçeveleriyle birlikte çalışma yeteneği sizin için önemliyse, lütfen bu yetenekleri TFF'nin geri kalanıyla eşit olacak şekilde geliştirmemize yardımcı olmayı düşünün.

Başlamadan Önce

Lütfen ortamınızı nasıl yapılandıracağınız konusunda TFF belgelerinin ana gövdesine bakın. Bu öğreticiyi nerede çalıştırdığınıza bağlı olarak, açıklamayı kaldırmak ve aşağıdaki kodun bir kısmını veya tamamını çalıştırmak isteyebilirsiniz.

# !pip install --quiet --upgrade tensorflow-federated-nightly
# !pip install --quiet --upgrade nest-asyncio
# import nest_asyncio
# nest_asyncio.apply()

Bu eğitim ayrıca TFF'nin birincil TensorFlow eğitimlerini incelediğinizi ve temel TFF kavramlarına aşina olduğunuzu varsayar. Henüz yapmadıysanız, lütfen bunlardan en az birini gözden geçirmeyi düşünün.

JAX hesaplamaları

TFF'de JAX desteği, içe aktarmalardan başlayarak TFF'nin TensorFlow ile birlikte çalışma biçimiyle simetrik olacak şekilde tasarlanmıştır:

import jax
import numpy as np
import tensorflow_federated as tff

Ayrıca, tıpkı TensorFlow'da olduğu gibi, herhangi bir TFF kodunu ifade etmenin temeli, yerel olarak çalışan mantıktır. Kullanarak, aşağıda gösterildiği gibi, JAX 'bu mantığı ifade edebiliriz @tff.experimental.jax_computation sargısı. Bu benzer şekilde davranır @tff.tf_computation artık senin aşina olduğu. Basit bir şeyle başlayalım, örneğin iki tamsayı ekleyen bir hesaplama:

@tff.experimental.jax_computation(np.int32, np.int32)
def add_numbers(x, y):
  return jax.numpy.add(x, y)

Yukarıda tanımlanan JAX hesaplamasını tıpkı normalde TFF hesaplamasını kullandığınız gibi kullanabilirsiniz. Örneğin, tür imzasını aşağıdaki gibi kontrol edebilirsiniz:

str(add_numbers.type_signature)
'(<x=int32,y=int32> -> int32)'

Kullandığımız Not np.int32 bağımsız değişken türünü tanımlamak için kullanılır. TFF (örneğin Numpy türleri arasında ayrım yapmaz np.int32 ) ve (örneğin, TensorFlow tip tf.int32 ). TFF'nin bakış açısından, bunlar sadece aynı şeye atıfta bulunmanın yolları.

Şimdi, TFF'nin Python olmadığını unutmayın (ve bu bir zil çalmazsa, lütfen daha önceki öğreticilerimizden bazılarını inceleyin, örneğin özel algoritmalar hakkında). Sen kullanabilirsiniz @tff.experimental.jax_computation koduyla, yani takip ve seri hale getirilebilir sarıcı ile herhangi jax kodu olduğunu yapacağınız ile normalde Annotatesekmesindeki @jax.jit (ama gerekmez XLA içine derlenmiş olması bekleniyor aslında kullanmak @jax.jit ) TFF ek açıklama için embed sizin JAX kodu.

Gerçekten de, kaputun altında TFF, JAX hesaplamalarını anında XLA'ya derler. El ile çıkarma ve gelen tefrika XLA kodu yazdırarak kendiniz için bunu kontrol edebilirsiniz add_numbers aşağıdaki gibi:

comp_pb = tff.framework.serialize_computation(add_numbers)
comp_pb.WhichOneof('computation')
'xla'
xla_code = jax.lib.xla_client.XlaComputation(comp_pb.xla.hlo_module.value)
print(xla_code.as_hlo_text())
HloModule xla_computation_add_numbers.7

ENTRY xla_computation_add_numbers.7 {
  constant.4 = pred[] constant(false)
  parameter.1 = (s32[], s32[]) parameter(0)
  get-tuple-element.2 = s32[] get-tuple-element(parameter.1), index=0
  get-tuple-element.3 = s32[] get-tuple-element(parameter.1), index=1
  add.5 = s32[] add(get-tuple-element.2, get-tuple-element.3)
  ROOT tuple.6 = (s32[]) tuple(add.5)
}

İşlevsel eşdeğeri olarak XLA kodu olarak JAX hesaplamalardan temsil düşünün tf.GraphDef TensorFlow ifade hesaplamalar için. Sadece gibi XLA destekleyen ortamlar çeşitli taşınabilir ve çalıştırılabilir tf.GraphDef herhangi TensorFlow çalışma zamanı üzerinde çalıştırılabilir.

TFF, arka uç olarak XLA derleyicisine dayalı bir çalışma zamanı yığını sağlar. Aşağıdaki şekilde etkinleştirebilirsiniz:

tff.backends.xla.set_local_python_execution_context()

Şimdi yukarıda tanımladığımız hesaplamayı yapabilirsiniz:

add_numbers(2, 3)
5

Yeterince kolay. Darbe ile gidelim ve MNIST gibi daha karmaşık bir şey yapalım.

Hazır API ile MNIST eğitimi örneği

Her zamanki gibi, veri yığınları ve model için bir grup TFF türü tanımlayarak başlıyoruz (unutmayın, TFF kesin olarak yazılmış bir çerçevedir).

import collections

BATCH_TYPE = collections.OrderedDict([
    ('pixels', tff.TensorType(np.float32, (50, 784))),
    ('labels', tff.TensorType(np.int32, (50,)))
])

MODEL_TYPE = collections.OrderedDict([
    ('weights', tff.TensorType(np.float32, (784, 10))),
    ('bias', tff.TensorType(np.float32, (10,)))
])

Şimdi, modeli ve tek bir veri grubunu parametre olarak alarak JAX'ta model için bir kayıp fonksiyonu tanımlayalım:

def loss(model, batch):
  y = jax.nn.softmax(
      jax.numpy.add(
          jax.numpy.matmul(batch['pixels'], model['weights']), model['bias']))
  targets = jax.nn.one_hot(jax.numpy.reshape(batch['labels'], -1), 10)
  return -jax.numpy.mean(jax.numpy.sum(targets * jax.numpy.log(y), axis=1))

Şimdi, gitmenin bir yolu, bir hazır API kullanmaktır. Az önce tanımlanan kayıp işlevine dayalı bir eğitim süreci oluşturmak için API'mizi nasıl kullanabileceğinize dair bir örnek.

STEP_SIZE = 0.001

trainer = tff.experimental.learning.build_jax_federated_averaging_process(
    BATCH_TYPE, MODEL_TYPE, loss, STEP_SIZE)

Bir gelen bir eğitmen yapı kullandığınız gibi sadece yukarıdaki kullanabilirsiniz tf.Keras TensorFlow modele. Örneğin, eğitim için başlangıç ​​modelini şu şekilde oluşturabilirsiniz:

initial_model = trainer.initialize()
initial_model
Struct([('weights', array([[0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       ...,
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.]], dtype=float32)), ('bias', array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32))])

Gerçek eğitimi gerçekleştirmek için bazı verilere ihtiyacımız var. Basit tutmak için rastgele veriler yapalım. Veriler rastgele olduğu için eğitim verileri üzerinde değerlendirme yapacağız, aksi takdirde rastgele değerlendirme verileriyle modelin çalışmasını beklemek zor olurdu. Ayrıca, bu küçük ölçekli demo için, istemcileri rastgele örnekleme konusunda endişelenmeyeceğiz (diğer öğreticilerdeki şablonları izleyerek bu tür değişiklikleri keşfetmeyi kullanıcıya bir alıştırma olarak bırakıyoruz):

def random_batch():
  pixels = np.random.uniform(
      low=0.0, high=1.0, size=(50, 784)).astype(np.float32)
  labels = np.random.randint(low=0, high=9, size=(50,), dtype=np.int32)
  return collections.OrderedDict([('pixels', pixels), ('labels', labels)])

NUM_CLIENTS = 2
NUM_BATCHES = 10

train_data = [
    [random_batch() for _ in range(NUM_BATCHES)]
    for _ in range(NUM_CLIENTS)]

Bununla, aşağıdaki gibi tek bir eğitim adımını gerçekleştirebiliriz:

trained_model = trainer.next(initial_model, train_data)
trained_model
Struct([('weights', array([[ 1.04456245e-04, -1.53498477e-05,  2.54597180e-05, ...,
         5.61640409e-05, -5.32875274e-05, -4.62881755e-04],
       [ 7.30908650e-05,  4.67643113e-05,  2.03352147e-06, ...,
         3.77510623e-05,  3.52839161e-05, -4.59865667e-04],
       [ 8.14835730e-05,  3.03147244e-05, -1.89143739e-05, ...,
         1.12527239e-04,  4.09212225e-06, -4.59960109e-04],
       ...,
       [ 9.23552434e-05,  2.44302555e-06, -2.20817346e-05, ...,
         7.61375341e-05,  1.76906979e-05, -4.43495519e-04],
       [ 1.17451040e-04,  2.47748958e-05,  1.04728279e-05, ...,
         5.26388249e-07,  7.21131510e-05, -4.67137404e-04],
       [ 3.75041491e-05,  6.58061981e-05,  1.14522081e-05, ...,
         2.52584141e-05,  3.55410739e-05, -4.30888613e-04]], dtype=float32)), ('bias', array([ 1.5096272e-04,  2.6502126e-05, -1.9462314e-05,  8.1269856e-05,
        2.1832302e-04,  1.6636557e-04,  1.2815947e-04,  9.0642272e-05,
        7.7109929e-05, -9.1987278e-04], dtype=float32))])

Eğitim adımının sonucunu değerlendirelim. Kolay tutmak için, merkezi bir şekilde değerlendirebiliriz:

import itertools
eval_data = list(itertools.chain.from_iterable(train_data))

def average_loss(model, data):
  return np.mean([loss(model, batch) for batch in data])

print (average_loss(initial_model, eval_data))
print (average_loss(trained_model, eval_data))
2.3025854
2.282762

Kayıp azalıyor. Harika! Şimdi bunu birden fazla turda çalıştıralım:

NUM_ROUNDS = 20
for _ in range(NUM_ROUNDS):
  trained_model = trainer.next(trained_model, train_data)
  print(average_loss(trained_model, eval_data))
2.2685437
2.257856
2.2495182
2.2428129
2.2372835
2.2326245
2.2286277
2.2251441
2.2220676
2.219318
2.2168345
2.2145717
2.2124937
2.2105706
2.2087805
2.2071042
2.2055268
2.2040353
2.2026198
2.2012706

Gördüğünüz gibi, JAX'ı TFF ile kullanmak o kadar da farklı değil, ancak deneysel API'ler işlevsellik açısından henüz TensorFlow API'leri ile aynı seviyede değil.

kaputun altında

Hazır API'mizi kullanmamayı tercih ederseniz, kendi özel hesaplamalarınızı, TensorFlow için özel algoritma eğitimlerinde gördüğünüz gibi, JAX'ın gradyan iniş mekanizmasını kullanmanız dışında uygulayabilirsiniz. Örneğin, modeli tek bir mini partide güncelleyen bir JAX hesaplamasını nasıl tanımlayabileceğiniz aşağıda açıklanmıştır:

@tff.experimental.jax_computation(MODEL_TYPE, BATCH_TYPE)
def train_on_one_batch(model, batch):
  grads = jax.grad(loss)(model, batch)
  return collections.OrderedDict([
      (k, model[k] - STEP_SIZE * grads[k]) for k in ['weights', 'bias']
  ])

Çalışıp çalışmadığını şu şekilde test edebilirsiniz:

sample_batch = random_batch()
trained_model = train_on_one_batch(initial_model, sample_batch)
print(average_loss(initial_model, [sample_batch]))
print(average_loss(trained_model, [sample_batch]))
2.3025854
2.2977567

JAX ile çalışmanın bir ihtar o eşdeğer sunmuyor olmasıdır tf.data.Dataset . Bu nedenle, veri kümeleri üzerinde yineleme yapmak için, aşağıda gösterilen gibi, dizilerdeki işlemler için TFF'nin bildirimsel sözleşmelerini kullanmanız gerekecektir:

@tff.federated_computation(MODEL_TYPE, tff.SequenceType(BATCH_TYPE))
def train_on_one_client(model, batches):
  return tff.sequence_reduce(batches, model, train_on_one_batch)

İşe yaradığını görelim:

sample_dataset = [random_batch() for _ in range(100)]
trained_model = train_on_one_client(initial_model, sample_dataset)
print(average_loss(initial_model, sample_dataset))
print(average_loss(trained_model, sample_dataset))
2.3025854
2.2284968

Tek bir eğitim turu gerçekleştiren hesaplama, TensorFlow eğitimlerinde görmüş olabileceğiniz hesaplamaya benziyor:

@tff.federated_computation(
    tff.FederatedType(MODEL_TYPE, tff.SERVER),
    tff.FederatedType(tff.SequenceType(BATCH_TYPE), tff.CLIENTS))
def train_one_round(model, federated_data):
  locally_trained_models = tff.federated_map(
      train_on_one_client,
      collections.OrderedDict([
          ('model', tff.federated_broadcast(model)),
          ('batches', federated_data)]))
  return tff.federated_mean(locally_trained_models)

İşe yaradığını görelim:

trained_model = train_one_round(initial_model, train_data)
print(average_loss(initial_model, eval_data))
print(average_loss(trained_model, eval_data))
2.3025854
2.282762

Gördüğünüz gibi, ister hazır API'ler aracılığıyla ister doğrudan düşük seviyeli TFF yapılarını kullanarak olsun, TFF'de JAX kullanmak, TFF'yi TensorFlow ile kullanmaya benzer. Gelecekteki güncellemeler için bizi izlemeye devam edin ve ML çerçeveleri arasında birlikte çalışabilirlik için daha iyi destek görmek istiyorsanız, bize bir çekme isteği göndermekten çekinmeyin!