สหพันธ์การเรียนรู้สำหรับการจำแนกรูปภาพ for

ดูบน TensorFlow.org ทำงานใน Google Colab ดูแหล่งที่มาบน GitHub ดาวน์โหลดโน๊ตบุ๊ค

ในการกวดวิชานี้เราจะใช้ตัวอย่างการฝึกอบรม MNIST คลาสสิกที่จะแนะนำการเรียนรู้ของสหพันธ์ (FL) ชั้น API ของฉิบหาย tff.learning - ชุดของการเชื่อมต่อในระดับสูงที่สามารถนำมาใช้ในการดำเนินชนิดที่พบบ่อยของงานการเรียนรู้แบบ federated เช่น การฝึกอบรมแบบสหพันธรัฐ เทียบกับโมเดลที่ผู้ใช้จัดหาซึ่งนำไปใช้ใน TensorFlow

บทช่วยสอนนี้และ Federated Learning API มีไว้สำหรับผู้ใช้ที่ต้องการเสียบโมเดล TensorFlow ของตนเองเข้ากับ TFF เป็นหลัก โดยถือว่าส่วนหลังส่วนใหญ่เป็นกล่องดำ เพื่อความเข้าใจที่มากขึ้นในเชิงลึกของฉิบหายและวิธีการดำเนินการตามขั้นตอนวิธีการเรียนรู้แบบรวมของคุณเองให้ดูบทเรียนเกี่ยวกับเอฟซีหลักของ API - กำหนดเองสหพันธ์อัลกอริทึมส่วนที่ 1 และ ส่วนที่ 2

สำหรับข้อมูลเพิ่มเติมเกี่ยว tff.learning , ดำเนินการกับ การเรียนรู้สำหรับสหพันธ์ข้อความ Generation , กวดวิชาซึ่งนอกจากจะครอบคลุมรูปแบบที่เกิดขึ้นอีกนอกจากนี้ยังแสดงให้เห็นถึงการโหลดก่อนการฝึกอบรมต่อเนื่องรุ่น Keras สำหรับการปรับแต่งกับการเรียนรู้แบบ federated รวมกับการประเมินผลการใช้ Keras

ก่อนที่เราจะเริ่มต้น

ก่อนที่เราจะเริ่ม โปรดเรียกใช้สิ่งต่อไปนี้เพื่อให้แน่ใจว่าสภาพแวดล้อมของคุณได้รับการตั้งค่าอย่างถูกต้อง หากคุณไม่เห็นคำทักทายโปรดดูที่ การติดตั้ง คู่มือสำหรับคำแนะนำ

# tensorflow_federated_nightly also bring in tf_nightly, which
# can causes a duplicate tensorboard install, leading to errors.
!pip uninstall --yes tensorboard tb-nightly

!pip install --quiet --upgrade tensorflow-federated-nightly
!pip install --quiet --upgrade nest-asyncio
!pip install --quiet --upgrade tb-nightly  # or tensorboard, but not both

import nest_asyncio
nest_asyncio.apply()
%load_ext tensorboard
import collections

import numpy as np
import tensorflow as tf
import tensorflow_federated as tff

np.random.seed(0)

tff.federated_computation(lambda: 'Hello, World!')()
b'Hello, World!'

การเตรียมข้อมูลเข้า

เริ่มจากข้อมูลกันก่อน การเรียนรู้แบบสหพันธรัฐต้องการชุดข้อมูลแบบรวมศูนย์ กล่าวคือ การรวบรวมข้อมูลจากผู้ใช้หลายราย ข้อมูลสหพันธ์เป็นปกติไม่ใช่ IID ซึ่งส่อเค้าชุดที่เป็นเอกลักษณ์ของความท้าทาย

เพื่อความสะดวกในการทดลองเราเมล็ดที่เก็บฉิบหายที่มีไม่กี่ชุดข้อมูลรวมทั้งรุ่น federated ของ MNIST ที่มีรุ่นที่ ชุด NIST เดิม ที่ได้รับอีกครั้งประมวลผลโดยใช้ ใบ เพื่อให้ข้อมูลที่เป็นคีย์โดยนักเขียนเดิมของ ตัวเลข เนื่องจากผู้เขียนแต่ละคนมีลักษณะเฉพาะ ชุดข้อมูลนี้จึงแสดงลักษณะการทำงานที่ไม่ใช่ iid ที่คาดไว้ของชุดข้อมูลรวม

นี่คือวิธีที่เราสามารถโหลดได้

emnist_train, emnist_test = tff.simulation.datasets.emnist.load_data()

ชุดข้อมูลที่ส่งกลับโดย load_data() เป็นกรณีของ tff.simulation.ClientData , อินเตอร์เฟซที่ช่วยให้คุณระบุการตั้งค่าของผู้ใช้ในการสร้าง tf.data.Dataset ที่แสดงถึงข้อมูลของผู้ใช้เฉพาะและการสอบถาม โครงสร้างของแต่ละองค์ประกอบ คุณสามารถใช้อินเทอร์เฟซนี้เพื่อสำรวจเนื้อหาของชุดข้อมูลได้อย่างไร โปรดทราบว่าแม้อินเทอร์เฟซนี้จะอนุญาตให้คุณทำซ้ำผ่านรหัสลูกค้าได้ นี่เป็นเพียงคุณลักษณะของข้อมูลการจำลองเท่านั้น ดังที่คุณจะเห็นในไม่ช้านี้ ข้อมูลประจำตัวของลูกค้าจะไม่ถูกใช้โดยกรอบการเรียนรู้แบบรวมศูนย์ - จุดประสงค์เดียวของพวกเขาคือเพื่อให้คุณสามารถเลือกชุดย่อยของข้อมูลสำหรับการจำลองได้

len(emnist_train.client_ids)
3383
emnist_train.element_type_structure
OrderedDict([('label', TensorSpec(shape=(), dtype=tf.int32, name=None)), ('pixels', TensorSpec(shape=(28, 28), dtype=tf.float32, name=None))])
example_dataset = emnist_train.create_tf_dataset_for_client(
    emnist_train.client_ids[0])

example_element = next(iter(example_dataset))

example_element['label'].numpy()
1
from matplotlib import pyplot as plt
plt.imshow(example_element['pixels'].numpy(), cmap='gray', aspect='equal')
plt.grid(False)
_ = plt.show()

png

สำรวจความแตกต่างในข้อมูลรวม

