Halaman ini diterjemahkan oleh Cloud Translation API.
Switch to English

Pembelajaran Federasi untuk Klasifikasi Gambar

Lihat di TensorFlow.org Jalankan di Google Colab Lihat sumber di GitHub

Dalam tutorial ini, kami menggunakan contoh pelatihan MNIST klasik untuk memperkenalkan lapisan API Federated Learning (FL) dari TFF, tff.learning - satu set antarmuka tingkat tinggi yang dapat digunakan untuk melakukan jenis umum tugas pembelajaran federasi, seperti pelatihan federasi, terhadap model yang disediakan pengguna yang diterapkan di TensorFlow.

Tutorial ini, dan Federated Learning API, ditujukan terutama bagi pengguna yang ingin menyambungkan model TensorFlow mereka sendiri ke TFF, memperlakukan yang terakhir sebagian besar sebagai kotak hitam. Untuk pemahaman yang lebih mendalam tentang TFF dan cara menerapkan algoritme pembelajaran federasi Anda sendiri, lihat tutorial di FC Core API - Algoritme Federasi Kustom Bagian 1 dan Bagian 2 .

Untuk informasi lebih lanjut tentang tff.learning , lanjutkan dengan Federated Learning for Text Generation , tutorial yang selain mencakup model berulang, juga mendemonstrasikan pemuatan model Keras berseri yang telah dilatih sebelumnya untuk penyempurnaan dengan pembelajaran gabungan yang dikombinasikan dengan evaluasi menggunakan Keras.

Sebelum kita mulai

Sebelum kita mulai, jalankan perintah berikut untuk memastikan bahwa lingkungan Anda sudah diatur dengan benar. Jika Anda tidak melihat salam, lihat panduan Instalasi untuk instruksi.


!pip install --quiet --upgrade tensorflow_federated_nightly
!pip install --quiet --upgrade nest_asyncio

import nest_asyncio
nest_asyncio.apply()

%load_ext tensorboard
import collections

import numpy as np
import tensorflow as tf
import tensorflow_federated as tff

np.random.seed(0)

tff.federated_computation(lambda: 'Hello, World!')()
b'Hello, World!'

Mempersiapkan data masukan

Mari kita mulai dengan data. Pembelajaran federasi membutuhkan kumpulan data federasi, yaitu kumpulan data dari banyak pengguna. Data gabungan biasanya non- iid , yang menimbulkan serangkaian tantangan unik.

Untuk memfasilitasi eksperimen, kami mengunggulkan repositori TFF dengan beberapa set data, termasuk versi federasi dari MNIST yang berisi versi set data NIST asli yang telah diproses ulang menggunakan Leaf sehingga datanya dikunci oleh penulis asli digitnya. Karena setiap penulis memiliki gaya unik, kumpulan data ini menunjukkan jenis perilaku non-iid yang diharapkan dari kumpulan data federasi.

Begini cara kita memuatnya.

emnist_train, emnist_test = tff.simulation.datasets.emnist.load_data()

Kumpulan data yang dikembalikan oleh load_data() adalah turunan dari tff.simulation.ClientData , antarmuka yang memungkinkan Anda menghitung kumpulan pengguna, membuat tf.data.Dataset yang mewakili data pengguna tertentu, dan untuk membuat kueri struktur elemen individu. Berikut cara menggunakan antarmuka ini untuk menjelajahi konten kumpulan data. Ingatlah bahwa meskipun antarmuka ini memungkinkan Anda untuk melakukan iterasi pada id klien, ini hanyalah fitur dari data simulasi. Seperti yang akan Anda lihat sebentar lagi, identitas klien tidak digunakan oleh framework pembelajaran federasi - satu-satunya tujuan mereka adalah memungkinkan Anda memilih subset data untuk simulasi.

len(emnist_train.client_ids)
3383
emnist_train.element_type_structure
OrderedDict([('pixels', TensorSpec(shape=(28, 28), dtype=tf.float32, name=None)), ('label', TensorSpec(shape=(), dtype=tf.int32, name=None))])
example_dataset = emnist_train.create_tf_dataset_for_client(
    emnist_train.client_ids[0])

example_element = next(iter(example_dataset))

example_element['label'].numpy()
1
from matplotlib import pyplot as plt
plt.imshow(example_element['pixels'].numpy(), cmap='gray', aspect='equal')
plt.grid(False)
_ = plt.show()

png

Menjelajahi heterogenitas dalam data federasi

Data gabungan biasanya non- iid , pengguna biasanya memiliki distribusi data yang berbeda tergantung pada pola penggunaan. Beberapa klien mungkin memiliki lebih sedikit contoh pelatihan di perangkat, yang mengalami kekurangan data secara lokal, sementara beberapa klien akan memiliki lebih dari cukup contoh pelatihan. Mari jelajahi konsep heterogenitas data yang khas dari sistem federasi dengan data EMNIST yang kami miliki. Penting untuk diperhatikan bahwa analisis mendalam dari data klien ini hanya tersedia bagi kami karena ini adalah lingkungan simulasi tempat semua data tersedia untuk kami secara lokal. Dalam lingkungan federasi produksi nyata, Anda tidak akan dapat memeriksa data klien tunggal.

