การสร้างอัลกอริทึมการเรียนรู้แบบสหพันธรัฐของคุณเอง

ดูบน TensorFlow.org ทำงานใน Google Colab ดูแหล่งที่มาบน GitHub ดาวน์โหลดโน๊ตบุ๊ค

ก่อนที่เราจะเริ่มต้น

ก่อนที่เราจะเริ่ม โปรดเรียกใช้สิ่งต่อไปนี้เพื่อให้แน่ใจว่าสภาพแวดล้อมของคุณได้รับการตั้งค่าอย่างถูกต้อง หากคุณไม่เห็นคำทักทายโปรดดูที่ การติดตั้ง คู่มือสำหรับคำแนะนำ

!pip install --quiet --upgrade tensorflow-federated-nightly
!pip install --quiet --upgrade nest-asyncio

import nest_asyncio
nest_asyncio.apply()
import tensorflow as tf
import tensorflow_federated as tff

ใน การจัดหมวดหมู่ภาพ และ รุ่นข้อความ บทเรียนที่เราได้เรียนรู้วิธีการตั้งค่ารูปแบบและข้อมูลที่เป็นท่อสำหรับสหพันธ์การเรียนรู้ (FL) และดำเนินการฝึกอบรมแบบ federated ผ่าน tff.learning ชั้น API ของฉิบหาย

นี่เป็นเพียงส่วนเล็ก ๆ ของภูเขาน้ำแข็งเมื่อพูดถึงการวิจัยของ FL ในการกวดวิชานี้เราหารือถึงวิธีการดำเนินการตามขั้นตอนวิธีการเรียนรู้แบบ federated โดยไม่ต้องชะลอไป tff.learning API เรามุ่งมั่นที่จะบรรลุสิ่งต่อไปนี้:

เป้าหมาย:

  • ทำความเข้าใจโครงสร้างทั่วไปของอัลกอริธึมการเรียนรู้แบบสหพันธรัฐ
  • สำรวจสหพันธ์หลักของฉิบหาย
  • ใช้ Federated Core เพื่อปรับใช้ Federated Averaging โดยตรง

ในขณะที่การกวดวิชานี้มีข้อมูลในตัวเราขอแนะนำเป็นครั้งแรกที่ได้อ่าน การจัดหมวดหมู่ภาพ และ ข้อความรุ่น บทเรียน

การเตรียมข้อมูลเข้า

ก่อนอื่นเราโหลดและประมวลผลชุดข้อมูล EMNIST ที่รวมอยู่ใน TFF ล่วงหน้า สำหรับรายละเอียดเพิ่มเติมโปรดดูที่ การจัดหมวดหมู่ภาพ กวดวิชา

emnist_train, emnist_test = tff.simulation.datasets.emnist.load_data()

เพื่อที่จะกินอาหารชุดข้อมูลในรูปแบบของเราเราแผ่ข้อมูลและแปลงแต่ละตัวอย่างลงใน tuple ของรูปแบบ (flattened_image_vector, label)

NUM_CLIENTS = 10
BATCH_SIZE = 20

def preprocess(dataset):

  def batch_format_fn(element):
    """Flatten a batch of EMNIST data and return a (features, label) tuple."""
    return (tf.reshape(element['pixels'], [-1, 784]), 
            tf.reshape(element['label'], [-1, 1]))

  return dataset.batch(BATCH_SIZE).map(batch_format_fn)

ตอนนี้เราเลือกลูกค้าจำนวนเล็กน้อย และใช้การประมวลผลล่วงหน้าด้านบนกับชุดข้อมูล

client_ids = sorted(emnist_train.client_ids)[:NUM_CLIENTS]
federated_train_data = [preprocess(emnist_train.create_tf_dataset_for_client(x))
  for x in client_ids
]

กำลังเตรียมโมเดล

เราใช้รูปแบบเดียวกับใน การจัดหมวดหมู่ภาพ กวดวิชา รุ่นนี้ (ดำเนินการผ่านทาง tf.keras ) มีชั้นเดียวที่ซ่อนอยู่ตามชั้น softmax

def create_keras_model():
  initializer = tf.keras.initializers.GlorotNormal(seed=0)
  return tf.keras.models.Sequential([
      tf.keras.layers.Input(shape=(784,)),
      tf.keras.layers.Dense(10, kernel_initializer=initializer),
      tf.keras.layers.Softmax(),
  ])