ข้อมูลสหพันธ์เป็นปกติไม่ใช่ IID ผู้ใช้มักจะมีการกระจายของข้อมูลที่แตกต่างกันขึ้นอยู่กับรูปแบบการใช้ ลูกค้าบางรายอาจมีตัวอย่างการฝึกอบรมบนอุปกรณ์น้อยกว่า เนื่องจากมีปัญหาด้านข้อมูลภายใน ในขณะที่ลูกค้าบางรายอาจมีตัวอย่างการฝึกอบรมมากเกินพอ มาสำรวจแนวคิดนี้เกี่ยวกับความแตกต่างของข้อมูลตามแบบฉบับของระบบรวมที่มีข้อมูล EMNIST ที่เรามีอยู่ สิ่งสำคัญคือต้องทราบว่าการวิเคราะห์ข้อมูลของลูกค้าในเชิงลึกนี้มีให้เฉพาะเราเท่านั้น เนื่องจากเป็นสภาพแวดล้อมการจำลองที่ข้อมูลทั้งหมดมีให้เราในพื้นที่ ในสภาพแวดล้อมที่รวมการผลิตจริง คุณจะไม่สามารถตรวจสอบข้อมูลของลูกค้ารายเดียวได้

อันดับแรก เรามาดูตัวอย่างข้อมูลของลูกค้ารายหนึ่งเพื่อทำความเข้าใจตัวอย่างบนอุปกรณ์จำลองเครื่องหนึ่ง เนื่องจากชุดข้อมูลที่เราใช้ได้รับการคีย์โดยผู้เขียนที่ไม่ซ้ำกัน ข้อมูลของลูกค้ารายหนึ่งแสดงถึงลายมือของคนคนหนึ่งสำหรับตัวอย่างตัวเลข 0 ถึง 9 ซึ่งจำลอง "รูปแบบการใช้งาน" ที่ไม่ซ้ำกันของผู้ใช้รายหนึ่ง

## Example MNIST digits for one client
figure = plt.figure(figsize=(20, 4))
j = 0

for example in example_dataset.take(40):
  plt.subplot(4, 10, j+1)
  plt.imshow(example['pixels'].numpy(), cmap='gray', aspect='equal')
  plt.axis('off')
  j += 1

png

ตอนนี้ มาดูจำนวนตัวอย่างในไคลเอนต์แต่ละรายสำหรับป้ายกำกับหลัก MNIST แต่ละรายการ ในสภาพแวดล้อมแบบรวมศูนย์ จำนวนตัวอย่างในแต่ละไคลเอ็นต์อาจแตกต่างกันไปเล็กน้อย ขึ้นอยู่กับพฤติกรรมของผู้ใช้

# Number of examples per layer for a sample of clients
f = plt.figure(figsize=(12, 7))
f.suptitle('Label Counts for a Sample of Clients')
for i in range(6):
  client_dataset = emnist_train.create_tf_dataset_for_client(
      emnist_train.client_ids[i])
  plot_data = collections.defaultdict(list)
  for example in client_dataset:
    # Append counts individually per label to make plots
    # more colorful instead of one color per plot.
    label = example['label'].numpy()
    plot_data[label].append(label)
  plt.subplot(2, 3, i+1)
  plt.title('Client {}'.format(i))
  for j in range(10):
    plt.hist(
        plot_data[j],
        density=False,
        bins=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10])

png

ตอนนี้ มาลองนึกภาพค่าเฉลี่ยของอิมเมจต่อไคลเอนต์สำหรับป้ายกำกับ MNIST แต่ละอัน รหัสนี้จะสร้างค่าเฉลี่ยของแต่ละค่าพิกเซลสำหรับตัวอย่างทั้งหมดของผู้ใช้สำหรับป้ายกำกับเดียว เราจะเห็นว่ารูปภาพเฉลี่ยของลูกค้ารายหนึ่งสำหรับตัวเลขจะดูแตกต่างจากรูปภาพเฉลี่ยของลูกค้ารายอื่นสำหรับตัวเลขเดียวกัน เนื่องจากรูปแบบการเขียนด้วยลายมือที่เป็นเอกลักษณ์ของแต่ละคน เราสามารถรำพึงถึงวิธีที่แต่ละรอบการฝึกในท้องถิ่นจะสะกิดโมเดลไปในทิศทางที่แตกต่างกันของลูกค้าแต่ละราย ในขณะที่เรากำลังเรียนรู้จากข้อมูลเฉพาะของผู้ใช้รายนั้นในรอบท้องถิ่นนั้น ต่อมาในบทช่วยสอน เราจะมาดูกันว่าเราจะนำการอัปเดตแต่ละรายการไปยังโมเดลจากลูกค้าทั้งหมดและรวมเข้าด้วยกันในรูปแบบใหม่ทั่วโลกได้อย่างไร ซึ่งได้เรียนรู้จากข้อมูลเฉพาะของลูกค้าแต่ละราย

# Each client has different mean images, meaning each client will be nudging
# the model in their own directions locally.

for i in range(5):
  client_dataset = emnist_train.create_tf_dataset_for_client(
      emnist_train.client_ids[i])
  plot_data = collections.defaultdict(list)
  for example in client_dataset:
    plot_data[example['label'].numpy()].append(example['pixels'].numpy())
  f = plt.figure(i, figsize=(12, 5))
  f.suptitle("Client #{}'s Mean Image Per Label".format(i))
  for j in range(10):
    mean_img = np.mean(plot_data[j], 0)
    plt.subplot(2, 5, j+1)
    plt.imshow(mean_img.reshape((28, 28)))
    plt.axis('off')

png

png

png

png

png

ข้อมูลผู้ใช้อาจมีสัญญาณรบกวนและติดป้ายกำกับที่ไม่น่าเชื่อถือ ตัวอย่างเช่น เมื่อดูข้อมูลของลูกค้า #2 ด้านบน เราจะเห็นได้ว่าสำหรับป้ายกำกับ 2 เป็นไปได้ว่าอาจมีตัวอย่างที่ติดฉลากไม่ถูกต้องซึ่งสร้างภาพที่มีความหมายที่มีเสียงดังกว่า

กำลังประมวลผลข้อมูลอินพุตล่วงหน้า

เนื่องจากข้อมูลที่มีอยู่แล้ว tf.data.Dataset , preprocessing สามารถทำได้โดยใช้การแปลงชุดข้อมูล ที่นี่เราแผ่ 28x28 ภาพเป็น 784 อาร์เรย์องค์ประกอบสับตัวอย่างแต่ละจัดระเบียบให้เป็นแบทช์และเปลี่ยนชื่อคุณสมบัติจาก pixels และ label เพื่อ x และ y สำหรับใช้กับ Keras นอกจากนี้เรายังโยนใน repeat มากกว่าชุดข้อมูลในการทำงานหลายยุคสมัย

NUM_CLIENTS = 10
NUM_EPOCHS = 5
BATCH_SIZE = 20
SHUFFLE_BUFFER = 100
PREFETCH_BUFFER = 10

def preprocess(dataset):

  def batch_format_fn(element):
    """Flatten a batch `pixels` and return the features as an `OrderedDict`."""
    return collections.OrderedDict(
        x=tf.reshape(element['pixels'], [-1, 784]),
        y=tf.reshape(element['label'], [-1, 1]))

  return dataset.repeat(NUM_EPOCHS).shuffle(SHUFFLE_BUFFER, seed=1).batch(
      BATCH_SIZE).map(batch_format_fn).prefetch(PREFETCH_BUFFER)

