Google I / O'daki önemli notları, ürün oturumlarını, atölyeleri ve daha fazlasını izleyin Oynatma listesine bakın

TFF'de JAX için deneysel destek

TensorFlow.org'da görüntüleyin Google Colab'de çalıştırın GitHub'da görüntüle Defteri indirin

TFF, TensorFlow ekosisteminin bir parçası olmanın yanı sıra, diğer ön uç ve arka uç ML çerçeveleriyle birlikte çalışabilirliği sağlamayı amaçlamaktadır. Şu anda, diğer makine öğrenimi çerçevelerine yönelik destek hala kuluçka aşamasındadır ve desteklenen API'ler ve işlevler değişebilir (büyük ölçüde TFF kullanıcılarından gelen talebin bir işlevi olarak). Bu öğreticide, alternatif bir ML ön ucu olarak JAX ile TFF'nin ve alternatif bir arka uç olarak XLA derleyicisinin nasıl kullanılacağı açıklanmaktadır. Burada gösterilen örnekler, uçtan uca tamamen yerel bir JAX / XLA yığınına dayanmaktadır. Çerçeveler arasında kod karıştırma olasılığı (örneğin, TensorFlow ile JAX) gelecekteki eğitimlerden birinde tartışılacaktır.

Her zaman olduğu gibi katkılarınızı memnuniyetle karşılıyoruz. JAX / XLA desteği veya diğer makine öğrenimi çerçeveleriyle birlikte çalışma yeteneği sizin için önemliyse, lütfen bu yetenekleri TFF'nin geri kalanıyla eşit olacak şekilde geliştirmemize yardımcı olmayı düşünün.

Başlamadan Önce

Ortamınızı nasıl yapılandıracağınızla ilgili olarak lütfen TFF belgelerinin ana gövdesine bakın. Bu öğreticiyi nerede çalıştırdığınıza bağlı olarak, aşağıdaki kodun bir kısmını veya tamamını çalıştırmak isteyebilirsiniz.

# !pip install --quiet --upgrade tensorflow-federated-nightly
# !pip install --quiet --upgrade nest-asyncio
# import nest_asyncio
# nest_asyncio.apply()

Bu öğretici ayrıca, TFF'nin birincil TensorFlow eğitimlerini gözden geçirdiğinizi ve temel TFF kavramlarına aşina olduğunuzu varsayar. Henüz bunu yapmadıysanız, lütfen bunlardan en az birini gözden geçirmeyi düşünün.

JAX hesaplamaları

TFF'de JAX desteği, içe aktarmalardan başlayarak TFF'nin TensorFlow ile birlikte çalışma biçimiyle simetrik olacak şekilde tasarlanmıştır:

import jax
import numpy as np
import tensorflow_federated as tff

Ayrıca, TensorFlow'da olduğu gibi, herhangi bir TFF kodunu ifade etmenin temeli, yerel olarak çalışan mantıktır. Bu mantığı @tff.experimental.jax_computation aşağıda gösterildiği gibi @tff.experimental.jax_computation sarmalayıcısını kullanarak ifade edebilirsiniz. Şimdiye kadar @tff.tf_computation olduğunuz @tff.tf_computation benzer şekilde davranır. Basit bir şeyle başlayalım, örneğin iki tam sayı ekleyen bir hesaplama:

@tff.experimental.jax_computation(np.int32, np.int32)
def add_numbers(x, y):
  return jax.numpy.add(x, y)

Yukarıda tanımlanan JAX hesaplamasını, normalde bir TFF hesaplaması kullanacağınız gibi kullanabilirsiniz. Örneğin, aşağıdaki gibi tip imzasını kontrol edebilirsiniz:

str(add_numbers.type_signature)
'(<x=int32,y=int32> -> int32)'

Argümanların türünü tanımlamak için np.int32 kullandığımıza np.int32 edin. TFF (örneğin Numpy türleri arasında ayrım yapmaz np.int32 ) ve (örneğin, TensorFlow tip tf.int32 ). TFF'nin bakış açısından, bunlar sadece aynı şeyi ifade etmenin yolları.