Pertama, mari kita ambil sampel dari satu data klien untuk merasakan contoh pada satu perangkat yang disimulasikan. Karena kumpulan data yang kami gunakan telah dikunci oleh penulis unik, data dari satu klien mewakili tulisan tangan satu orang untuk sampel dari angka 0 hingga 9, yang mensimulasikan "pola penggunaan" unik dari satu pengguna.

## Example MNIST digits for one client
figure = plt.figure(figsize=(20, 4))
j = 0

for example in example_dataset.take(40):
  plt.subplot(4, 10, j+1)
  plt.imshow(example['pixels'].numpy(), cmap='gray', aspect='equal')
  plt.axis('off')
  j += 1

png

Sekarang mari kita visualisasikan jumlah contoh pada setiap klien untuk setiap label digit MNIST. Di lingkungan federasi, jumlah contoh pada setiap klien dapat sedikit berbeda, bergantung pada perilaku pengguna.

# Number of examples per layer for a sample of clients
f = plt.figure(figsize=(12, 7))
f.suptitle('Label Counts for a Sample of Clients')
for i in range(6):
  client_dataset = emnist_train.create_tf_dataset_for_client(
      emnist_train.client_ids[i])
  plot_data = collections.defaultdict(list)
  for example in client_dataset:
    # Append counts individually per label to make plots
    # more colorful instead of one color per plot.
    label = example['label'].numpy()
    plot_data[label].append(label)
  plt.subplot(2, 3, i+1)
  plt.title('Client {}'.format(i))
  for j in range(10):
    plt.hist(
        plot_data[j],
        density=False,
        bins=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10])

png

Sekarang mari kita visualisasikan gambar rata-rata per klien untuk setiap label MNIST. Kode ini akan menghasilkan mean dari setiap nilai piksel untuk semua contoh pengguna untuk satu label. Kita akan melihat bahwa gambar rata-rata satu klien untuk satu digit akan terlihat berbeda dari gambar rata-rata klien lain untuk digit yang sama, karena gaya tulisan tangan masing-masing orang yang unik. Kita dapat merenungkan bagaimana setiap putaran pelatihan lokal akan mendorong model ke arah yang berbeda pada setiap klien, karena kita belajar dari data unik pengguna itu sendiri di babak lokal tersebut. Nanti di tutorial kita akan melihat bagaimana kita dapat mengambil setiap pembaruan ke model dari semua klien dan menggabungkannya menjadi model global baru kita, yang telah belajar dari setiap data unik klien kita sendiri.

# Each client has different mean images, meaning each client will be nudging
# the model in their own directions locally.

for i in range(5):
  client_dataset = emnist_train.create_tf_dataset_for_client(
      emnist_train.client_ids[i])
  plot_data = collections.defaultdict(list)
  for example in client_dataset:
    plot_data[example['label'].numpy()].append(example['pixels'].numpy())
  f = plt.figure(i, figsize=(12, 5))
  f.suptitle("Client #{}'s Mean Image Per Label".format(i))
  for j in range(10):
    mean_img = np.mean(plot_data[j], 0)
    plt.subplot(2, 5, j+1)
    plt.imshow(mean_img.reshape((28, 28)))
    plt.axis('off')

png

png

png

png

png

Data pengguna dapat berisik dan diberi label yang tidak dapat diandalkan. Sebagai contoh, dengan melihat data Klien # 2 di atas, kita dapat melihat bahwa untuk label 2, mungkin saja ada beberapa contoh yang salah label yang menciptakan gambar yang lebih berisik.

Memproses data masukan

Karena datanya sudah menjadi tf.data.Dataset , preprocessing dapat dilakukan menggunakan transformasi Dataset. Di sini, kita meratakan 28x28 gambar ke 784 array -element, mengocok contoh individu, mengatur mereka ke dalam batch, dan mengganti nama fitur dari pixels dan label untuk x dan y untuk digunakan dengan Keras. Kami juga melakukan repeat atas kumpulan data untuk menjalankan beberapa periode.

NUM_CLIENTS = 10
NUM_EPOCHS = 5
BATCH_SIZE = 20
SHUFFLE_BUFFER = 100
PREFETCH_BUFFER= 10

def preprocess(dataset):

  def batch_format_fn(element):
    """Flatten a batch `pixels` and return the features as an `OrderedDict`."""
    return collections.OrderedDict(
        x=tf.reshape(element['pixels'], [-1, 784]),
        y=tf.reshape(element['label'], [-1, 1]))

  return dataset.repeat(NUM_EPOCHS).shuffle(SHUFFLE_BUFFER).batch(
      BATCH_SIZE).map(batch_format_fn).prefetch(PREFETCH_BUFFER)

Mari kita verifikasi bahwa ini berhasil.

preprocessed_example_dataset = preprocess(example_dataset)

sample_batch = tf.nest.map_structure(lambda x: x.numpy(),
                                     next(iter(preprocessed_example_dataset)))