เพื่อที่จะใช้รูปแบบนี้ในฉิบหายเราห่อรุ่น Keras เป็น tff.learning.Model นี้จะช่วยให้เราสามารถดำเนินการรูปแบบของ การส่งผ่านไปข้างหน้า ภายในฉิบหายและ เอาท์พุทสารสกัดจากรูปแบบ สำหรับรายละเอียดเพิ่มเติมยังเห็น การจัดหมวดหมู่ภาพ กวดวิชา

def model_fn():
  keras_model = create_keras_model()
  return tff.learning.from_keras_model(
      keras_model,
      input_spec=federated_train_data[0].element_spec,
      loss=tf.keras.losses.SparseCategoricalCrossentropy(),
      metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])

ในขณะที่เราใช้ tf.keras เพื่อสร้าง tff.learning.Model , ฉิบหายสนับสนุนรูปแบบมากขึ้นทั่วไป โมเดลเหล่านี้มีคุณลักษณะที่เกี่ยวข้องต่อไปนี้ในการบันทึกน้ำหนักของแบบจำลอง:

  • trainable_variables : การ iterable ของเทนเซอร์ที่สอดคล้องกับชั้นสุวินัย
  • non_trainable_variables : การ iterable ของเทนเซอร์ที่สอดคล้องกับชั้นไม่ใช่สุวินัย

สำหรับวัตถุประสงค์ของเราเราจะใช้ trainable_variables (เนื่องจากโมเดลของเรามีเพียงแค่นั้นเท่านั้น!)

สร้างอัลกอริทึมการเรียนรู้แบบสหพันธรัฐของคุณเอง

ในขณะที่ tff.learning API จะช่วยหนึ่งในการสร้างหลายสายพันธุ์ของสหพันธ์ Averaging มีอัลกอริทึมแบบ federated อื่น ๆ ที่ไม่เหมาะสมอย่างเรียบร้อยในกรอบนี้ ตัวอย่างเช่นคุณอาจต้องการเพิ่ม regularization คลิปหรืออัลกอริทึมที่มีความซับซ้อนมากขึ้นเช่น การฝึกอบรม GAN federated คุณก็อาจจะแทนมีความสนใจใน การวิเคราะห์แบบ federated

สำหรับอัลกอริธึมขั้นสูงเหล่านี้ เราจะต้องเขียนอัลกอริทึมของเราเองโดยใช้ TFF ในหลายกรณี อัลกอริธึมแบบรวมศูนย์มี 4 องค์ประกอบหลัก:

  1. ขั้นตอนการออกอากาศแบบเซิร์ฟเวอร์ถึงไคลเอ็นต์
  2. ขั้นตอนการอัปเดตไคลเอ็นต์ท้องถิ่น
  3. ขั้นตอนการอัปโหลดไคลเอ็นต์สู่เซิร์ฟเวอร์
  4. ขั้นตอนการอัปเดตเซิร์ฟเวอร์

ในฉิบหายเราโดยทั่วไปหมายถึงขั้นตอนวิธีแบบ federated เป็น tff.templates.IterativeProcess (ซึ่งเราจะเรียกว่าเป็นเพียง IterativeProcess ตลอด) นี้เป็นชั้นที่มี initialize และ next ฟังก์ชั่น ที่นี่ initialize จะใช้ในการเริ่มต้นเซิร์ฟเวอร์และ next จะดำเนินการอย่างใดอย่างหนึ่งรอบการสื่อสารของอัลกอริทึมแบบ federated มาเขียนโครงร่างของกระบวนการทำซ้ำของเราสำหรับ FedAvg กันดีกว่า

ครั้งแรกที่เรามีฟังก์ชั่นการเริ่มต้นที่เพียงสร้าง tff.learning.Model และผลตอบแทนของน้ำหนักสุวินัย

def initialize_fn():
  model = model_fn()
  return model.trainable_variables

ฟังก์ชันนี้ดูดี แต่อย่างที่เราจะเห็นในภายหลัง เราจำเป็นต้องทำการปรับเปลี่ยนเล็กน้อยเพื่อให้เป็น "การคำนวณ TFF"

นอกจากนี้เรายังต้องการที่จะวาด next_fn

def next_fn(server_weights, federated_dataset):
  # Broadcast the server weights to the clients.
  server_weights_at_client = broadcast(server_weights)

  # Each client computes their updated weights.
  client_weights = client_update(federated_dataset, server_weights_at_client)

  # The server averages these updates.
  mean_client_weights = mean(client_weights)

  # The server updates its model.
  server_weights = server_update(mean_client_weights)

  return server_weights

เราจะเน้นที่การใช้องค์ประกอบทั้งสี่นี้แยกกัน ก่อนอื่นเราเน้นที่ส่วนต่างๆ ที่สามารถนำมาใช้ใน TensorFlow ได้ นั่นคือขั้นตอนการอัปเดตไคลเอ็นต์และเซิร์ฟเวอร์

