Bekerja dengan ClientData tff.

Lihat di TensorFlow.org Jalankan di Google Colab Lihat sumber di GitHub Unduh buku catatan

Gagasan tentang kumpulan data yang dikunci oleh klien (misalnya pengguna) sangat penting untuk komputasi gabungan seperti yang dimodelkan dalam TFF. TFF menyediakan antarmuka tff.simulation.datasets.ClientData untuk abstrak lebih konsep ini, dan yang TFF host (dataset stackoverflow , shakespeare , emnist , cifar100 , dan gldv2 ) semua mengimplementasikan interface ini.

Jika Anda bekerja pada pembelajaran Federasi dengan dataset Anda sendiri, TFF sangat mendorong Anda untuk baik melaksanakan ClientData antarmuka atau menggunakan salah satu dari fungsi pembantu TFF untuk menghasilkan ClientData yang mewakili data Anda pada disk, misalnya tff.simulation.datasets.ClientData.from_clients_and_fn .

Karena kebanyakan dari TFF contoh end-to-end mulai dengan ClientData objek, menerapkan ClientData antarmuka dengan dataset kustom Anda akan membuat lebih mudah untuk spelunk melalui kode yang ada ditulis dengan TFF. Selanjutnya, tf.data.Datasets yang ClientData konstruksi dapat mengulangi lebih langsung untuk menghasilkan struktur numpy array, sehingga ClientData benda dapat digunakan dengan kerangka ML berbasis Python sebelum pindah ke TFF.

Ada beberapa pola yang dapat Anda gunakan untuk membuat hidup Anda lebih mudah jika Anda berniat untuk meningkatkan simulasi Anda ke banyak mesin atau menerapkannya. Di bawah ini kami akan berjalan melalui beberapa cara kita dapat menggunakan ClientData dan TFF untuk membuat skala kecil iterasi-to skala besar eksperimen-produksi pengalaman penyebaran kami sebagai halus mungkin.

Pola mana yang harus saya gunakan untuk meneruskan ClientData ke TFF?

Kita akan membahas dua penggunaan dari TFF ClientData secara mendalam; jika Anda termasuk dalam salah satu dari dua kategori di bawah ini, Anda jelas akan lebih memilih satu dari yang lain. Jika tidak, Anda mungkin memerlukan pemahaman yang lebih rinci tentang pro dan kontra dari masing-masing untuk membuat pilihan yang lebih bernuansa.

  • Saya ingin mengulangi secepat mungkin di mesin lokal; Saya tidak perlu dapat dengan mudah memanfaatkan runtime terdistribusi TFF.

    • Anda ingin lulus tf.data.Datasets ke TFF langsung.
    • Hal ini memungkinkan Anda untuk program imperatif dengan tf.data.Dataset benda, dan proses mereka sewenang-wenang.
    • Ini memberikan lebih banyak fleksibilitas daripada opsi di bawah ini; mendorong logika ke klien mengharuskan logika ini dapat serial.
  • Saya ingin menjalankan komputasi gabungan saya di runtime jarak jauh TFF, atau saya berencana untuk melakukannya segera.

    • Dalam hal ini Anda ingin memetakan konstruksi set data dan prapemrosesan ke klien.
    • Hasil dalam Anda ini melewati hanya daftar client_ids langsung ke perhitungan federasi Anda.
    • Mendorong konstruksi set data dan prapemrosesan ke klien menghindari kemacetan dalam serialisasi, dan secara signifikan meningkatkan kinerja dengan ratusan hingga ribuan klien.

Siapkan lingkungan sumber terbuka

paket impor

Memanipulasi objek ClientData

Mari kita mulai dengan bongkar menjelajahi TFF EMNIST ClientData :

client_data, _ = tff.simulation.datasets.emnist.load_data()
Downloading emnist_all.sqlite.lzma: 100%|██████████| 170507172/170507172 [00:19<00:00, 8831921.67it/s]
2021-10-01 11:17:58.718735: E tensorflow/stream_executor/cuda/cuda_driver.cc:271] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected

Memeriksa dataset pertama dapat memberitahu kita apa jenis contoh adalah di ClientData .

first_client_id = client_data.client_ids[0]
first_client_dataset = client_data.create_tf_dataset_for_client(
    first_client_id)
print(first_client_dataset.element_spec)
# This information is also available as a `ClientData` property:
assert client_data.element_type_structure == first_client_dataset.element_spec
OrderedDict([('label', TensorSpec(shape=(), dtype=tf.int32, name=None)), ('pixels', TensorSpec(shape=(28, 28), dtype=tf.float32, name=None))])