sample_batch
OrderedDict([('x', array([[1., 1., 1., ..., 1., 1., 1.],
       [1., 1., 1., ..., 1., 1., 1.],
       [1., 1., 1., ..., 1., 1., 1.],
       ...,
       [1., 1., 1., ..., 1., 1., 1.],
       [1., 1., 1., ..., 1., 1., 1.],
       [1., 1., 1., ..., 1., 1., 1.]], dtype=float32)), ('y', array([[2],
       [1],
       [2],
       [3],
       [6],
       [0],
       [1],
       [4],
       [1],
       [0],
       [6],
       [9],
       [9],
       [3],
       [6],
       [1],
       [4],
       [8],
       [0],
       [2]], dtype=int32))])

Kami memiliki hampir semua blok penyusun untuk membangun kumpulan data federasi.

Salah satu cara untuk memberi makan data federasi ke TFF dalam simulasi adalah sebagai daftar Python, dengan setiap elemen dari daftar tersebut menyimpan data pengguna individu, baik sebagai daftar atau sebagai tf.data.Dataset . Karena kita sudah memiliki antarmuka yang menyediakan yang terakhir, mari kita gunakan.

Berikut adalah fungsi pembantu sederhana yang akan membuat daftar kumpulan data dari kumpulan pengguna tertentu sebagai masukan untuk putaran pelatihan atau evaluasi.

def make_federated_data(client_data, client_ids):
  return [
      preprocess(client_data.create_tf_dataset_for_client(x))
      for x in client_ids
  ]

Sekarang, bagaimana kita memilih klien?

Dalam skenario pelatihan federasi yang khas, kita menghadapi kemungkinan populasi perangkat pengguna yang sangat besar, hanya sebagian kecil yang mungkin tersedia untuk pelatihan pada titik waktu tertentu. Hal ini terjadi, misalnya, saat perangkat klien adalah ponsel yang berpartisipasi dalam pelatihan hanya saat dicolokkan ke sumber daya, di luar jaringan terukur, dan menganggur.

Tentu saja, kami berada dalam lingkungan simulasi, dan semua data tersedia secara lokal. Biasanya, saat menjalankan simulasi, kami hanya akan mengambil sampel subset acak dari klien untuk dilibatkan dalam setiap putaran pelatihan, umumnya berbeda di setiap putaran.

Meskipun demikian, seperti yang dapat Anda ketahui dengan mempelajari makalah tentang algoritme Federated Averaging , mencapai konvergensi dalam sistem dengan subkumpulan klien yang diambil sampelnya secara acak di setiap putaran dapat memakan waktu cukup lama, dan tidak praktis untuk menjalankan ratusan putaran dalam tutorial interaktif ini.

Yang akan kita lakukan adalah mengambil sampel kumpulan klien satu kali, dan menggunakan kembali kumpulan yang sama di seluruh putaran untuk mempercepat konvergensi (sengaja terlalu pas untuk beberapa data pengguna ini). Kami membiarkannya sebagai latihan bagi pembaca untuk memodifikasi tutorial ini untuk mensimulasikan pengambilan sampel acak - ini cukup mudah dilakukan (setelah Anda melakukannya, perlu diingat bahwa untuk menyatukan model mungkin memerlukan beberapa saat).

sample_clients = emnist_train.client_ids[0:NUM_CLIENTS]

federated_train_data = make_federated_data(emnist_train, sample_clients)

print('Number of client datasets: {l}'.format(l=len(federated_train_data)))
print('First dataset: {d}'.format(d=federated_train_data[0]))
Number of client datasets: 10
First dataset: <DatasetV1Adapter shapes: OrderedDict([(x, (None, 784)), (y, (None, 1))]), types: OrderedDict([(x, tf.float32), (y, tf.int32)])>

Membuat model dengan Keras

Jika Anda menggunakan Keras, Anda mungkin sudah memiliki kode yang menyusun model Keras. Berikut adalah contoh model sederhana yang akan mencukupi kebutuhan kita.

def create_keras_model():
  return tf.keras.models.Sequential([
      tf.keras.layers.Input(shape=(784,)),
      tf.keras.layers.Dense(10, kernel_initializer='zeros'),
      tf.keras.layers.Softmax(),
  ])

Untuk menggunakan model apa pun dengan TFF, model tersebut perlu dibungkus dalam sebuah instance dari antarmuka tff.learning.Model , yang mengekspos metode untuk tff.learning.Model model, properti metadata, dll., Mirip dengan Keras, tetapi juga memperkenalkan tambahan elemen, seperti cara untuk mengontrol proses komputasi metrik federasi. Jangan khawatir tentang ini untuk saat ini; jika Anda memiliki model Keras seperti yang baru saja kami definisikan di atas, Anda dapat meminta TFF membungkusnya untuk Anda dengan menjalankan tff.learning.from_keras_model , meneruskan model dan sampel kumpulan data sebagai argumen, seperti yang ditunjukkan di bawah ini.

def model_fn():
  # We _must_ create a new model here, and _not_ capture it from an external
  # scope. TFF will call this within different graph contexts.
  keras_model = create_keras_model()
  return tff.learning.from_keras_model(
      keras_model,
      input_spec=preprocessed_example_dataset.element_spec,
      loss=tf.keras.losses.SparseCategoricalCrossentropy(),
      metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])