มาตรวจสอบว่าใช้งานได้

preprocessed_example_dataset = preprocess(example_dataset)

sample_batch = tf.nest.map_structure(lambda x: x.numpy(),
                                     next(iter(preprocessed_example_dataset)))

sample_batch
OrderedDict([('x', array([[1., 1., 1., ..., 1., 1., 1.],
       [1., 1., 1., ..., 1., 1., 1.],
       [1., 1., 1., ..., 1., 1., 1.],
       ...,
       [1., 1., 1., ..., 1., 1., 1.],
       [1., 1., 1., ..., 1., 1., 1.],
       [1., 1., 1., ..., 1., 1., 1.]], dtype=float32)), ('y', array([[2],
       [1],
       [5],
       [7],
       [1],
       [7],
       [7],
       [1],
       [4],
       [7],
       [4],
       [2],
       [2],
       [5],
       [4],
       [1],
       [1],
       [0],
       [0],
       [9]], dtype=int32))])

เรามีหน่วยการสร้างเกือบทั้งหมดเพื่อสร้างชุดข้อมูลแบบรวมศูนย์

หนึ่งในวิธีที่จะเลี้ยงข้อมูล federated จะฉิบหายในการจำลองเป็นเพียงเป็นรายการหลามกับองค์ประกอบของรายการถือข้อมูลของผู้ใช้แต่ละคนแต่ละคนไม่ว่าจะเป็นรายการหรือเป็น tf.data.Dataset เนื่องจากเรามีอินเทอร์เฟสที่ให้มาอยู่แล้ว มาใช้งานกัน

นี่คือฟังก์ชันตัวช่วยง่ายๆ ที่จะสร้างรายการชุดข้อมูลจากกลุ่มผู้ใช้ที่กำหนด เพื่อเป็นข้อมูลป้อนเข้าในการฝึกอบรมหรือการประเมินผล

def make_federated_data(client_data, client_ids):
  return [
      preprocess(client_data.create_tf_dataset_for_client(x))
      for x in client_ids
  ]

แล้วเราจะเลือกลูกค้าอย่างไร?

ในสถานการณ์การฝึกอบรมแบบรวมศูนย์ทั่วไป เรากำลังเผชิญกับอุปกรณ์ผู้ใช้ที่มีประชากรจำนวนมาก ซึ่งอาจมีเพียงเศษเสี้ยวเดียวเท่านั้นที่พร้อมสำหรับการฝึกอบรมในช่วงเวลาที่กำหนด ในกรณีนี้ ตัวอย่างเช่น เมื่ออุปกรณ์ไคลเอ็นต์เป็นโทรศัพท์มือถือที่เข้าร่วมการฝึกอบรมเฉพาะเมื่อเสียบปลั๊กเข้ากับแหล่งพลังงาน ปิดเครือข่ายแบบมีมิเตอร์ หรือไม่ได้ใช้งาน

แน่นอน เราอยู่ในสภาพแวดล้อมการจำลอง และข้อมูลทั้งหมดมีอยู่ในเครื่อง โดยปกติ เมื่อเรียกใช้การจำลอง เราจะสุ่มตัวอย่างกลุ่มย่อยของลูกค้าที่เข้าร่วมการฝึกอบรมแต่ละรอบ ซึ่งโดยทั่วไปจะแตกต่างกันในแต่ละรอบ

ที่กล่าวว่าในขณะที่คุณสามารถหาได้โดยการศึกษากระดาษบน Averaging สหพันธ์ ขั้นตอนวิธีการบรรลุการบรรจบกันในระบบที่มีการย่อยสุ่มของลูกค้าในแต่ละรอบจะใช้เวลาสักครู่และมันจะทำไม่ได้ที่จะมีการเรียกใช้หลายร้อยรอบ บทช่วยสอนแบบโต้ตอบนี้

สิ่งที่เราจะทำคือสุ่มตัวอย่างกลุ่มลูกค้าหนึ่งครั้ง และใช้ชุดเดิมซ้ำตลอดทั้งรอบเพื่อเร่งการบรรจบกัน (ตั้งใจให้เหมาะสมกับข้อมูลของผู้ใช้สองสามรายนี้) เราปล่อยให้มันเป็นแบบฝึกหัดสำหรับผู้อ่านในการปรับเปลี่ยนบทช่วยสอนนี้เพื่อจำลองการสุ่มตัวอย่าง - มันค่อนข้างง่ายที่จะทำ (เมื่อคุณทำแล้ว โปรดจำไว้ว่าการทำให้โมเดลมาบรรจบกันอาจใช้เวลาสักครู่)

sample_clients = emnist_train.client_ids[0:NUM_CLIENTS]

federated_train_data = make_federated_data(emnist_train, sample_clients)

print('Number of client datasets: {l}'.format(l=len(federated_train_data)))
print('First dataset: {d}'.format(d=federated_train_data[0]))
Number of client datasets: 10
First dataset: <DatasetV1Adapter shapes: OrderedDict([(x, (None, 784)), (y, (None, 1))]), types: OrderedDict([(x, tf.float32), (y, tf.int32)])>

การสร้างโมเดลด้วย Keras

หากคุณกำลังใช้ Keras คุณน่าจะมีโค้ดที่สร้างโมเดล Keras อยู่แล้ว ต่อไปนี้คือตัวอย่างโมเดลง่ายๆ ที่เพียงพอสำหรับความต้องการของเรา

def create_keras_model():
  return tf.keras.models.Sequential([
      tf.keras.layers.InputLayer(input_shape=(784,)),
      tf.keras.layers.Dense(10, kernel_initializer='zeros'),
      tf.keras.layers.Softmax(),
  ])

เพื่อที่จะใช้รูปแบบใด ๆ กับฉิบหายจะต้องมีการห่อในอินสแตนซ์ที่ tff.learning.Model อินเตอร์เฟซซึ่งหมายความว่าวิธีการที่จะประทับตราไปข้างหน้าผ่านรูปแบบของสมบัติเมตาดาต้าและอื่น ๆ ในทำนองเดียวกันกับ Keras แต่ยังแนะนำเพิ่มเติม องค์ประกอบต่างๆ เช่น วิธีในการควบคุมกระบวนการคำนวณเมตริกแบบรวมศูนย์ ไม่ต้องกังวลเกี่ยวกับเรื่องนี้ในตอนนี้ ถ้าคุณมีรูปแบบ Keras อย่างหนึ่งที่เราได้กำหนดไว้ข้างต้นเพียงแค่คุณสามารถมีฉิบหายห่อไว้สำหรับคุณโดยเรียก tff.learning.from_keras_model ผ่านรูปแบบและตัวอย่างที่เป็นชุดข้อมูลที่เป็นข้อโต้แย้งที่แสดงด้านล่าง

def model_fn():
  # We _must_ create a new model here, and _not_ capture it from an external
  # scope. TFF will call this within different graph contexts.
  keras_model = create_keras_model()
  return tff.learning.from_keras_model(
      keras_model,
      input_spec=preprocessed_example_dataset.element_spec,
      loss=tf.keras.losses.SparseCategoricalCrossentropy(),
      metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])