Perhatikan bahwa hasil dataset collections.OrderedDict objek yang memiliki pixels dan label kunci, di mana pixel adalah tensor dengan bentuk [28, 28] . Misalkan kita ingin meratakan masukan kami keluar ke bentuk [784] . Salah satu cara yang mungkin bisa kita lakukan ini akan menjadi untuk menerapkan fungsi pre-processing untuk kami ClientData objek.

def preprocess_dataset(dataset):
  """Create batches of 5 examples, and limit to 3 batches."""

  def map_fn(input):
    return collections.OrderedDict(
        x=tf.reshape(input['pixels'], shape=(-1, 784)),
        y=tf.cast(tf.reshape(input['label'], shape=(-1, 1)), tf.int64),
    )

  return dataset.batch(5).map(
      map_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE).take(5)


preprocessed_client_data = client_data.preprocess(preprocess_dataset)

# Notice that we have both reshaped and renamed the elements of the ordered dict.
first_client_dataset = preprocessed_client_data.create_tf_dataset_for_client(
    first_client_id)
print(first_client_dataset.element_spec)
OrderedDict([('x', TensorSpec(shape=(None, 784), dtype=tf.float32, name=None)), ('y', TensorSpec(shape=(None, 1), dtype=tf.int64, name=None))])

Kami mungkin ingin selain melakukan beberapa pemrosesan awal yang lebih kompleks (dan mungkin stateful), misalnya pengocokan.

def preprocess_and_shuffle(dataset):
  """Applies `preprocess_dataset` above and shuffles the result."""
  preprocessed = preprocess_dataset(dataset)
  return preprocessed.shuffle(buffer_size=5)

preprocessed_and_shuffled = client_data.preprocess(preprocess_and_shuffle)

# The type signature will remain the same, but the batches will be shuffled.
first_client_dataset = preprocessed_and_shuffled.create_tf_dataset_for_client(
    first_client_id)
print(first_client_dataset.element_spec)
OrderedDict([('x', TensorSpec(shape=(None, 784), dtype=tf.float32, name=None)), ('y', TensorSpec(shape=(None, 1), dtype=tf.int64, name=None))])

Berinteraksi dengan tff.Computation

Sekarang kita dapat melakukan beberapa manipulasi dasar dengan ClientData objek, kami siap untuk data umpan ke tff.Computation . Kami mendefinisikan tff.templates.IterativeProcess yang mengimplementasikan Federasi Averaging , dan mengeksplorasi metode yang berbeda lewat itu data.

def model_fn():
  model = tf.keras.models.Sequential([
      tf.keras.layers.InputLayer(input_shape=(784,)),
      tf.keras.layers.Dense(10, kernel_initializer='zeros'),
  ])
  return tff.learning.from_keras_model(
      model,
      # Note: input spec is the _batched_ shape, and includes the 
      # label tensor which will be passed to the loss function. This model is
      # therefore configured to accept data _after_ it has been preprocessed.
      input_spec=collections.OrderedDict(
          x=tf.TensorSpec(shape=[None, 784], dtype=tf.float32),
          y=tf.TensorSpec(shape=[None, 1], dtype=tf.int64)),
      loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
      metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])

trainer = tff.learning.build_federated_averaging_process(
    model_fn,
    client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=0.01))

Sebelum kita mulai bekerja dengan ini IterativeProcess , satu komentar pada semantik ClientData adalah dalam rangka. Sebuah ClientData objek mewakili keseluruhan dari populasi yang tersedia untuk pelatihan federasi, yang pada umumnya tidak tersedia untuk lingkungan eksekusi dari sistem produksi FL dan khusus untuk simulasi. ClientData memang memberikan pengguna kemampuan untuk komputasi federasi memotong sepenuhnya dan hanya melatih model server-side seperti biasa melalui ClientData.create_tf_dataset_from_all_clients .

Lingkungan simulasi TFF menempatkan peneliti dalam kendali penuh atas loop luar. Secara khusus ini menyiratkan pertimbangan ketersediaan klien, klien putus sekolah, dll, harus ditangani oleh pengguna atau skrip driver Python. Satu bisa misalnya model client putus sekolah dengan menyesuaikan distribusi sampling atas Anda ClientData's client_ids sehingga pengguna dengan data yang lebih (dan Sejalan lagi berjalan perhitungan lokal) akan dipilih dengan probabilitas yang lebih rendah.

Namun, dalam sistem federasi nyata, klien tidak dapat dipilih secara eksplisit oleh pelatih model; pemilihan klien didelegasikan ke sistem yang menjalankan komputasi gabungan.

Melewati tf.data.Datasets langsung ke TFF