Şimdi, TFF'nin Python olmadığını unutmayın (ve eğer bu bir zil çalmazsa, lütfen daha önceki eğitimlerimizden bazılarını inceleyin, örneğin, özel algoritmalarla ilgili). @tff.experimental.jax_computation sarmalayıcısını, izlenebilen ve serileştirilebilen herhangi bir JAX koduyla kullanabilirsiniz, yani normalde @jax.jit ile açıklama @jax.jit ve @jax.jit derlenmesi beklenen @jax.jit (ancak yapmanız gerekmez) JAX kodunuzu TFF'ye yerleştirmek için @jax.jit ek açıklamasını kullanın).

Aslında, başlık altında, TFF, JAX hesaplamalarını anında XLA'ya derler. Aşağıdaki gibi, seri hale getirilmiş XLA kodunu add_numbers manuel olarak çıkararak ve yazdırarak bunu kendiniz kontrol edebilirsiniz:

comp_pb = tff.framework.serialize_computation(add_numbers)
comp_pb.WhichOneof('computation')
'xla'
xla_code = jax.lib.xla_client.XlaComputation(comp_pb.xla.hlo_module.value)
print(xla_code.as_hlo_text())
HloModule xla_computation_add_numbers.7

ENTRY xla_computation_add_numbers.7 {
  constant.4 = pred[] constant(false)
  parameter.1 = (s32[], s32[]) parameter(0)
  get-tuple-element.2 = s32[] get-tuple-element(parameter.1), index=0
  get-tuple-element.3 = s32[] get-tuple-element(parameter.1), index=1
  add.5 = s32[] add(get-tuple-element.2, get-tuple-element.3)
  ROOT tuple.6 = (s32[]) tuple(add.5)
}

JAX hesaplamalarının XLA kodu olarak tf.GraphDef , tf.GraphDef ifade edilen hesaplamalar için tf.GraphDef işlevsel eşdeğeri olarak düşünün. Taşınabilir ve tf.GraphDef destekleyen çeşitli ortamlarda çalıştırılabilir, tıpkı tf.GraphDef gibi herhangi bir TensorFlow çalışma zamanında yürütülebilir.

TFF, arka uç olarak XLA derleyicisine dayalı bir çalışma zamanı yığını sağlar. Aşağıdaki şekilde etkinleştirebilirsiniz:

tff.backends.xla.set_local_execution_context()

Şimdi, yukarıda tanımladığımız hesaplamayı gerçekleştirebilirsiniz:

add_numbers(2, 3)
5

Yeterince kolay. Darbe ile devam edelim ve MNIST gibi daha karmaşık bir şey yapalım.

Hazır API ile MNIST eğitimi örneği

Her zaman olduğu gibi, veri yığınları ve model için bir dizi TFF türü tanımlayarak başlıyoruz (unutmayın, TFF güçlü bir şekilde yazılmış bir çerçevedir).

import collections

BATCH_TYPE = collections.OrderedDict([
    ('pixels', tff.TensorType(np.float32, (50, 784))),
    ('labels', tff.TensorType(np.int32, (50,)))
])

MODEL_TYPE = collections.OrderedDict([
    ('weights', tff.TensorType(np.float32, (784, 10))),
    ('bias', tff.TensorType(np.float32, (10,)))
])

Şimdi, modeli ve tek bir veri kümesini parametre olarak alarak JAX'daki model için bir kayıp işlevi tanımlayalım:

def loss(model, batch):
  y = jax.nn.softmax(
      jax.numpy.add(
          jax.numpy.matmul(batch['pixels'], model['weights']), model['bias']))
  targets = jax.nn.one_hot(jax.numpy.reshape(batch['labels'], -1), 10)
  return -jax.numpy.mean(jax.numpy.sum(targets * jax.numpy.log(y), axis=1))