Melatih model pada data federasi

Sekarang kita memiliki model yang dibungkus sebagai tff.learning.Model untuk digunakan dengan TFF, kita dapat membiarkan TFF membangun algoritma Federated Averaging dengan menjalankan fungsi helper tff.learning.build_federated_averaging_process , sebagai berikut.

Perlu diingat bahwa argumen harus berupa konstruktor (seperti model_fn atas), bukan instance yang sudah dibuat, sehingga konstruksi model Anda dapat terjadi dalam konteks yang dikontrol oleh TFF (jika Anda penasaran tentang alasannya ini, kami mendorong Anda untuk membaca tutorial tindak lanjut tentang algoritme khusus ).

Satu catatan penting tentang algoritme Federated Averaging di bawah ini, ada 2 pengoptimal: pengoptimal _client dan pengoptimal _server. Pengoptimal _client hanya digunakan untuk menghitung pembaruan model lokal pada setiap klien. Pengoptimal _server menerapkan pembaruan rata-rata ke model global di server. Secara khusus, ini berarti bahwa pilihan pengoptimal dan kecepatan pembelajaran yang digunakan mungkin harus berbeda dari yang Anda gunakan untuk melatih model pada set data iid standar. Sebaiknya mulai dengan SGD biasa, mungkin dengan kecepatan pembelajaran yang lebih kecil dari biasanya. Kecepatan pembelajaran yang kami gunakan belum disetel dengan cermat, silakan bereksperimen.

iterative_process = tff.learning.build_federated_averaging_process(
    model_fn,
    client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=0.02),
    server_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=1.0))

Apa yang baru saja terjadi? TFF telah membuat sepasang komputasi federasi dan memaketkannya menjadi tff.templates.IterativeProcess di mana komputasi ini tersedia sebagai sepasang properti yang initialize dan next .

Singkatnya, komputasi federasi adalah program dalam bahasa internal TFF yang dapat mengekspresikan berbagai algoritme gabungan (Anda dapat menemukan lebih banyak tentang ini di tutorial algoritme kustom ). Dalam kasus ini, dua komputasi yang dihasilkan dan dikemas ke dalam iterative_process mengimplementasikan Federated Averaging .

Ini adalah tujuan TFF untuk mendefinisikan komputasi dengan cara yang dapat dieksekusi dalam pengaturan pembelajaran federasi yang sebenarnya, tetapi saat ini hanya runtime simulasi eksekusi lokal yang diterapkan. Untuk menjalankan komputasi dalam simulator, Anda cukup memanggilnya seperti fungsi Python. Lingkungan standar yang ditafsirkan ini tidak dirancang untuk kinerja tinggi, tetapi sudah cukup untuk tutorial ini; kami berharap dapat menyediakan runtime simulasi berkinerja lebih tinggi untuk memfasilitasi penelitian berskala lebih besar dalam rilis mendatang.

Mari kita mulai dengan initialize komputasi. Seperti halnya untuk semua komputasi federasi, Anda dapat menganggapnya sebagai fungsi. Komputasi tidak membutuhkan argumen, dan mengembalikan satu hasil - representasi dari status proses Federated Averaging di server. Meskipun kami tidak ingin menyelami detail TFF, mungkin ada baiknya untuk melihat seperti apa keadaan ini. Anda dapat memvisualisasikannya sebagai berikut.

str(iterative_process.initialize.type_signature)
'( -> <model=<trainable=<float32[784,10],float32[10]>,non_trainable=<>>,optimizer_state=<int64>,delta_aggregate_state=<>,model_broadcast_state=<>>@SERVER)'

Meskipun tanda tangan tipe di atas mungkin pada awalnya tampak agak samar, Anda dapat mengenali bahwa status server terdiri dari model (parameter model awal untuk MNIST yang akan didistribusikan ke semua perangkat), dan optimizer_state (informasi tambahan yang dikelola oleh server, seperti jumlah putaran yang akan digunakan untuk jadwal hyperparameter, dll.).

Mari panggil komputasi initialize untuk membangun status server.

state = iterative_process.initialize()

Yang kedua dari pasangan komputasi federasi, next , mewakili satu putaran Federated Averaging, yang terdiri dari mendorong status server (termasuk parameter model) ke klien, pelatihan di perangkat tentang data lokal mereka, mengumpulkan dan rata-rata pembaruan model , dan menghasilkan model baru yang diperbarui di server.

Secara konseptual, Anda dapat memikirkan next sebagai memiliki tanda tangan tipe fungsional yang terlihat sebagai berikut.

SERVER_STATE, FEDERATED_DATA -> SERVER_STATE, TRAINING_METRICS

Secara khusus, orang harus berpikir tentang next() bukan sebagai fungsi yang berjalan di server, melainkan sebagai representasi fungsional deklaratif dari seluruh komputasi terdesentralisasi - beberapa input disediakan oleh server ( SERVER_STATE ), tetapi masing-masing berpartisipasi perangkat menyumbangkan kumpulan data lokalnya sendiri.