การฝึกโมเดลเกี่ยวกับข้อมูลรวม

ตอนนี้เรามีรูปแบบการห่อเป็น tff.learning.Model สำหรับใช้กับฉิบหายเราสามารถให้ฉิบหายสร้างอัลกอริทึม Averaging สหพันธ์โดยเรียกฟังก์ชั่นผู้ช่วย tff.learning.build_federated_averaging_process ดังต่อไปนี้

โปรดทราบว่าการโต้แย้งจะต้องมีคอนสตรัค (เช่น model_fn ข้างต้น) ไม่ได้เป็นตัวอย่างแล้วสร้างเพื่อให้การก่อสร้างของรูปแบบของคุณสามารถเกิดขึ้นได้ในบริบทที่ควบคุมโดยฉิบหาย (ถ้าคุณอยากรู้เกี่ยวกับสาเหตุของการ นี้เราขอแนะนำให้คุณอ่านกวดวิชาติดตาม ขั้นตอนวิธีการที่กำหนดเอง )

หนึ่งบันทึกที่สำคัญในขั้นตอนวิธีการเฉลี่ยสหพันธ์ด้านล่างมี 2 เพิ่มประสิทธิภาพก _client เพิ่มประสิทธิภาพและเพิ่มประสิทธิภาพ _SERVER เพิ่มประสิทธิภาพ _client จะใช้ในการคำนวณการปรับปรุงรูปแบบท้องถิ่นเกี่ยวกับลูกค้าแต่ละราย _SERVER เพิ่มประสิทธิภาพการใช้การปรับปรุงเฉลี่ยกับรูปแบบระดับโลกที่เซิร์ฟเวอร์ โดยเฉพาะอย่างยิ่ง นี่หมายความว่าตัวเลือกของเครื่องมือเพิ่มประสิทธิภาพและอัตราการเรียนรู้ที่ใช้อาจต้องแตกต่างจากที่คุณใช้เพื่อฝึกโมเดลบนชุดข้อมูล id มาตรฐาน เราขอแนะนำให้เริ่มต้นด้วย SGD ปกติ อาจมีอัตราการเรียนรู้น้อยกว่าปกติ อัตราการเรียนรู้ที่เราใช้ยังไม่ได้รับการปรับอย่างระมัดระวัง โปรดทดลอง

iterative_process = tff.learning.build_federated_averaging_process(
    model_fn,
    client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=0.02),
    server_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=1.0))

เกิดอะไรขึ้น? ฉิบหายได้สร้างคู่ของการคำนวณแบบ federated และบรรจุไว้ใน tff.templates.IterativeProcess ซึ่งคำนวณเหล่านี้มีอยู่เป็นคู่ของคุณสมบัติ initialize และ next

สรุปการคำนวณแบบ federated เป็นโปรแกรมในภาษาภายในฉิบหายที่สามารถแสดงขั้นตอนวิธีแบบ federated ต่างๆ (คุณสามารถหาข้อมูลเพิ่มเติมเกี่ยวกับเรื่องนี้ใน ที่กำหนดเองขั้นตอนวิธีการ กวดวิชา) ในกรณีนี้ทั้งสองคำนวณสร้างและบรรจุลงใน iterative_process ใช้ สหพันธ์ Averaging

เป็นเป้าหมายของ TFF ในการกำหนดการคำนวณในลักษณะที่สามารถดำเนินการได้ในการตั้งค่าการเรียนรู้แบบสหพันธรัฐจริง แต่ปัจจุบันมีการใช้รันไทม์การจำลองการดำเนินการในเครื่องเท่านั้น ในการดำเนินการคำนวณในโปรแกรมจำลอง คุณเพียงแค่เรียกใช้มันเหมือนกับฟังก์ชัน Python สภาพแวดล้อมที่ตีความโดยค่าเริ่มต้นนี้ไม่ได้ออกแบบมาสำหรับประสิทธิภาพสูง แต่จะเพียงพอสำหรับบทช่วยสอนนี้ เราคาดว่าจะให้รันไทม์การจำลองประสิทธิภาพสูงขึ้นเพื่ออำนวยความสะดวกในการวิจัยขนาดใหญ่ขึ้นในรุ่นต่อๆ ไป

เริ่มต้น Let 's กับ initialize ในการคำนวณ เช่นเดียวกับการคำนวณแบบรวมศูนย์ทั้งหมด คุณสามารถคิดว่ามันเป็นฟังก์ชันได้ การคำนวณไม่มีอาร์กิวเมนต์ และส่งคืนผลลัพธ์หนึ่งรายการ - การแสดงสถานะของกระบวนการ Federated Averaging บนเซิร์ฟเวอร์ แม้ว่าเราไม่ต้องการเจาะลึกในรายละเอียดของ TFF แต่อาจเป็นประโยชน์ในการดูว่าสถานะนี้เป็นอย่างไร คุณสามารถเห็นภาพได้ดังนี้

str(iterative_process.initialize.type_signature)
'( -> <model=<trainable=<float32[784,10],float32[10]>,non_trainable=<>>,optimizer_state=<int64>,delta_aggregate_state=<value_sum_process=<>,weight_sum_process=<>>,model_broadcast_state=<>>@SERVER)'

ในขณะที่ลายเซ็นประเภทดังกล่าวข้างต้นในตอนแรกอาจดูเหมือนคลุมเครือบิตคุณสามารถรับรู้ว่าสถานะของเซิร์ฟเวอร์ที่ประกอบด้วย model (พารามิเตอร์รุ่นเริ่มต้นสำหรับ MNIST ที่จะแจกจ่ายให้กับอุปกรณ์ทั้งหมด) และ optimizer_state (ข้อมูลเพิ่มเติมการเก็บรักษาโดยเซิร์ฟเวอร์ เช่น จำนวนรอบที่ใช้สำหรับกำหนดการไฮเปอร์พารามิเตอร์ เป็นต้น)

ขอเรียก initialize ในการคำนวณเพื่อสร้างรัฐเซิร์ฟเวอร์

state = iterative_process.initialize()

ที่สองคู่ของการคำนวณแบบ federated ที่ next หมายถึงรอบเดียวของสหพันธ์ Averaging ซึ่งประกอบด้วยการผลักดันสถานะของเซิร์ฟเวอร์ (รวมทั้งพารามิเตอร์แบบ) ให้กับลูกค้า, การฝึกอบรมในอุปกรณ์บนข้อมูลท้องถิ่นของพวกเขาการจัดเก็บภาษีและการปรับปรุงรูปแบบเฉลี่ย , และผลิตโมเดลที่อัปเดตใหม่ที่เซิร์ฟเวอร์

แนวคิดที่คุณสามารถคิด next ว่ามีลายเซ็นประเภทการทำงานที่มีลักษณะดังต่อไปนี้