บล็อกเทนเซอร์โฟลว์

อัพเดทลูกค้า

เราจะใช้ของเรา tff.learning.Model ที่จะดำเนินการฝึกอบรมลูกค้าเป็นหลักเดียวกับที่คุณจะอบรมรุ่น TensorFlow โดยเฉพาะอย่างยิ่งเราจะใช้ tf.GradientTape เพื่อคำนวณการไล่ระดับสีในกระบวนการของข้อมูลแล้วใช้การไล่ระดับสีเหล่านี้โดยใช้ client_optimizer เราเน้นเฉพาะตุ้มน้ำหนักที่สามารถฝึกได้

@tf.function
def client_update(model, dataset, server_weights, client_optimizer):
  """Performs training (using the server model weights) on the client's dataset."""
  # Initialize the client model with the current server weights.
  client_weights = model.trainable_variables
  # Assign the server weights to the client model.
  tf.nest.map_structure(lambda x, y: x.assign(y),
                        client_weights, server_weights)

  # Use the client_optimizer to update the local model.
  for batch in dataset:
    with tf.GradientTape() as tape:
      # Compute a forward pass on the batch of data
      outputs = model.forward_pass(batch)

    # Compute the corresponding gradient
    grads = tape.gradient(outputs.loss, client_weights)
    grads_and_vars = zip(grads, client_weights)

    # Apply the gradient using a client optimizer.
    client_optimizer.apply_gradients(grads_and_vars)

  return client_weights

อัพเดทเซิฟเวอร์

การอัปเดตเซิร์ฟเวอร์สำหรับ FedAvg ง่ายกว่าการอัปเดตไคลเอ็นต์ เราจะใช้ "วานิลลา" เฉลี่ยแบบรวมศูนย์ ซึ่งเราเพียงแค่แทนที่น้ำหนักโมเดลเซิร์ฟเวอร์ด้วยค่าเฉลี่ยของน้ำหนักโมเดลไคลเอ็นต์ ย้ำอีกครั้ง เราเน้นเฉพาะตุ้มน้ำหนักที่ฝึกได้

@tf.function
def server_update(model, mean_client_weights):
  """Updates the server model weights as the average of the client model weights."""
  model_weights = model.trainable_variables
  # Assign the mean client weights to the server model.
  tf.nest.map_structure(lambda x, y: x.assign(y),
                        model_weights, mean_client_weights)
  return model_weights

ข้อมูลอาจจะง่ายโดยเพียงแค่กลับมา mean_client_weights การใช้งานอย่างไรก็ตามที่สูงขึ้นของสหพันธ์ใช้ Averaging mean_client_weights ด้วยเทคนิคที่ซับซ้อนมากขึ้นเช่นโมเมนตัมหรือปรับตัว

ท้าทาย: Implement รุ่นของ server_update ที่ปรับปรุงน้ำหนักเซิร์ฟเวอร์เพื่อเป็นจุดกึ่งกลางของ model_weights และ mean_client_weights (หมายเหตุ: ชนิดของ "จุดกึ่งกลาง" วิธีการนี้จะคล้ายคลึงกับการทำงานที่ผ่านมาเกี่ยวกับการ เพิ่มประสิทธิภาพ Lookahead !)

จนถึงตอนนี้ เราเพิ่งเขียนโค้ด TensorFlow ล้วนๆ เท่านั้น นี่คือการออกแบบ เนื่องจาก TFF อนุญาตให้คุณใช้โค้ด TensorFlow ส่วนใหญ่ที่คุณคุ้นเคยอยู่แล้ว แต่ตอนนี้เราจะต้องระบุตรรกะประสานที่เป็นตรรกะที่สั่งการสิ่งที่ออกอากาศเซิร์ฟเวอร์ไปยังลูกค้าและสิ่งที่เป็นภาพที่ส่งลูกค้าไปยังเซิร์ฟเวอร์

นี้จะต้องสหพันธ์หลักของฉิบหาย

บทนำสู่สหพันธรัฐคอร์

สหพันธ์แกน (เอฟซี) เป็นชุดของการเชื่อมต่อในระดับต่ำกว่าที่ทำหน้าที่เป็นรากฐานสำหรับการ tff.learning API อย่างไรก็ตาม อินเทอร์เฟซเหล่านี้ไม่จำกัดเพียงการเรียนรู้ อันที่จริง พวกเขาสามารถใช้สำหรับการวิเคราะห์และการคำนวณอื่น ๆ มากกว่าข้อมูลแบบกระจาย