Mari kita jalankan satu putaran pelatihan dan visualisasikan hasilnya. Kita dapat menggunakan data federasi yang telah kita buat di atas untuk sampel pengguna.

state, metrics = iterative_process.next(state, federated_train_data)
print('round  1, metrics={}'.format(metrics))
round  1, metrics=<broadcast=<>,aggregation=<>,train=<sparse_categorical_accuracy=0.12037037312984467,loss=3.0108425617218018>>

Mari kita jalankan beberapa putaran lagi. Seperti disebutkan sebelumnya, biasanya pada titik ini Anda akan memilih subset dari data simulasi Anda dari sampel pengguna baru yang dipilih secara acak untuk setiap putaran untuk mensimulasikan penerapan realistis di mana pengguna terus menerus datang dan pergi, tetapi di notebook interaktif ini, untuk demi demonstrasi kami hanya akan menggunakan kembali pengguna yang sama, sehingga sistem dapat menyatu dengan cepat.

NUM_ROUNDS = 11
for round_num in range(2, NUM_ROUNDS):
  state, metrics = iterative_process.next(state, federated_train_data)
  print('round {:2d}, metrics={}'.format(round_num, metrics))
round  2, metrics=<broadcast=<>,aggregation=<>,train=<sparse_categorical_accuracy=0.14814814925193787,loss=2.8865506649017334>>
round  3, metrics=<broadcast=<>,aggregation=<>,train=<sparse_categorical_accuracy=0.148765429854393,loss=2.9079062938690186>>
round  4, metrics=<broadcast=<>,aggregation=<>,train=<sparse_categorical_accuracy=0.17633745074272156,loss=2.724686622619629>>
round  5, metrics=<broadcast=<>,aggregation=<>,train=<sparse_categorical_accuracy=0.20226337015628815,loss=2.6334855556488037>>
round  6, metrics=<broadcast=<>,aggregation=<>,train=<sparse_categorical_accuracy=0.22427983582019806,loss=2.5482592582702637>>
round  7, metrics=<broadcast=<>,aggregation=<>,train=<sparse_categorical_accuracy=0.24094650149345398,loss=2.4472343921661377>>
round  8, metrics=<broadcast=<>,aggregation=<>,train=<sparse_categorical_accuracy=0.259876549243927,loss=2.3809611797332764>>
round  9, metrics=<broadcast=<>,aggregation=<>,train=<sparse_categorical_accuracy=0.29814815521240234,loss=2.156442403793335>>
round 10, metrics=<broadcast=<>,aggregation=<>,train=<sparse_categorical_accuracy=0.31687241792678833,loss=2.122845411300659>>

Kerugian pelatihan menurun setelah setiap putaran pelatihan federasi, yang menunjukkan bahwa model tersebut konvergen. Ada beberapa peringatan penting dengan metrik pelatihan ini, namun, lihat bagian Evaluasi nanti di tutorial ini.

Menampilkan metrik model di TensorBoard

Selanjutnya, mari kita visualisasikan metrik dari komputasi federasi ini menggunakan Tensorboard.

Mari kita mulai dengan membuat direktori dan penulis ringkasan yang sesuai untuk menulis metrik.


logdir = "/tmp/logs/scalars/training/"
summary_writer = tf.summary.create_file_writer(logdir)
state = iterative_process.initialize()

Plotkan metrik skalar yang relevan dengan penulis ringkasan yang sama.


with summary_writer.as_default():
  for round_num in range(1, NUM_ROUNDS):
    state, metrics = iterative_process.next(state, federated_train_data)
    for name, value in metrics.train._asdict().items():
      tf.summary.scalar(name, value, step=round_num)

Mulai TensorBoard dengan direktori log root yang ditentukan di atas. Perlu waktu beberapa detik untuk memuat data.


%tensorboard --logdir /tmp/logs/scalars/ --port=0