SERVER_STATE, FEDERATED_DATA -> SERVER_STATE, TRAINING_METRICS

โดยเฉพาะอย่างยิ่งหนึ่งควรคิดเกี่ยวกับการ next() ไม่ได้เป็นฟังก์ชั่นที่รันบนเซิร์ฟเวอร์ แต่เป็นตัวแทนของการทำงานที่เปิดเผยของการคำนวณการกระจายอำนาจทั้งหมด - บางส่วนของปัจจัยการผลิตที่มีให้โดยเซิร์ฟเวอร์ ( SERVER_STATE ) แต่แต่ละคนที่เข้าร่วมโครงการ อุปกรณ์มีส่วนสนับสนุนชุดข้อมูลในเครื่องของตัวเอง

มาฝึกรอบเดียวและเห็นภาพผลลัพธ์ เราสามารถใช้ข้อมูลรวมที่เราได้สร้างไว้ข้างต้นสำหรับกลุ่มตัวอย่างผู้ใช้

state, metrics = iterative_process.next(state, federated_train_data)
print('round  1, metrics={}'.format(metrics))
round  1, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.12345679), ('loss', 3.1193738)])), ('stat', OrderedDict([('num_examples', 4860)]))])

มาลุยกันอีกหลายรอบ ดังที่ได้กล่าวไว้ก่อนหน้านี้ โดยทั่วไป ณ จุดนี้ คุณจะเลือกชุดย่อยของข้อมูลการจำลองของคุณจากกลุ่มตัวอย่างที่สุ่มเลือกใหม่ของผู้ใช้สำหรับแต่ละรอบ เพื่อจำลองการใช้งานจริงที่ผู้ใช้เข้าๆ ออกๆ อย่างต่อเนื่อง แต่ในสมุดบันทึกแบบโต้ตอบนี้ เพื่อการสาธิต เราจะใช้ผู้ใช้เดิมซ้ำ เพื่อให้ระบบมาบรรจบกันอย่างรวดเร็ว

NUM_ROUNDS = 11
for round_num in range(2, NUM_ROUNDS):
  state, metrics = iterative_process.next(state, federated_train_data)
  print('round {:2d}, metrics={}'.format(round_num, metrics))
round  2, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.13518518), ('loss', 2.9834728)])), ('stat', OrderedDict([('num_examples', 4860)]))])
round  3, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.14382716), ('loss', 2.861665)])), ('stat', OrderedDict([('num_examples', 4860)]))])
round  4, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.17407407), ('loss', 2.7957022)])), ('stat', OrderedDict([('num_examples', 4860)]))])
round  5, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.19917695), ('loss', 2.6146567)])), ('stat', OrderedDict([('num_examples', 4860)]))])
round  6, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.21975309), ('loss', 2.529761)])), ('stat', OrderedDict([('num_examples', 4860)]))])
round  7, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.2409465), ('loss', 2.4053504)])), ('stat', OrderedDict([('num_examples', 4860)]))])
round  8, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.2611111), ('loss', 2.315389)])), ('stat', OrderedDict([('num_examples', 4860)]))])
round  9, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.30823046), ('loss', 2.1240263)])), ('stat', OrderedDict([('num_examples', 4860)]))])
round 10, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.33312756), ('loss', 2.1164262)])), ('stat', OrderedDict([('num_examples', 4860)]))])

การสูญเสียการฝึกจะลดลงหลังจากการฝึกแบบสหพันธรัฐแต่ละรอบ ซึ่งบ่งชี้ว่าโมเดลกำลังบรรจบกัน มีบางประการที่สำคัญกับตัวชี้วัดการฝึกอบรมเหล่านี้มี แต่ให้ดูส่วนที่เกี่ยวกับการประเมินผลต่อไปในการกวดวิชานี้

การแสดงเมทริกโมเดลใน TensorBoard

ต่อไป มาดูเมทริกจากการคำนวณแบบรวมศูนย์เหล่านี้โดยใช้ Tensorboard

เริ่มต้นด้วยการสร้างไดเร็กทอรีและตัวเขียนสรุปที่เกี่ยวข้องเพื่อเขียนเมตริก

logdir = "/tmp/logs/scalars/training/"
summary_writer = tf.summary.create_file_writer(logdir)
state = iterative_process.initialize()

พล็อตเมทริกสเกลาร์ที่เกี่ยวข้องด้วยตัวเขียนสรุปเดียวกัน

with summary_writer.as_default():
  for round_num in range(1, NUM_ROUNDS):
    state, metrics = iterative_process.next(state, federated_train_data)
    for name, value in metrics['train'].items():
      tf.summary.scalar(name, value, step=round_num)

เริ่ม TensorBoard ด้วยไดเร็กทอรีบันทึกของรูทที่ระบุด้านบน อาจใช้เวลาสองสามวินาทีในการโหลดข้อมูล

!ls {logdir}
%tensorboard --logdir {logdir} --port=0
events.out.tfevents.1629557449.ebe6e776479e64ea-4903924a278.borgtask.google.com.458912.1.v2
Launching TensorBoard...
Reusing TensorBoard on port 50681 (pid 292785), started 0:30:30 ago. (Use '!kill 292785' to kill it.)
<IPython.core.display.Javascript at 0x7fd6617e02d0>
# Uncomment and run this this cell to clean your directory of old output for
# future graphs from this directory. We don't run it by default so that if 
# you do a "Runtime > Run all" you don't lose your results.

# !rm -R /tmp/logs/scalars/*

ในการดูเมตริกการประเมินด้วยวิธีเดียวกัน คุณสามารถสร้างโฟลเดอร์ eval แยกต่างหาก เช่น "logs/scalars/eval" เพื่อเขียนไปยัง TensorBoard

การปรับแต่งการใช้งานโมเดล

Keras เป็น ระดับสูงรุ่น API ที่แนะนำสำหรับ TensorFlow และเราขอแนะนำให้ใช้แบบจำลอง Keras (ผ่าน tff.learning.from_keras_model ) ในฉิบหายเมื่อใดก็ตามที่เป็นไปได้

อย่างไรก็ตาม tff.learning ให้อินเตอร์เฟซรูปแบบในระดับที่ต่ำกว่า tff.learning.Model ที่ exposes ฟังก์ชันการทำงานที่น้อยที่สุดที่จำเป็นสำหรับการใช้รูปแบบการเรียนรู้แบบ federated โดยตรงใช้อินเตอร์เฟซนี้ (อาจจะยังคงใช้การสร้างบล็อกเช่น tf.keras.layers ) ช่วยให้ปรับแต่งได้สูงสุดโดยไม่ต้องแก้ไข internals ของขั้นตอนวิธีการเรียนรู้แบบ federated

ลองทำใหม่ทั้งหมดตั้งแต่ต้น

การกำหนดตัวแปรโมเดล การส่งต่อ และตัวชี้วัด