ในระดับสูง คอร์รวมเป็นสภาพแวดล้อมการพัฒนาที่ช่วยให้ตรรกะของโปรแกรมแสดงออกอย่างกะทัดรัดเพื่อรวมรหัส TensorFlow กับตัวดำเนินการสื่อสารแบบกระจาย (เช่น ผลรวมแบบกระจายและการออกอากาศ) เป้าหมายคือเพื่อให้นักวิจัยและผู้ปฏิบัติงานสามารถควบคุมการสื่อสารแบบกระจายในระบบของตนได้โดยไม่ต้องใช้รายละเอียดการใช้ระบบ (เช่นการระบุการแลกเปลี่ยนข้อความเครือข่ายแบบจุดต่อจุด)

ประเด็นสำคัญประการหนึ่งคือ TFF ได้รับการออกแบบมาเพื่อรักษาความเป็นส่วนตัว ดังนั้นจึงช่วยให้สามารถควบคุมตำแหน่งข้อมูลได้อย่างชัดเจน เพื่อป้องกันการสะสมข้อมูลที่ไม่ต้องการที่ตำแหน่งเซิร์ฟเวอร์ส่วนกลาง

ข้อมูลส่วนกลาง

แนวคิดหลักใน TFF คือ "ข้อมูลรวม" ซึ่งหมายถึงการรวบรวมรายการข้อมูลที่โฮสต์ข้ามกลุ่มอุปกรณ์ในระบบแบบกระจาย (เช่น ชุดข้อมูลไคลเอ็นต์ หรือน้ำหนักโมเดลเซิร์ฟเวอร์) เรารูปแบบการเก็บรวบรวมทั้งหมดของรายการข้อมูลในอุปกรณ์ทั้งหมดเป็นค่า federated เดียว

ตัวอย่างเช่น สมมติว่าเรามีอุปกรณ์ไคลเอนต์ที่แต่ละอันมีทุ่นแทนอุณหภูมิของเซ็นเซอร์ เราสามารถแสดงเป็นลอย federated โดย

federated_float_on_clients = tff.FederatedType(tf.float32, tff.CLIENTS)

ประเภทสหพันธ์ถูกกำหนดโดยชนิด T ขององค์ประกอบของสมาชิก (เช่น. tf.float32 ) และกลุ่ม G ของอุปกรณ์ เราจะมุ่งเน้นไปที่กรณีที่ G เป็นทั้ง tff.CLIENTS หรือ tff.SERVER ดังกล่าวเป็นชนิด federated จะแสดงเป็น {T}@G ที่แสดงด้านล่าง

str(federated_float_on_clients)
'{float32}@CLIENTS'

ทำไมเราถึงสนใจเรื่องตำแหน่งมาก? เป้าหมายหลักของ TFF คือการเปิดใช้งานการเขียนโค้ดที่สามารถนำไปใช้งานบนระบบแบบกระจายจริงได้ ซึ่งหมายความว่าจำเป็นอย่างยิ่งที่จะต้องให้เหตุผลว่าอุปกรณ์ชุดย่อยใดที่เรียกใช้โค้ดใด และข้อมูลส่วนต่างๆ อยู่ที่ไหน

ฉิบหายมุ่งเน้นไปที่สิ่งที่สาม: ข้อมูลที่มีข้อมูลที่มีการวางและวิธีการที่ข้อมูลจะถูกเปลี่ยน สองคนแรกที่ถูกห่อหุ้มในรูปแบบ federated ในขณะที่ที่ผ่านมามีการห่อหุ้มในการคำนวณแบบ federated

การคำนวณแบบสหพันธรัฐ

ฉิบหายเป็นสภาพแวดล้อมการเขียนโปรแกรมอย่างยิ่งพิมพ์การทำงานที่มีหน่วยพื้นฐานคำนวณ federated เหล่านี้เป็นส่วนของตรรกะที่ยอมรับค่ารวมเป็นอินพุต และส่งกลับค่ารวมเป็นผลลัพธ์

ตัวอย่างเช่น สมมติว่าเราต้องการเฉลี่ยอุณหภูมิบนเซ็นเซอร์ลูกค้าของเรา เราสามารถกำหนดสิ่งต่อไปนี้ (โดยใช้ federated float ของเรา):

@tff.federated_computation(tff.FederatedType(tf.float32, tff.CLIENTS))
def get_average_temperature(client_temperatures):
  return tff.federated_mean(client_temperatures)