Salah satu pilihan yang kita miliki untuk interfacing antara ClientData dan IterativeProcess adalah bahwa membangun tf.data.Datasets di Python, dan melewati dataset ini untuk TFF.

Perhatikan bahwa jika kita menggunakan preprocessed kami ClientData dataset kami menghasilkan adalah dari jenis yang sesuai yang diharapkan oleh model kami yang didefinisikan di atas.

selected_client_ids = preprocessed_and_shuffled.client_ids[:10]

preprocessed_data_for_clients = [
    preprocessed_and_shuffled.create_tf_dataset_for_client(
        selected_client_ids[i]) for i in range(10)
]

state = trainer.initialize()
for _ in range(5):
  t1 = time.time()
  state, metrics = trainer.next(state, preprocessed_data_for_clients)
  t2 = time.time()
  print('loss {}, round time {}'.format(metrics['train']['loss'], t2 - t1))
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow_federated/python/core/impl/compiler/tensorflow_computation_transformations.py:62: extract_sub_graph (from tensorflow.python.framework.graph_util_impl) is deprecated and will be removed in a future version.
Instructions for updating:
Use `tf.compat.v1.graph_util.extract_sub_graph`
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow_federated/python/core/impl/compiler/tensorflow_computation_transformations.py:62: extract_sub_graph (from tensorflow.python.framework.graph_util_impl) is deprecated and will be removed in a future version.
Instructions for updating:
Use `tf.compat.v1.graph_util.extract_sub_graph`
loss 2.9005744457244873, round time 4.576513767242432
loss 3.113278388977051, round time 0.49641919136047363
loss 2.7581865787506104, round time 0.4904160499572754
loss 2.87259578704834, round time 0.48976993560791016
loss 3.1202380657196045, round time 0.6724586486816406

Jika kita mengambil rute ini, namun, kami tidak akan dapat sepele pindah ke simulasi MULTIMESIN. Dataset kita membangun di runtime TensorFlow lokal dapat menangkap negara dari lingkungan python sekitarnya, dan gagal dalam serialisasi atau deserialization ketika mereka mencoba untuk negara referensi yang tidak lagi tersedia untuk mereka. Hal ini dapat terwujud misalnya dalam kesalahan ajaib dari TensorFlow ini tensor_util.cc :

Check failed: DT_VARIANT == input.dtype() (21 vs. 20)

Pemetaan konstruksi dan preprocessing atas klien

Untuk menghindari masalah ini, TFF merekomendasikan penggunanya untuk mempertimbangkan dataset Instansiasi dan preprocessing sebagai sesuatu yang terjadi secara lokal pada setiap klien, dan menggunakan pembantu TFF atau federated_map secara eksplisit menjalankan ini kode preprocessing pada setiap klien.

Secara konseptual, alasan untuk memilih ini jelas: dalam runtime lokal TFF, klien hanya "secara tidak sengaja" memiliki akses ke lingkungan Python global karena fakta bahwa seluruh orkestrasi federasi terjadi pada satu mesin. Perlu dicatat pada titik ini bahwa pemikiran serupa memunculkan filosofi fungsional lintas platform, selalu serialisasi, dan fungsional TFF.

TFF membuat perubahan yang sederhana melalui ClientData's atribut dataset_computation , sebuah tff.Computation yang membutuhkan client_id dan mengembalikan terkait tf.data.Dataset .

Perhatikan bahwa preprocess hanya bekerja dengan dataset_computation ; yang dataset_computation atribut dari preprocessed ClientData menggabungkan seluruh pipa preprocessing kita hanya didefinisikan:

print('dataset computation without preprocessing:')
print(client_data.dataset_computation.type_signature)
print('\n')
print('dataset computation with preprocessing:')
print(preprocessed_and_shuffled.dataset_computation.type_signature)
dataset computation without preprocessing:
(string -> <label=int32,pixels=float32[28,28]>*)


dataset computation with preprocessing:
(string -> <x=float32[?,784],y=int64[?,1]>*)

Kita bisa memanggil dataset_computation dan menerima dataset bersemangat dalam runtime Python, tapi kekuatan nyata dari pendekatan ini dilaksanakan ketika kita menulis dengan proses berulang atau perhitungan lain untuk menghindari mewujudkan dataset ini dalam runtime bersemangat global yang sama sekali. TFF menyediakan fungsi pembantu tff.simulation.compose_dataset_computation_with_iterative_process yang dapat digunakan untuk melakukan hal ini.

trainer_accepting_ids = tff.simulation.compose_dataset_computation_with_iterative_process(
    preprocessed_and_shuffled.dataset_computation, trainer)