Şimdi, gitmenin bir yolu, hazır bir API kullanmaktır. Az önce tanımlanan kayıp işlevine dayalı bir eğitim süreci oluşturmak için API'mizi nasıl kullanabileceğinizin bir örneğini burada bulabilirsiniz.

STEP_SIZE = 0.001

trainer = tff.experimental.learning.build_jax_federated_averaging_process(
    BATCH_TYPE, MODEL_TYPE, loss, STEP_SIZE)

Yukarıdakileri, TensorFlow'da bir tf.Keras modelinden bir eğitmen yapısını kullandığınız gibi kullanabilirsiniz. Örneğin, eğitim için başlangıç ​​modelini şu şekilde oluşturabilirsiniz:

initial_model = trainer.initialize()
initial_model
Struct([('weights', array([[0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       ...,
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.]], dtype=float32)), ('bias', array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32))])

Gerçek eğitimi gerçekleştirmek için bazı verilere ihtiyacımız var. Basit tutmak için rastgele veriler yapalım. Veriler rastgele olduğundan, eğitim verilerini değerlendireceğiz, çünkü aksi takdirde rastgele değerlendirme verileriyle modelin çalışmasını beklemek zor olacaktır. Ayrıca, bu küçük ölçekli demo için, istemcileri rastgele örnekleme konusunda endişelenmeyeceğiz (bunu, diğer eğiticilerdeki şablonları izleyerek bu tür değişiklikleri keşfetmesi için kullanıcıya bir alıştırma olarak bırakıyoruz):

def random_batch():
  pixels = np.random.uniform(
      low=0.0, high=1.0, size=(50, 784)).astype(np.float32)
  labels = np.random.randint(low=0, high=9, size=(50,), dtype=np.int32)
  return collections.OrderedDict([('pixels', pixels), ('labels', labels)])

NUM_CLIENTS = 2
NUM_BATCHES = 10

train_data = [
    [random_batch() for _ in range(NUM_BATCHES)]
    for _ in range(NUM_CLIENTS)]

Bununla, aşağıdaki gibi tek bir eğitim aşaması gerçekleştirebiliriz:

trained_model = trainer.next(initial_model, train_data)
trained_model
Struct([('weights', array([[ 1.04456245e-04, -1.53498477e-05,  2.54597180e-05, ...,
         5.61640409e-05, -5.32875274e-05, -4.62881755e-04],
       [ 7.30908650e-05,  4.67643113e-05,  2.03352147e-06, ...,
         3.77510623e-05,  3.52839161e-05, -4.59865667e-04],
       [ 8.14835730e-05,  3.03147244e-05, -1.89143739e-05, ...,
         1.12527239e-04,  4.09212225e-06, -4.59960109e-04],
       ...,
       [ 9.23552434e-05,  2.44302555e-06, -2.20817346e-05, ...,
         7.61375341e-05,  1.76906979e-05, -4.43495519e-04],
       [ 1.17451040e-04,  2.47748958e-05,  1.04728279e-05, ...,
         5.26388249e-07,  7.21131510e-05, -4.67137404e-04],
       [ 3.75041491e-05,  6.58061981e-05,  1.14522081e-05, ...,
         2.52584141e-05,  3.55410739e-05, -4.30888613e-04]], dtype=float32)), ('bias', array([ 1.5096272e-04,  2.6502126e-05, -1.9462314e-05,  8.1269856e-05,
        2.1832302e-04,  1.6636557e-04,  1.2815947e-04,  9.0642272e-05,
        7.7109929e-05, -9.1987278e-04], dtype=float32))])

Eğitim adımının sonucunu değerlendirelim. İşi kolaylaştırmak için merkezi bir şekilde değerlendirebiliriz:

import itertools
eval_data = list(itertools.chain.from_iterable(train_data))

def average_loss(model, data):
  return np.mean([loss(model, batch) for batch in data])

print (average_loss(initial_model, eval_data))
print (average_loss(trained_model, eval_data))
2.3025854
2.282762