คุณอาจจะถามว่าวิธีการที่แตกต่างกันนี้จาก tf.function มัณฑนากรใน TensorFlow? คำตอบที่สำคัญคือว่ารหัสที่สร้างขึ้นโดย tff.federated_computation จะไม่ TensorFlow มิได้หลามรหัส; มันเป็นคุณสมบัติของระบบกระจายในแพลตฟอร์มภาษากาวภายใน

แม้ว่าสิ่งนี้อาจฟังดูซับซ้อน แต่คุณสามารถคิดได้ว่าการคำนวณ TFF เป็นฟังก์ชันที่มีลายเซ็นประเภทที่กำหนดไว้อย่างดี ลายเซ็นประเภทนี้สามารถสอบถามโดยตรง

str(get_average_temperature.type_signature)
'({float32}@CLIENTS -> float32@SERVER)'

นี้ tff.federated_computation ยอมรับข้อโต้แย้งของประเภท federated {float32}@CLIENTS และค่าผลตอบแทนจากประเภท federated {float32}@SERVER การคำนวณแบบรวมศูนย์อาจเปลี่ยนจากเซิร์ฟเวอร์หนึ่งไปยังอีกเครื่องหนึ่ง จากเครื่องลูกหนึ่งไปยังอีกเครื่องหนึ่ง หรือจากเครื่องหนึ่งไปยังอีกเครื่องหนึ่ง การคำนวณแบบรวมศูนย์สามารถประกอบได้เหมือนกับฟังก์ชันปกติ ตราบใดที่ลายเซ็นประเภทตรงกัน

เพื่อสนับสนุนการพัฒนา, ฉิบหายช่วยให้คุณสามารถเรียก tff.federated_computation เป็นฟังก์ชั่นหลาม ตัวอย่างเช่น เราสามารถเรียก

get_average_temperature([68.5, 70.3, 69.8])
69.53334

การคำนวณแบบไม่กระตือรือร้นและ TensorFlow

มีข้อ จำกัด ที่สำคัญสองประการที่ควรทราบ ครั้งแรกเมื่อล่ามหลามพบ tff.federated_computation มัณฑนากร, ฟังก์ชั่นที่มีการตรวจสอบทันทีและต่อเนื่องสำหรับการใช้งานในอนาคต เนื่องจากลักษณะการกระจายอำนาจของ Federated Learning การใช้งานในอนาคตนี้อาจเกิดขึ้นที่อื่น เช่น สภาพแวดล้อมการดำเนินการระยะไกล ดังนั้นการคำนวณฉิบหายเป็นพื้นฐานที่ไม่กระตือรือร้น ลักษณะการทำงานนี้จะค่อนข้างคล้ายกับที่ของ tf.function มัณฑนากรใน TensorFlow

ประการที่สองการคำนวณแบบ federated สามารถประกอบด้วยผู้ประกอบการ federated (เช่น tff.federated_mean ) พวกเขาไม่สามารถมีการดำเนินงาน TensorFlow รหัส TensorFlow จะต้องถูกคุมขังในบล็อกตกแต่งด้วย tff.tf_computation ส่วนใหญ่รหัส TensorFlow สามัญสามารถตกแต่งโดยตรงเช่นฟังก์ชั่นดังต่อไปนี้ที่ใช้เวลาจำนวนและเพิ่ม 0.5 กับมัน

@tff.tf_computation(tf.float32)
def add_half(x):
  return tf.add(x, 0.5)

เหล่านี้ยังมีประเภทลายเซ็น แต่ไม่มีตำแหน่ง ตัวอย่างเช่น เราสามารถเรียก

str(add_half.type_signature)
'(float32 -> float32)'

ที่นี่เราเห็นความแตกต่างที่สำคัญระหว่าง tff.federated_computation และ tff.tf_computation อดีตมีตำแหน่งที่ชัดเจนในขณะที่หลังไม่มี

เราสามารถใช้ tff.tf_computation บล็อกในการคำนวณแบบ federated โดยการระบุตำแหน่ง มาสร้างฟังก์ชันที่บวกครึ่ง แต่เฉพาะกับโฟลตรวมที่ไคลเอนต์ เราสามารถทำได้โดยใช้ tff.federated_map ซึ่งใช้ที่กำหนด tff.tf_computation ขณะที่การรักษาตำแหน่ง

@tff.federated_computation(tff.FederatedType(tf.float32, tff.CLIENTS))
def add_half_on_clients(x):
  return tff.federated_map(add_half, x)

ฟังก์ชั่นนี้เป็นเกือบจะเหมือนกับ add_half ยกเว้นว่าจะยอมรับเฉพาะค่ากับตำแหน่งที่ tff.CLIENTS และค่าผลตอบแทนกับตำแหน่งเดียวกัน เราสามารถเห็นสิ่งนี้ในลายเซ็นประเภท:

str(add_half_on_clients.type_signature)
'({float32}@CLIENTS -> {float32}@CLIENTS)'

สรุป:

  • TFF ทำงานบนค่าส่วนกลาง
  • แต่ละค่า federated มีชนิดแบบรวมที่มีประเภท (เช่น. tf.float32 ) และตำแหน่ง (เช่น. tff.CLIENTS )
  • ค่าสหพันธ์สามารถเปลี่ยนโดยใช้การคำนวณแบบ federated ซึ่งจะต้องได้รับการตกแต่งด้วย tff.federated_computation และลายเซ็นประเภท federated
  • รหัส TensorFlow จะต้องมีอยู่ในบล็อกที่มี tff.tf_computation ตกแต่ง
  • บล็อกเหล่านี้สามารถรวมเข้ากับการคำนวณแบบรวมศูนย์ได้

สร้างอัลกอริธึมการเรียนรู้แบบสหพันธรัฐของคุณเอง กลับมาอีกครั้ง

ตอนนี้เราได้เห็น Federated Core แล้ว เราจึงสามารถสร้างอัลกอริทึมการเรียนรู้แบบรวมศูนย์ของเราเองได้ โปรดจำไว้ว่าข้างต้นเราได้กำหนด initialize_fn และ next_fn สำหรับขั้นตอนวิธีการของเรา next_fn จะทำให้การใช้งานของ client_update และ server_update เรากำหนดโดยใช้รหัส TensorFlow บริสุทธิ์

อย่างไรก็ตามในการที่จะทำให้อัลกอริทึมของเราคำนวณ federated เราจะจำเป็นต้องใช้ทั้ง next_fn และ initialize_fn แต่ละเป็น tff.federated_computation

TensorFlow Federated บล็อก

การสร้างการคำนวณการเริ่มต้น

ฟังก์ชั่นการเริ่มต้นจะค่อนข้างง่าย: เราจะสร้างรูปแบบการใช้ model_fn แต่จำไว้ว่าเราจะต้องแยกออกจากเรารหัส TensorFlow ใช้ tff.tf_computation

@tff.tf_computation
def server_init():
  model = model_fn()
  return model.trainable_variables

จากนั้นเราจะสามารถผ่านนี้โดยตรงในการคำนวณโดยใช้ federated tff.federated_value

@tff.federated_computation
def initialize_fn():
  return tff.federated_value(server_init(), tff.SERVER)

สร้าง next_fn

ตอนนี้เราใช้รหัสอัปเดตไคลเอ็นต์และเซิร์ฟเวอร์เพื่อเขียนอัลกอริทึมจริง ครั้งแรกที่เราจะเปิดของเรา client_update เป็น tff.tf_computation ที่ยอมรับชุดข้อมูลลูกค้าและน้ำหนักเซิร์ฟเวอร์และผลการปรับปรุงน้ำหนักของลูกค้าเมตริกซ์

เราจะต้องใช้ประเภทที่เกี่ยวข้องเพื่อตกแต่งฟังก์ชันของเราอย่างเหมาะสม โชคดีที่ประเภทของน้ำหนักเซิร์ฟเวอร์สามารถแยกได้โดยตรงจากแบบจำลองของเรา

whimsy_model = model_fn()
tf_dataset_type = tff.SequenceType(whimsy_model.input_spec)

ลองดูที่ลายเซ็นประเภทชุดข้อมูล จำไว้ว่าเราถ่ายภาพ 28 คูณ 28 (ด้วยป้ายกำกับจำนวนเต็ม) และทำให้แบน

str(tf_dataset_type)
'<float32[?,784],int32[?,1]>*'

นอกจากนี้เรายังสามารถแยกประเภทรุ่นน้ำหนักโดยใช้ของเรา server_init ฟังก์ชั่นดังกล่าวข้างต้น

model_weights_type = server_init.type_signature.result

การตรวจสอบประเภทลายเซ็น เราจะสามารถเห็นสถาปัตยกรรมของแบบจำลองของเราได้!

str(model_weights_type)
'<float32[784,10],float32[10]>'

ตอนนี้เราสามารถสร้างของเรา tff.tf_computation สำหรับการปรับปรุงไคลเอ็นต์

@tff.tf_computation(tf_dataset_type, model_weights_type)
def client_update_fn(tf_dataset, server_weights):
  model = model_fn()
  client_optimizer = tf.keras.optimizers.SGD(learning_rate=0.01)
  return client_update(model, tf_dataset, server_weights, client_optimizer)