ขั้นตอนแรกคือการระบุตัวแปร TensorFlow ที่เราจะใช้งาน เพื่อให้โค้ดต่อไปนี้อ่านง่ายขึ้น มากำหนดโครงสร้างข้อมูลเพื่อแสดงทั้งชุด ซึ่งจะรวมถึงตัวแปรเช่น weights และ bias ที่เราจะฝึกเช่นเดียวกับตัวแปรที่จะถือสถิติต่างๆที่สะสมและเคาน์เตอร์เราจะปรับปรุงในระหว่างการฝึกเช่น loss_sum , accuracy_sum และ num_examples

MnistVariables = collections.namedtuple(
    'MnistVariables', 'weights bias num_examples loss_sum accuracy_sum')

นี่คือวิธีการที่สร้างตัวแปร เพื่อประโยชน์ของความเรียบง่ายเราเป็นตัวแทนสถิติทั้งหมดเป็น tf.float32 เป็นที่จะขจัดความจำเป็นในการแปลงชนิดในระยะต่อมา ห่อ initializers ตัวแปร lambdas เป็นความต้องการที่กำหนดโดย ตัวแปรทรัพยากร

def create_mnist_variables():
  return MnistVariables(
      weights=tf.Variable(
          lambda: tf.zeros(dtype=tf.float32, shape=(784, 10)),
          name='weights',
          trainable=True),
      bias=tf.Variable(
          lambda: tf.zeros(dtype=tf.float32, shape=(10)),
          name='bias',
          trainable=True),
      num_examples=tf.Variable(0.0, name='num_examples', trainable=False),
      loss_sum=tf.Variable(0.0, name='loss_sum', trainable=False),
      accuracy_sum=tf.Variable(0.0, name='accuracy_sum', trainable=False))

ด้วยตัวแปรสำหรับพารามิเตอร์แบบจำลองและสถิติสะสม ตอนนี้เราสามารถกำหนดวิธีการส่งต่อที่คำนวณการสูญเสีย ปล่อยการคาดการณ์ และอัปเดตสถิติสะสมสำหรับชุดข้อมูลอินพุตชุดเดียวดังนี้

def predict_on_batch(variables, x):
  return tf.nn.softmax(tf.matmul(x, variables.weights) + variables.bias)

def mnist_forward_pass(variables, batch):
  y = predict_on_batch(variables, batch['x'])
  predictions = tf.cast(tf.argmax(y, 1), tf.int32)

  flat_labels = tf.reshape(batch['y'], [-1])
  loss = -tf.reduce_mean(
      tf.reduce_sum(tf.one_hot(flat_labels, 10) * tf.math.log(y), axis=[1]))
  accuracy = tf.reduce_mean(
      tf.cast(tf.equal(predictions, flat_labels), tf.float32))

  num_examples = tf.cast(tf.size(batch['y']), tf.float32)

  variables.num_examples.assign_add(num_examples)
  variables.loss_sum.assign_add(loss * num_examples)
  variables.accuracy_sum.assign_add(accuracy * num_examples)

  return loss, predictions

ต่อไป เรากำหนดฟังก์ชันที่ส่งคืนชุดของตัววัดในเครื่อง อีกครั้งโดยใช้ TensorFlow ค่าเหล่านี้เป็นค่า (นอกเหนือจากการอัพเดตโมเดล ซึ่งได้รับการจัดการโดยอัตโนมัติ) ที่มีสิทธิ์รวมเข้ากับเซิร์ฟเวอร์ในกระบวนการเรียนรู้หรือการประเมินแบบรวมศูนย์

ที่นี่เราก็กลับค่าเฉลี่ยของ loss และ accuracy , เช่นเดียวกับ num_examples ซึ่งเราจะต้องถูกต้องน้ำหนักผลงานจากผู้ใช้งานที่แตกต่างกันเมื่อคำนวณมวล federated

def get_local_mnist_metrics(variables):
  return collections.OrderedDict(
      num_examples=variables.num_examples,
      loss=variables.loss_sum / variables.num_examples,
      accuracy=variables.accuracy_sum / variables.num_examples)

สุดท้ายเราจะต้องกำหนดวิธีการที่จะรวมตัวชี้วัดในท้องถิ่นที่ปล่อยออกมาโดยผ่านทางอุปกรณ์แต่ละ get_local_mnist_metrics นี้เป็นเพียงส่วนหนึ่งของรหัสที่ไม่ได้เขียนใน TensorFlow - มันเป็นคำนวณ federated แสดงในฉิบหาย หากคุณต้องการที่จะขุดลึกหางมากกว่า ที่กำหนดเองขั้นตอนวิธีการ กวดวิชา แต่ในการใช้งานมากที่สุดที่คุณจะไม่จำเป็นจริงๆที่จะ; รูปแบบต่างๆ ที่แสดงด้านล่างน่าจะเพียงพอแล้ว นี่คือสิ่งที่ดูเหมือน:

@tff.federated_computation
def aggregate_mnist_metrics_across_clients(metrics):
  return collections.OrderedDict(
      num_examples=tff.federated_sum(metrics.num_examples),
      loss=tff.federated_mean(metrics.loss, metrics.num_examples),
      accuracy=tff.federated_mean(metrics.accuracy, metrics.num_examples))

การป้อนข้อมูล metrics สอดคล้องอาร์กิวเมนต์ OrderedDict กลับโดย get_local_mnist_metrics ข้างต้น แต่วิกฤตค่าจะไม่ tf.Tensors - พวกเขาจะ "บรรจุกล่อง" เป็น tff.Value s ที่จะทำให้มันชัดเจนคุณจะไม่สามารถจัดการกับพวกเขาโดยใช้ TensorFlow แต่เพียง ผู้ประกอบการใช้ federated ฉิบหายเหมือน tff.federated_mean และ tff.federated_sum พจนานุกรมที่ส่งคืนของการรวมโกลบอลกำหนดชุดของเมทริกซึ่งจะพร้อมใช้งานบนเซิร์ฟเวอร์

สร้างตัวอย่างของ tff.learning.Model

ด้วยทั้งหมดที่กล่าวมานี้ เราพร้อมที่จะสร้างการแสดงโมเดลสำหรับใช้กับ TFF ที่คล้ายกับที่สร้างขึ้นสำหรับคุณเมื่อคุณปล่อยให้ TFF นำเข้าโมเดล Keras

from typing import Callable, List, OrderedDict