Kayıp azalıyor. Harika! Şimdi, bunu birden çok turda çalıştıralım:

NUM_ROUNDS = 20
for _ in range(NUM_ROUNDS):
  trained_model = trainer.next(trained_model, train_data)
  print(average_loss(trained_model, eval_data))
2.2685437
2.257856
2.2495182
2.2428129
2.2372835
2.2326245
2.2286277
2.2251441
2.2220676
2.219318
2.2168345
2.2145717
2.2124937
2.2105706
2.2087805
2.2071042
2.2055268
2.2040353
2.2026198
2.2012706

Gördüğünüz gibi, deneysel API'ler işlevsellik açısından TensorFlow API'leri ile aynı seviyede olmasa da, JAX'i TFF ile kullanmak o kadar da farklı değil.

Kaputun altında

Hazır API'mizi kullanmamayı tercih ederseniz, kendi özel hesaplamalarınızı, TensorFlow için özel algoritma eğitimlerinde gördüğünüz gibi, JAX'ın gradyan iniş mekanizmasını kullanmanın dışında uygulayabilirsiniz. Örneğin, modeli tek bir mini partide güncelleyen bir JAX hesaplamasını nasıl tanımlayabileceğiniz aşağıda açıklanmıştır:

@tff.experimental.jax_computation(MODEL_TYPE, BATCH_TYPE)
def train_on_one_batch(model, batch):
  grads = jax.api.grad(loss)(model, batch)
  return collections.OrderedDict([
      (k, model[k] - STEP_SIZE * grads[k]) for k in ['weights', 'bias']
  ])

Çalıştığını şu şekilde test edebilirsiniz:

sample_batch = random_batch()
trained_model = train_on_one_batch(initial_model, sample_batch)
print(average_loss(initial_model, [sample_batch]))
print(average_loss(trained_model, [sample_batch]))
2.3025854
2.2977567

JAX ile çalışmanın bir uyarısı,tf.data.Dataset eşdeğerinitf.data.Dataset . Bu nedenle, veri kümeleri üzerinde yineleme yapmak için, aşağıda gösterilenler gibi diziler üzerindeki işlemler için TFF'nin bildirim temelli sözleşmelerini kullanmanız gerekecektir:

@tff.federated_computation(MODEL_TYPE, tff.SequenceType(BATCH_TYPE))
def train_on_one_client(model, batches):
  return tff.sequence_reduce(batches, model, train_on_one_batch)

Bakalım işe yarıyor:

sample_dataset = [random_batch() for _ in range(100)]
trained_model = train_on_one_client(initial_model, sample_dataset)
print(average_loss(initial_model, sample_dataset))
print(average_loss(trained_model, sample_dataset))
2.3025854
2.2284968

Tek bir eğitim turu gerçekleştiren hesaplama, TensorFlow eğitimlerinde görmüş olabileceğinize benziyor:

@tff.federated_computation(
    tff.FederatedType(MODEL_TYPE, tff.SERVER),
    tff.FederatedType(tff.SequenceType(BATCH_TYPE), tff.CLIENTS))
def train_one_round(model, federated_data):
  locally_trained_models = tff.federated_map(
      train_on_one_client,
      collections.OrderedDict([
          ('model', tff.federated_broadcast(model)),
          ('batches', federated_data)]))
  return tff.federated_mean(locally_trained_models)

Bakalım işe yarıyor:

trained_model = train_one_round(initial_model, train_data)
print(average_loss(initial_model, eval_data))
print(average_loss(trained_model, eval_data))
2.3025854
2.282762

Gördüğünüz gibi, ister hazır API'ler aracılığıyla ister doğrudan düşük seviyeli TFF yapılarını kullanarak, TFF'de JAX kullanmak, TensorFlow ile TFF kullanmaya benzer. Gelecekteki güncellemeler için bizi izlemeye devam edin ve makine öğrenimi çerçevelerinde birlikte çalışabilirlik için daha iyi destek görmek istiyorsanız bize bir istek göndermekten çekinmeyin!