tff.tf_computation รุ่นของโปรแกรมปรับปรุงเซิร์ฟเวอร์สามารถกำหนดในลักษณะที่คล้ายกันโดยใช้ชนิดที่เราได้สกัดแล้ว

@tff.tf_computation(model_weights_type)
def server_update_fn(mean_client_weights):
  model = model_fn()
  return server_update(model, mean_client_weights)

สุดท้าย แต่ไม่น้อยเราต้องสร้าง tff.federated_computation ที่นำมานี้ทั้งหมดเข้าด้วยกัน ฟังก์ชั่นนี้จะยอมรับค่าทั้งสองแบบ federated หนึ่งที่สอดคล้องกับน้ำหนักเซิร์ฟเวอร์ (กับตำแหน่ง tff.SERVER ) และอื่น ๆ ที่สอดคล้องกับชุดข้อมูลลูกค้า (กับตำแหน่ง tff.CLIENTS )

โปรดทราบว่าทั้งสองประเภทนี้ถูกกำหนดไว้ข้างต้น! เราก็ต้องให้พวกเขามีตำแหน่งที่เหมาะสมโดยใช้ tff.FederatedType

federated_server_type = tff.FederatedType(model_weights_type, tff.SERVER)
federated_dataset_type = tff.FederatedType(tf_dataset_type, tff.CLIENTS)

จำองค์ประกอบ 4 ประการของอัลกอริธึม FL ได้หรือไม่

  1. ขั้นตอนการออกอากาศแบบเซิร์ฟเวอร์ถึงไคลเอ็นต์
  2. ขั้นตอนการอัปเดตไคลเอ็นต์ท้องถิ่น
  3. ขั้นตอนการอัปโหลดไคลเอ็นต์สู่เซิร์ฟเวอร์
  4. ขั้นตอนการอัปเดตเซิร์ฟเวอร์

ตอนนี้เราได้สร้างส่วนข้างต้นแล้ว แต่ละส่วนสามารถแสดงเป็นโค้ด TFF บรรทัดเดียวได้ ความเรียบง่ายนี้เป็นเหตุผลว่าทำไมเราจึงต้องระมัดระวังเป็นพิเศษในการระบุสิ่งต่างๆ เช่น ประเภทรวม!

@tff.federated_computation(federated_server_type, federated_dataset_type)
def next_fn(server_weights, federated_dataset):
  # Broadcast the server weights to the clients.
  server_weights_at_client = tff.federated_broadcast(server_weights)

  # Each client computes their updated weights.
  client_weights = tff.federated_map(
      client_update_fn, (federated_dataset, server_weights_at_client))

  # The server averages these updates.
  mean_client_weights = tff.federated_mean(client_weights)

  # The server updates its model.
  server_weights = tff.federated_map(server_update_fn, mean_client_weights)

  return server_weights

ขณะนี้เรามี tff.federated_computation สำหรับทั้งการเริ่มต้นขั้นตอนวิธีการและสำหรับการทำงานเป็นขั้นตอนหนึ่งของขั้นตอนวิธี จะเสร็จสิ้นขั้นตอนวิธีการของเราเราผ่านเหล่านี้เป็น tff.templates.IterativeProcess

federated_algorithm = tff.templates.IterativeProcess(
    initialize_fn=initialize_fn,
    next_fn=next_fn
)

ดู Let 's ที่ลายเซ็นประเภทของ initialize และ next การทำงานของกระบวนการซ้ำของเรา

str(federated_algorithm.initialize.type_signature)
'( -> <float32[784,10],float32[10]>@SERVER)'

สะท้อนให้เห็นถึงความจริงที่ว่า federated_algorithm.initialize เป็นฟังก์ชั่นไม่หาเรื่องว่าผลตอบแทนรูปแบบชั้นเดียว (กับ 784 โดยน้ำหนัก 10 เมทริกซ์และ 10 หน่วยอคติ)

str(federated_algorithm.next.type_signature)
'(<server_weights=<float32[784,10],float32[10]>@SERVER,federated_dataset={<float32[?,784],int32[?,1]>*}@CLIENTS> -> <float32[784,10],float32[10]>@SERVER)'

ที่นี่เราจะเห็นว่า federated_algorithm.next ยอมรับรูปแบบเซิร์ฟเวอร์และไคลเอ็นต์ข้อมูลและผลตอบแทนรูปแบบการปรับปรุงเซิร์ฟเวอร์

การประเมินอัลกอริทึม

ลองวิ่งสักสองสามรอบและดูว่าการสูญเสียจะเปลี่ยนไปอย่างไร ครั้งแรกที่เราจะกำหนดฟังก์ชั่นการประเมินผลการใช้วิธีการรวมศูนย์ที่กล่าวถึงในการกวดวิชาที่สอง