class MnistModel(tff.learning.Model):

  def __init__(self):
    self._variables = create_mnist_variables()

  @property
  def trainable_variables(self):
    return [self._variables.weights, self._variables.bias]

  @property
  def non_trainable_variables(self):
    return []

  @property
  def local_variables(self):
    return [
        self._variables.num_examples, self._variables.loss_sum,
        self._variables.accuracy_sum
    ]

  @property
  def input_spec(self):
    return collections.OrderedDict(
        x=tf.TensorSpec([None, 784], tf.float32),
        y=tf.TensorSpec([None, 1], tf.int32))

  @tf.function
  def predict_on_batch(self, x, training=True):
    del training
    return predict_on_batch(self._variables, x)

  @tf.function
  def forward_pass(self, batch, training=True):
    del training
    loss, predictions = mnist_forward_pass(self._variables, batch)
    num_exmaples = tf.shape(batch['x'])[0]
    return tff.learning.BatchOutput(
        loss=loss, predictions=predictions, num_examples=num_exmaples)

  @tf.function
  def report_local_outputs(self):
    return get_local_mnist_metrics(self._variables)

  @property
  def federated_output_computation(self):
    return aggregate_mnist_metrics_across_clients

  @tf.function
  def report_local_unfinalized_metrics(
      self) -> OrderedDict[str, List[tf.Tensor]]:
    """Creates an `OrderedDict` of metric names to unfinalized values."""
    return collections.OrderedDict(
        num_examples=[self._variables.num_examples],
        loss=[self._variables.loss_sum, self._variables.num_examples],
        accuracy=[self._variables.accuracy_sum, self._variables.num_examples])

  def metric_finalizers(
      self) -> OrderedDict[str, Callable[[List[tf.Tensor]], tf.Tensor]]:
    """Creates an `OrderedDict` of metric names to finalizers."""
    return collections.OrderedDict(
        num_examples=tf.function(func=lambda x: x[0]),
        loss=tf.function(func=lambda x: x[0] / x[1]),
        accuracy=tf.function(func=lambda x: x[0] / x[1]))

ในขณะที่คุณสามารถดูวิธีการที่เป็นนามธรรมและคุณสมบัติที่กำหนดโดย tff.learning.Model สอดคล้องกับข้อมูลโค้ดในส่วนก่อนหน้านี้ที่นำตัวแปรและกำหนดความสูญเสียและสถิติ

ต่อไปนี้คือประเด็นสำคัญบางประการที่ควรค่าแก่การเน้นย้ำ:

  • รัฐทั้งหมดว่ารูปแบบของคุณจะใช้จะต้องถูกจับเป็นตัวแปร TensorFlow เป็นฉิบหายไม่ได้ใช้งูหลามที่รันไทม์ (จำรหัสของคุณควรจะเขียนดังกล่าวว่าสามารถนำไปใช้กับอุปกรณ์มือถือดู ที่กำหนดเองขั้นตอนวิธีการ สอนสำหรับข้อมูลเพิ่มเติมในเชิงลึก ชี้แจงเหตุผล)
  • โมเดลของคุณควรจะอธิบายสิ่งที่รูปแบบของข้อมูลจะยอมรับ ( input_spec ) ในขณะที่โดยทั่วไปฉิบหายเป็นสภาพแวดล้อมการพิมพ์อย่างยิ่งและต้องการที่จะตรวจสอบลายเซ็นประเภทสำหรับทุกส่วน การประกาศรูปแบบอินพุตของโมเดลของคุณเป็นส่วนสำคัญ
  • แม้ว่าในทางเทคนิคไม่จำเป็นเราขอแนะนำให้ตัดตรรกะ TensorFlow ทั้งหมด (ไปข้างหน้าผ่านการคำนวณตัวชี้วัดอื่น ๆ ) เป็น tf.function s เช่นนี้จะช่วยให้แน่ใจว่า TensorFlow สามารถต่อเนื่องและลบความจำเป็นในการควบคุมการอ้างอิงที่ชัดเจน

ข้างต้นเพียงพอสำหรับการประเมินและอัลกอริทึมเช่น Federated SGD อย่างไรก็ตาม สำหรับ Federated Averaging เราจำเป็นต้องระบุว่าโมเดลควรฝึกแบบโลคัลในแต่ละแบตช์อย่างไร เราจะระบุเครื่องมือเพิ่มประสิทธิภาพในพื้นที่เมื่อสร้างอัลกอริทึม Federated Averaging

จำลองการฝึกแบบสหพันธรัฐด้วยโมเดลใหม่

จากทั้งหมดที่กล่าวมาข้างต้น กระบวนการที่เหลือจะดูเหมือนสิ่งที่เราได้เห็นแล้ว - เพียงแค่แทนที่ตัวสร้างแบบจำลองด้วยตัวสร้างของคลาสแบบจำลองใหม่ของเรา และใช้การคำนวณแบบรวมสองส่วนในกระบวนการวนซ้ำที่คุณสร้างขึ้นเพื่อหมุนเวียน รอบการฝึก

iterative_process = tff.learning.build_federated_averaging_process(
    MnistModel,
    client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=0.02))
state = iterative_process.initialize()
state, metrics = iterative_process.next(state, federated_train_data)
print('round  1, metrics={}'.format(metrics))
round  1, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('num_examples', 4860.0), ('loss', 3.0708053), ('accuracy', 0.12777779)])), ('stat', OrderedDict([('num_examples', 4860)]))])
for round_num in range(2, 11):
  state, metrics = iterative_process.next(state, federated_train_data)
  print('round {:2d}, metrics={}'.format(round_num, metrics))
round  2, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('num_examples', 4860.0), ('loss', 3.011699), ('accuracy', 0.13024691)])), ('stat', OrderedDict([('num_examples', 4860)]))])
round  3, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('num_examples', 4860.0), ('loss', 2.7408307), ('accuracy', 0.15576132)])), ('stat', OrderedDict([('num_examples', 4860)]))])
round  4, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('num_examples', 4860.0), ('loss', 2.6761012), ('accuracy', 0.17921811)])), ('stat', OrderedDict([('num_examples', 4860)]))])
round  5, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('num_examples', 4860.0), ('loss', 2.675567), ('accuracy', 0.1855967)])), ('stat', OrderedDict([('num_examples', 4860)]))])
round  6, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('num_examples', 4860.0), ('loss', 2.5664043), ('accuracy', 0.20329218)])), ('stat', OrderedDict([('num_examples', 4860)]))])
round  7, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('num_examples', 4860.0), ('loss', 2.4179392), ('accuracy', 0.24382716)])), ('stat', OrderedDict([('num_examples', 4860)]))])
round  8, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('num_examples', 4860.0), ('loss', 2.3237286), ('accuracy', 0.26687244)])), ('stat', OrderedDict([('num_examples', 4860)]))])
round  9, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('num_examples', 4860.0), ('loss', 2.1861682), ('accuracy', 0.28209877)])), ('stat', OrderedDict([('num_examples', 4860)]))])
round 10, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('num_examples', 4860.0), ('loss', 2.046388), ('accuracy', 0.32037038)])), ('stat', OrderedDict([('num_examples', 4860)]))])

หากต้องการดูเมตริกเหล่านี้ภายใน TensorBoard ให้อ้างอิงขั้นตอนที่แสดงด้านบนใน "การแสดงเมตริกโมเดลใน TensorBoard"

การประเมิน