Kedua ini tff.templates.IterativeProcesses dan satu di atas dijalankan dengan cara yang sama; namun mantan menerima dataset client preprocessed, dan yang terakhir menerima string yang mewakili id klien, penanganan baik konstruksi dataset dan preprocessing di tubuhnya - sebenarnya state dapat dilalui antara keduanya.

for _ in range(5):
  t1 = time.time()
  state, metrics = trainer_accepting_ids.next(state, selected_client_ids)
  t2 = time.time()
  print('loss {}, round time {}'.format(metrics['train']['loss'], t2 - t1))
loss 2.8417396545410156, round time 1.6707067489624023
loss 2.7670371532440186, round time 0.5207102298736572
loss 2.665048122406006, round time 0.5302855968475342
loss 2.7213189601898193, round time 0.5313887596130371
loss 2.580148935317993, round time 0.5283482074737549

Menskalakan ke sejumlah besar klien

trainer_accepting_ids dapat langsung digunakan dalam TFF runtime MULTIMESIN, dan menghindari mewujudkan tf.data.Datasets dan controller (dan karena itu serialisasi mereka dan mengirim mereka keluar untuk para pekerja).

Ini secara signifikan mempercepat simulasi terdistribusi, terutama dengan sejumlah besar klien, dan memungkinkan agregasi menengah untuk menghindari overhead serialisasi/deserialisasi yang serupa.

Deepdive opsional: menyusun logika prapemrosesan secara manual di TFF

TFF dirancang untuk komposisi dari bawah ke atas; jenis komposisi yang baru saja dilakukan oleh helper TFF sepenuhnya berada dalam kendali kami sebagai pengguna. Kita bisa memiliki manual menyusun perhitungan preprocessing kita hanya didefinisikan dengan pelatih sendiri next cukup sederhana:

selected_clients_type = tff.FederatedType(preprocessed_and_shuffled.dataset_computation.type_signature.parameter, tff.CLIENTS)

@tff.federated_computation(trainer.next.type_signature.parameter[0], selected_clients_type)
def new_next(server_state, selected_clients):
  preprocessed_data = tff.federated_map(preprocessed_and_shuffled.dataset_computation, selected_clients)
  return trainer.next(server_state, preprocessed_data)

manual_trainer_with_preprocessing = tff.templates.IterativeProcess(initialize_fn=trainer.initialize, next_fn=new_next)

Faktanya, inilah yang secara efektif dilakukan oleh helper yang kami gunakan di bawah tenda (ditambah melakukan pengecekan dan manipulasi tipe yang sesuai). Kita bahkan bisa menyatakan logika yang sama sedikit berbeda, dengan serialisasi preprocess_and_shuffle menjadi tff.Computation , dan dekomposisi federated_map menjadi satu langkah yang membangun un-preprocessed dataset dan lain yang berjalan preprocess_and_shuffle di setiap klien.

Kami dapat memverifikasi bahwa jalur yang lebih manual ini menghasilkan komputasi dengan tanda tangan tipe yang sama dengan helper TFF (nama parameter modulo):

print(trainer_accepting_ids.next.type_signature)
print(manual_trainer_with_preprocessing.next.type_signature)
(<server_state=<model=<trainable=<float32[784,10],float32[10]>,non_trainable=<>>,optimizer_state=<int64>,delta_aggregate_state=<value_sum_process=<>,weight_sum_process=<>>,model_broadcast_state=<>>@SERVER,federated_dataset={string}@CLIENTS> -> <<model=<trainable=<float32[784,10],float32[10]>,non_trainable=<>>,optimizer_state=<int64>,delta_aggregate_state=<value_sum_process=<>,weight_sum_process=<>>,model_broadcast_state=<>>@SERVER,<broadcast=<>,aggregation=<mean_value=<>,mean_weight=<>>,train=<sparse_categorical_accuracy=float32,loss=float32>,stat=<num_examples=int64>>@SERVER>)
(<server_state=<model=<trainable=<float32[784,10],float32[10]>,non_trainable=<>>,optimizer_state=<int64>,delta_aggregate_state=<value_sum_process=<>,weight_sum_process=<>>,model_broadcast_state=<>>@SERVER,selected_clients={string}@CLIENTS> -> <<model=<trainable=<float32[784,10],float32[10]>,non_trainable=<>>,optimizer_state=<int64>,delta_aggregate_state=<value_sum_process=<>,weight_sum_process=<>>,model_broadcast_state=<>>@SERVER,<broadcast=<>,aggregation=<mean_value=<>,mean_weight=<>>,train=<sparse_categorical_accuracy=float32,loss=float32>,stat=<num_examples=int64>>@SERVER>)