# Run this this cell to clean your directory of old output for future graphs from this directory.
!rm -R /tmp/logs/scalars/*

Untuk melihat metrik evaluasi dengan cara yang sama, Anda dapat membuat folder eval terpisah, seperti "logs / scalars / eval", untuk menulis ke TensorBoard.

Menyesuaikan implementasi model

Keras adalah API model tingkat tinggi yang direkomendasikan untuk TensorFlow , dan kami mendorong penggunaan model Keras (melalui tff.learning.from_keras_model ) di TFF jika memungkinkan.

Namun, tff.learning menyediakan antarmuka model tingkat rendah, tff.learning.Model , yang memperlihatkan fungsionalitas minimal yang diperlukan untuk menggunakan model untuk pembelajaran federasi. Mengimplementasikan antarmuka ini secara langsung (mungkin masih menggunakan blok bangunan seperti tf.keras.layers ) memungkinkan kustomisasi maksimum tanpa memodifikasi internal algoritme pembelajaran tf.keras.layers .

Jadi mari kita lakukan lagi dari awal.

Mendefinisikan variabel model, forward pass, dan metrik

Langkah pertama adalah mengidentifikasi variabel TensorFlow yang akan kita gunakan. Untuk membuat kode berikut lebih terbaca, mari kita tentukan struktur data untuk mewakili keseluruhan set. Ini akan mencakup variabel seperti weights dan bias bahwa kita akan melatih, serta variabel yang akan menggelar berbagai statistik kumulatif dan counter kami akan memperbarui selama pelatihan, seperti loss_sum , accuracy_sum , dan num_examples .

MnistVariables = collections.namedtuple(
    'MnistVariables', 'weights bias num_examples loss_sum accuracy_sum')

Berikut adalah metode yang membuat variabel. Demi kesederhanaan, kami menampilkan semua statistik sebagai tf.float32 , karena itu akan menghilangkan kebutuhan untuk konversi jenis di tahap selanjutnya. Membungkus penginisialisasi variabel sebagai lambda adalah persyaratan yang diberlakukan oleh variabel sumber daya .

def create_mnist_variables():
  return MnistVariables(
      weights=tf.Variable(
          lambda: tf.zeros(dtype=tf.float32, shape=(784, 10)),
          name='weights',
          trainable=True),
      bias=tf.Variable(
          lambda: tf.zeros(dtype=tf.float32, shape=(10)),
          name='bias',
          trainable=True),
      num_examples=tf.Variable(0.0, name='num_examples', trainable=False),
      loss_sum=tf.Variable(0.0, name='loss_sum', trainable=False),
      accuracy_sum=tf.Variable(0.0, name='accuracy_sum', trainable=False))

Dengan adanya variabel untuk parameter model dan statistik kumulatif, sekarang kita dapat menentukan metode penerusan yang menghitung kerugian, memancarkan prediksi, dan memperbarui statistik kumulatif untuk satu batch data masukan, sebagai berikut.

def mnist_forward_pass(variables, batch):
  y = tf.nn.softmax(tf.matmul(batch['x'], variables.weights) + variables.bias)
  predictions = tf.cast(tf.argmax(y, 1), tf.int32)

  flat_labels = tf.reshape(batch['y'], [-1])
  loss = -tf.reduce_mean(
      tf.reduce_sum(tf.one_hot(flat_labels, 10) * tf.math.log(y), axis=[1]))
  accuracy = tf.reduce_mean(
      tf.cast(tf.equal(predictions, flat_labels), tf.float32))

  num_examples = tf.cast(tf.size(batch['y']), tf.float32)

  variables.num_examples.assign_add(num_examples)
  variables.loss_sum.assign_add(loss * num_examples)
  variables.accuracy_sum.assign_add(accuracy * num_examples)

  return loss, predictions

Selanjutnya, kita mendefinisikan fungsi yang mengembalikan sekumpulan metrik lokal, sekali lagi menggunakan TensorFlow. Ini adalah nilai (selain pembaruan model, yang ditangani secara otomatis) yang memenuhi syarat untuk digabungkan ke server dalam proses pembelajaran atau evaluasi gabungan.

Di sini, kami hanya mengembalikan loss dan accuracy rata-rata, serta num_examples , yang kami perlukan untuk memberi bobot yang tepat pada kontribusi dari pengguna yang berbeda saat menghitung gabungan gabungan.

def get_local_mnist_metrics(variables):
  return collections.OrderedDict(
      num_examples=variables.num_examples,
      loss=variables.loss_sum / variables.num_examples,
      accuracy=variables.accuracy_sum / variables.num_examples)

Terakhir, kita perlu menentukan cara menggabungkan metrik lokal yang dipancarkan oleh setiap perangkat melalui get_local_mnist_metrics . Ini adalah satu-satunya bagian kode yang tidak ditulis di TensorFlow - ini adalah komputasi federasi yang diekspresikan dalam TFF. Jika Anda ingin menggali lebih dalam, baca sekilas tutorial algoritme khusus , tetapi di sebagian besar aplikasi, Anda tidak perlu melakukannya; varian dari pola yang ditunjukkan di bawah ini sudah cukup. Berikut tampilannya:

@tff.federated_computation
def aggregate_mnist_metrics_across_clients(metrics):
  return collections.OrderedDict(
      num_examples=tff.federated_sum(metrics.num_examples),
      loss=tff.federated_mean(metrics.loss, metrics.num_examples),
      accuracy=tff.federated_mean(metrics.accuracy, metrics.num_examples))
  

Argumen metrics input sesuai dengan OrderedDict yang OrderedDict oleh get_local_mnist_metrics atas, tetapi yang terpenting, nilainya tidak lagi tf.Tensors - mereka " tff.Value " sebagai tff.Value s, untuk memperjelas, Anda tidak dapat lagi memanipulasinya menggunakan TensorFlow, tetapi hanya menggunakan operator federasi TFF seperti tff.federated_mean dan tff.federated_sum . Kamus yang dikembalikan dari agregat global mendefinisikan kumpulan metrik yang akan tersedia di server.

tff.learning.Model instance tff.learning.Model

Dengan semua hal di atas, kami siap untuk membuat representasi model untuk digunakan dengan TFF mirip dengan yang dihasilkan untuk Anda saat Anda membiarkan TFF menelan model Keras.

class MnistModel(tff.learning.Model):

  def __init__(self):
    self._variables = create_mnist_variables()

  @property
  def trainable_variables(self):
    return [self._variables.weights, self._variables.bias]

  @property
  def non_trainable_variables(self):
    return []

  @property
  def local_variables(self):
    return [
        self._variables.num_examples, self._variables.loss_sum,
        self._variables.accuracy_sum
    ]

  @property
  def input_spec(self):
    return collections.OrderedDict(
        x=tf.TensorSpec([None, 784], tf.float32),
        y=tf.TensorSpec([None, 1], tf.int32))

  @tf.function
  def forward_pass(self, batch, training=True):
    del training
    loss, predictions = mnist_forward_pass(self._variables, batch)
    num_exmaples = tf.shape(batch['x'])[0]
    return tff.learning.BatchOutput(
        loss=loss, predictions=predictions, num_examples=num_exmaples)

  @tf.function
  def report_local_outputs(self):
    return get_local_mnist_metrics(self._variables)

  @property
  def federated_output_computation(self):
    return aggregate_mnist_metrics_across_clients

Seperti yang Anda lihat, metode dan properti abstrak yang ditentukan oleh tff.learning.Model sesuai dengan cuplikan kode di bagian sebelumnya yang memperkenalkan variabel dan menentukan kerugian dan statistik.

Berikut beberapa poin yang perlu disoroti:

  • Semua status yang akan digunakan model Anda harus ditangkap sebagai variabel TensorFlow, karena TFF tidak menggunakan Python pada waktu proses (ingat kode Anda harus ditulis sedemikian rupa sehingga dapat diterapkan ke perangkat seluler; lihat tutorial algoritme khusus untuk lebih mendalam komentar tentang alasan).
  • Model Anda harus mendeskripsikan bentuk data apa yang diterima ( input_spec ), karena secara umum, TFF adalah lingkungan yang diketik dengan kuat dan ingin menentukan jenis tanda tangan untuk semua komponen. Mendeklarasikan format masukan model Anda adalah bagian penting darinya.
  • Meskipun secara teknis tidak diperlukan, sebaiknya gabungkan semua logika TensorFlow (penerusan maju, penghitungan metrik, dll.) Sebagai tf.function s, karena ini membantu memastikan TensorFlow dapat diserialkan, dan menghilangkan kebutuhan akan dependensi kontrol eksplisit.

Di atas cukup untuk evaluasi dan algoritma seperti Federated SGD. Namun, untuk Federated Averaging, kita perlu menentukan bagaimana model harus dilatih secara lokal pada setiap batch. Kami akan menentukan pengoptimal lokal saat membuat algoritme Federated Averaging.

Mensimulasikan pelatihan federasi dengan model baru

Dengan semua hal di atas, sisa proses terlihat seperti yang telah kita lihat - cukup ganti konstruktor model dengan konstruktor kelas model baru kita, dan gunakan dua komputasi federasi dalam proses iteratif yang Anda buat untuk menggilir putaran pelatihan.

iterative_process = tff.learning.build_federated_averaging_process(
    MnistModel,
    client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=0.02))
state = iterative_process.initialize()
state, metrics = iterative_process.next(state, federated_train_data)
print('round  1, metrics={}'.format(metrics))
round  1, metrics=<broadcast=<>,aggregation=<>,train=<num_examples=4860.0,loss=2.9713594913482666,accuracy=0.13518518209457397>>

for round_num in range(2, 11):
  state, metrics = iterative_process.next(state, federated_train_data)
  print('round {:2d}, metrics={}'.format(round_num, metrics))
round  2, metrics=<broadcast=<>,aggregation=<>,train=<num_examples=4860.0,loss=2.975412607192993,accuracy=0.14032921195030212>>
round  3, metrics=<broadcast=<>,aggregation=<>,train=<num_examples=4860.0,loss=2.9395227432250977,accuracy=0.1594650149345398>>
round  4, metrics=<broadcast=<>,aggregation=<>,train=<num_examples=4860.0,loss=2.710164785385132,accuracy=0.17139917612075806>>
round  5, metrics=<broadcast=<>,aggregation=<>,train=<num_examples=4860.0,loss=2.5891618728637695,accuracy=0.20267489552497864>>
round  6, metrics=<broadcast=<>,aggregation=<>,train=<num_examples=4860.0,loss=2.5148487091064453,accuracy=0.21666666865348816>>
round  7, metrics=<broadcast=<>,aggregation=<>,train=<num_examples=4860.0,loss=2.2816808223724365,accuracy=0.2580246925354004>>
round  8, metrics=<broadcast=<>,aggregation=<>,train=<num_examples=4860.0,loss=2.3656885623931885,accuracy=0.25884774327278137>>
round  9, metrics=<broadcast=<>,aggregation=<>,train=<num_examples=4860.0,loss=2.23549222946167,accuracy=0.28477364778518677>>
round 10, metrics=<broadcast=<>,aggregation=<>,train=<num_examples=4860.0,loss=1.974222183227539,accuracy=0.35329216718673706>>

Untuk melihat metrik ini dalam TensorBoard, lihat langkah-langkah yang tercantum di atas dalam "Menampilkan metrik model di TensorBoard".

Evaluasi

Semua eksperimen kami sejauh ini hanya menyajikan metrik pelatihan federasi - metrik rata-rata dari semua kumpulan data yang dilatih di semua klien dalam putaran tersebut. Ini memperkenalkan kekhawatiran normal tentang overfitting, terutama karena kami menggunakan kumpulan klien yang sama di setiap putaran untuk kesederhanaan, tetapi ada gagasan tambahan tentang overfitting dalam metrik pelatihan yang khusus untuk algoritme Federated Averaging. Ini paling mudah dilihat jika kita membayangkan setiap klien memiliki satu batch data, dan kita melatih batch tersebut untuk banyak iterasi (epoch). Dalam kasus ini, model lokal akan dengan cepat sesuai dengan satu kelompok tersebut, sehingga metrik akurasi lokal yang kami rata-rata akan mendekati 1,0. Dengan demikian, metrik pelatihan ini dapat dianggap sebagai tanda bahwa pelatihan sedang berlangsung, tetapi tidak lebih.

Untuk melakukan evaluasi pada data federasi, Anda bisa membuat komputasi federasi lain yang dirancang hanya untuk tujuan ini, menggunakan fungsi tff.learning.build_federated_evaluation , dan meneruskan konstruktor model Anda sebagai argumen. Perhatikan bahwa tidak seperti Federated Averaging, di mana kita telah menggunakan MnistTrainableModel , itu sudah cukup untuk meneruskan MnistModel . Evaluasi tidak melakukan penurunan gradien, dan tidak perlu membuat pengoptimal.

Untuk eksperimen dan penelitian, saat set data pengujian terpusat tersedia, Federated Learning for Text Generation mendemonstrasikan opsi evaluasi lain: mengambil bobot yang dilatih dari pembelajaran federasi, menerapkannya ke model Keras standar, lalu memanggil tf.keras.models.Model.evaluate() pada tf.keras.models.Model.evaluate() data terpusat.

evaluation = tff.learning.build_federated_evaluation(MnistModel)

Anda dapat memeriksa tanda tangan tipe abstrak dari fungsi evaluasi sebagai berikut.

str(evaluation.type_signature)
'(<<trainable=<float32[784,10],float32[10]>,non_trainable=<>>@SERVER,{<x=float32[?,784],y=int32[?,1]>*}@CLIENTS> -> <num_examples=float32@SERVER,loss=float32@SERVER,accuracy=float32@SERVER>)'

Tidak perlu khawatir tentang detailnya pada saat ini, cukup perhatikan bahwa ini mengambil bentuk umum berikut, mirip dengan tff.templates.IterativeProcess.next tetapi dengan dua perbedaan penting. Pertama, kami tidak menampilkan status server, karena evaluasi tidak mengubah model atau aspek status lainnya - Anda dapat menganggapnya sebagai stateless. Kedua, evaluasi hanya membutuhkan model, dan tidak memerlukan bagian lain dari status server yang mungkin terkait dengan pelatihan, seperti variabel pengoptimal.

SERVER_MODEL, FEDERATED_DATA -> TRAINING_METRICS

Mari minta evaluasi tentang keadaan terakhir yang kita capai selama pelatihan. Untuk mengekstrak model terlatih terbaru dari status server, Anda cukup mengakses anggota .model , sebagai berikut.

train_metrics = evaluation(state.model, federated_train_data)

Inilah yang kami dapatkan. Perhatikan bahwa angkanya terlihat sedikit lebih baik daripada yang dilaporkan pada putaran terakhir pelatihan di atas. Berdasarkan konvensi, metrik pelatihan yang dilaporkan oleh proses pelatihan berulang umumnya mencerminkan kinerja model di awal babak pelatihan, sehingga metrik evaluasi akan selalu selangkah lebih maju.

str(train_metrics)
'<num_examples=4860.0,loss=1.7142657041549683,accuracy=0.38683128356933594>'

Sekarang, mari kita mengumpulkan sampel uji dari data federasi dan menjalankan kembali evaluasi pada data uji. Data akan berasal dari sampel pengguna nyata yang sama, tetapi dari kumpulan data yang berbeda.

federated_test_data = make_federated_data(emnist_test, sample_clients)

len(federated_test_data), federated_test_data[0]
(10,
 <DatasetV1Adapter shapes: OrderedDict([(x, (None, 784)), (y, (None, 1))]), types: OrderedDict([(x, tf.float32), (y, tf.int32)])>)
test_metrics = evaluation(state.model, federated_test_data)
str(test_metrics)
'<num_examples=580.0,loss=1.861915111541748,accuracy=0.3362068831920624>'

Ini mengakhiri tutorial. Kami mendorong Anda untuk bermain dengan parameter (mis., Ukuran kelompok, jumlah pengguna, masa, kecepatan pembelajaran, dll.), Untuk memodifikasi kode di atas untuk mensimulasikan pelatihan pada sampel acak pengguna di setiap putaran, dan untuk menjelajahi tutorial lainnya kami telah berkembang.