การทดลองทั้งหมดของเราจนถึงขณะนี้นำเสนอเฉพาะเมตริกการฝึกอบรมแบบรวมศูนย์เท่านั้น ซึ่งเป็นเมตริกเฉลี่ยของข้อมูลทุกชุดที่ได้รับการฝึกฝนจากไคลเอ็นต์ทั้งหมดในรอบนี้ สิ่งนี้ทำให้เกิดความกังวลตามปกติเกี่ยวกับการใส่มากเกินไป โดยเฉพาะอย่างยิ่งเนื่องจากเราใช้ไคลเอนต์ชุดเดียวกันในแต่ละรอบเพื่อความเรียบง่าย แต่มีแนวคิดเพิ่มเติมเกี่ยวกับการปรับให้เหมาะสมมากเกินไปในเมตริกการฝึกอบรมเฉพาะสำหรับอัลกอริธึม Federated Averaging วิธีนี้ง่ายที่สุดที่จะดูว่าเราคิดว่าลูกค้าแต่ละรายมีชุดข้อมูลชุดเดียวหรือไม่ และเราฝึกชุดนั้นสำหรับการทำซ้ำหลายๆ ครั้ง (ยุค) ในกรณีนี้ โมเดลในพื้นที่จะพอดีกับชุดงานนั้นอย่างรวดเร็ว ดังนั้นตัวชี้วัดความแม่นยำในพื้นที่ที่เราเฉลี่ยจะเข้าใกล้ 1.0 ดังนั้น ตัวชี้วัดการฝึกอบรมเหล่านี้จึงเป็นสัญญาณว่าการฝึกอบรมมีความคืบหน้า แต่ไม่มากไปกว่านั้น

เพื่อดำเนินการประเมินผลข้อมูลแบบ federated คุณสามารถสร้างการคำนวณแบบ federated อื่นออกแบบมาเพียงเพื่อการนี้โดยใช้ tff.learning.build_federated_evaluation ฟังก์ชั่นและผ่านในตัวสร้างแบบจำลองของคุณเป็นอาร์กิวเมนต์ โปรดทราบว่าแตกต่างกับสหพันธ์ Averaging ที่เราเคยใช้ MnistTrainableModel ก็พอเพียงที่จะผ่านการ MnistModel การประเมินไม่ได้ดำเนินการแบบไล่ระดับ และไม่จำเป็นต้องสร้างเครื่องมือเพิ่มประสิทธิภาพ

สำหรับการทดลองและการวิจัยเมื่อชุดทดสอบส่วนกลางใช้ได้ สหพันธ์เรียนรู้สำหรับข้อความ Generation แสดงให้เห็นถึงการประเมินผลตัวเลือกอื่น: การยกน้ำหนักได้รับการฝึกฝนจากการเรียนรู้แบบ federated นำมาใช้เป็นรูปแบบ Keras มาตรฐานและจากนั้นก็เรียก tf.keras.models.Model.evaluate() ในชุดข้อมูลที่ส่วนกลาง

evaluation = tff.learning.build_federated_evaluation(MnistModel)

คุณสามารถตรวจสอบลายเซ็นประเภทนามธรรมของฟังก์ชันการประเมินได้ดังนี้

str(evaluation.type_signature)
'(<server_model_weights=<trainable=<float32[784,10],float32[10]>,non_trainable=<>>@SERVER,federated_dataset={<x=float32[?,784],y=int32[?,1]>*}@CLIENTS> -> <eval=<num_examples=float32,loss=float32,accuracy=float32>,stat=<num_examples=int64>>@SERVER)'

ไม่จำเป็นต้องมีความกังวลเกี่ยวกับรายละเอียดในจุดนี้เพียงแค่ทราบว่าจะใช้เวลาในรูปแบบทั่วไปดังต่อไปนี้คล้ายกับ tff.templates.IterativeProcess.next แต่ด้วยความแตกต่างที่สำคัญสอง ประการแรก เราจะไม่ส่งคืนสถานะของเซิร์ฟเวอร์ เนื่องจากการประเมินไม่ได้แก้ไขโมเดลหรือสถานะอื่นใด - คุณสามารถมองว่าเป็นสถานะไร้สัญชาติได้ ประการที่สอง การประเมินต้องการเพียงโมเดล และไม่จำเป็นต้องมีส่วนอื่นๆ ของสถานะเซิร์ฟเวอร์ที่อาจเกี่ยวข้องกับการฝึกอบรม เช่น ตัวแปรตัวเพิ่มประสิทธิภาพ

SERVER_MODEL, FEDERATED_DATA -> TRAINING_METRICS

มาเรียกใช้การประเมินเกี่ยวกับสถานะล่าสุดที่เรามาถึงระหว่างการฝึกอบรมกัน เพื่อดึงรูปแบบการฝึกอบรมล่าสุดจากสถานะของเซิร์ฟเวอร์คุณก็เข้าถึง .model สมาชิกดังต่อไปนี้

train_metrics = evaluation(state.model, federated_train_data)

นี่คือสิ่งที่เราได้รับ โปรดทราบว่าตัวเลขดูดีกว่าที่รายงานโดยการฝึกอบรมรอบที่แล้วด้านบนเล็กน้อย ตามแบบแผน เมตริกการฝึกอบรมที่รายงานโดยกระบวนการฝึกอบรมแบบวนซ้ำมักจะสะท้อนประสิทธิภาพของแบบจำลองเมื่อเริ่มต้นรอบการฝึกอบรม ดังนั้น เมตริกการประเมินจึงนำหน้าหนึ่งก้าวเสมอ

str(train_metrics)
"OrderedDict([('eval', OrderedDict([('num_examples', 4860.0), ('loss', 1.7510437), ('accuracy', 0.2788066)])), ('stat', OrderedDict([('num_examples', 4860)]))])"

ตอนนี้ มารวบรวมตัวอย่างทดสอบของข้อมูลรวมและรันการประเมินข้อมูลการทดสอบอีกครั้ง ข้อมูลจะมาจากกลุ่มตัวอย่างผู้ใช้จริงกลุ่มเดียวกัน แต่มาจากชุดข้อมูลที่แยกออกมาต่างหาก

federated_test_data = make_federated_data(emnist_test, sample_clients)

len(federated_test_data), federated_test_data[0]
(10,
 <DatasetV1Adapter shapes: OrderedDict([(x, (None, 784)), (y, (None, 1))]), types: OrderedDict([(x, tf.float32), (y, tf.int32)])>)
test_metrics = evaluation(state.model, federated_test_data)
str(test_metrics)
"OrderedDict([('eval', OrderedDict([('num_examples', 580.0), ('loss', 1.8361608), ('accuracy', 0.2413793)])), ('stat', OrderedDict([('num_examples', 580)]))])"

นี้สรุปการกวดวิชา เราขอแนะนำให้คุณเล่นกับพารามิเตอร์ (เช่น ขนาดแบทช์ จำนวนผู้ใช้ ยุค อัตราการเรียนรู้ ฯลฯ) เพื่อแก้ไขโค้ดด้านบนเพื่อจำลองการฝึกสุ่มตัวอย่างผู้ใช้ในแต่ละรอบ และเพื่อสำรวจบทช่วยสอนอื่นๆ เราได้พัฒนา