ขั้นแรก เราสร้างชุดข้อมูลการประเมินแบบรวมศูนย์ จากนั้นจึงใช้การประมวลผลล่วงหน้าแบบเดียวกับที่เราใช้สำหรับข้อมูลการฝึกอบรม

central_emnist_test = emnist_test.create_tf_dataset_from_all_clients()
central_emnist_test = preprocess(central_emnist_test)

ต่อไป เราเขียนฟังก์ชันที่ยอมรับสถานะเซิร์ฟเวอร์ และใช้ Keras เพื่อประเมินในชุดข้อมูลทดสอบ ถ้าคุณคุ้นเคยกับ tf.Keras นี้จะมีลักษณะทุกคนคุ้นเคย แต่บันทึกการใช้งานของ set_weights !

def evaluate(server_state):
  keras_model = create_keras_model()
  keras_model.compile(
      loss=tf.keras.losses.SparseCategoricalCrossentropy(),
      metrics=[tf.keras.metrics.SparseCategoricalAccuracy()]  
  )
  keras_model.set_weights(server_state)
  keras_model.evaluate(central_emnist_test)

ตอนนี้ เรามาเริ่มต้นอัลกอริทึมของเราและประเมินในชุดทดสอบกัน

server_state = federated_algorithm.initialize()
evaluate(server_state)
2042/2042 [==============================] - 2s 767us/step - loss: 2.8479 - sparse_categorical_accuracy: 0.1027

มาฝึกกันสักสองสามรอบและดูว่ามีอะไรเปลี่ยนแปลงไหม

for round in range(15):
  server_state = federated_algorithm.next(server_state, federated_train_data)
evaluate(server_state)
2042/2042 [==============================] - 2s 738us/step - loss: 2.5867 - sparse_categorical_accuracy: 0.0980

เราเห็นการลดลงเล็กน้อยในฟังก์ชันการสูญเสีย แม้ว่าการกระโดดจะเล็ก แต่เราได้ทำการฝึกซ้อมเพียง 15 รอบและในกลุ่มย่อยของลูกค้า เพื่อให้เห็นผลดีขึ้น เราอาจต้องทำหลายร้อยรอบหากไม่ใช่

การปรับเปลี่ยนอัลกอริทึมของเรา

ณ จุดนี้ ให้หยุดคิดเกี่ยวกับสิ่งที่เราทำสำเร็จ เราได้ปรับใช้ Federated Averaging โดยตรงโดยการรวมโค้ด TensorFlow แท้ (สำหรับการอัปเดตไคลเอ็นต์และเซิร์ฟเวอร์) กับการคำนวณแบบรวมศูนย์จาก Federated Core ของ TFF

เพื่อดำเนินการเรียนรู้ที่ซับซ้อนมากขึ้น เราสามารถปรับเปลี่ยนสิ่งที่เรามีข้างต้นได้ โดยเฉพาะอย่างยิ่ง โดยการแก้ไขโค้ด TF ล้วนๆ ด้านบน เราสามารถเปลี่ยนวิธีที่ไคลเอ็นต์ดำเนินการฝึกอบรม หรือวิธีที่เซิร์ฟเวอร์อัปเดตโมเดล

ท้าทาย: เพิ่ม คลิปลาด ไป client_update ฟังก์ชั่น

หากเราต้องการทำการเปลี่ยนแปลงที่ใหญ่ขึ้น เราก็สามารถมีที่เก็บเซิร์ฟเวอร์และกระจายข้อมูลได้มากขึ้น ตัวอย่างเช่น เซิร์ฟเวอร์ยังสามารถเก็บอัตราการเรียนรู้ของไคลเอ็นต์ และทำให้เสื่อมลงเมื่อเวลาผ่านไป! หมายเหตุว่านี้จะต้องมีการเปลี่ยนแปลงลายเซ็นชนิดที่ใช้ใน tff.tf_computation สายดังกล่าวข้างต้น

ท้าทายยาก: ระบบสหพันธ์ Averaging กับการเรียนรู้การสลายตัวของอัตราลูกค้า

ณ จุดนี้ คุณอาจเริ่มตระหนักถึงความยืดหยุ่นในสิ่งที่คุณสามารถนำไปใช้ในกรอบงานนี้ สำหรับความคิด (รวมถึงคำตอบให้กับความท้าทายที่ยากข้างต้น) คุณสามารถดูซอร์สโค้ดสำหรับ tff.learning.build_federated_averaging_process หรือเช็คเอาท์ต่างๆ โครงการวิจัย โดยใช้ฉิบหาย