บันทึกวันที่! Google I / O ส่งคืนวันที่ 18-20 พฤษภาคม ลงทะเบียนตอนนี้
หน้านี้ได้รับการแปลโดย Cloud Translation API
Switch to English

การเรียนรู้แบบรวมศูนย์สำหรับการจำแนกประเภทรูปภาพ

ดูใน TensorFlow.org เรียกใช้ใน Google Colab ดูแหล่งที่มาบน GitHub ดาวน์โหลดสมุดบันทึก

ในบทช่วยสอนนี้เราใช้ตัวอย่างการฝึกอบรม MNIST แบบคลาสสิกเพื่อแนะนำเลเยอร์ Federated Learning (FL) API ของ TFF, tff.learning ซึ่งเป็นชุดของอินเทอร์เฟซระดับสูงที่สามารถใช้เพื่อดำเนินการประเภททั่วไปของงานการเรียนรู้แบบรวมเช่น การฝึกอบรมแบบรวมศูนย์เทียบกับโมเดลที่ผู้ใช้จัดหามาใช้ใน TensorFlow

บทช่วยสอนนี้และ Federated Learning API มีไว้สำหรับผู้ใช้ที่ต้องการเสียบโมเดล TensorFlow ของตนเองเข้ากับ TFF เป็นหลักโดยถือว่าส่วนใหญ่เป็นกล่องดำ สำหรับความเข้าใจในเชิงลึกมากขึ้นเกี่ยวกับ TFF และวิธีใช้อัลกอริทึมการเรียนรู้แบบรวมศูนย์ของคุณเองโปรดดูบทช่วยสอนเกี่ยวกับ FC Core API - Custom Federated Algorithms ตอนที่ 1 และ ตอนที่ 2

สำหรับข้อมูลเพิ่มเติมเกี่ยวกับ tff.learning ให้ดำเนินการต่อด้วย Federated Learning for Text Generation บทช่วยสอนซึ่งนอกจากจะครอบคลุมโมเดลที่เกิดซ้ำแล้วยังสาธิตการโหลดโมเดล Keras แบบอนุกรมที่ผ่านการฝึกอบรมมาแล้วสำหรับการปรับแต่งด้วยการเรียนรู้แบบรวมศูนย์รวมกับการประเมินโดยใช้ Keras

ก่อนที่เราจะเริ่ม

ก่อนที่เราจะเริ่มโปรดเรียกใช้สิ่งต่อไปนี้เพื่อให้แน่ใจว่าสภาพแวดล้อมของคุณได้รับการตั้งค่าอย่างถูกต้อง หากคุณไม่เห็นคำทักทายโปรดดูคำแนะนำในคู่มือการ ติดตั้ง

# tensorflow_federated_nightly also bring in tf_nightly, which
# can causes a duplicate tensorboard install, leading to errors.
!pip uninstall --yes tensorboard tb-nightly

!pip install --quiet --upgrade tensorflow-federated-nightly
!pip install --quiet --upgrade nest-asyncio
!pip install --quiet --upgrade tb-nightly  # or tensorboard, but not both

import nest_asyncio
nest_asyncio.apply()
%load_ext tensorboard
Fetching TensorBoard MPM... done.
import collections

import numpy as np
import tensorflow as tf
import tensorflow_federated as tff

np.random.seed(0)

tff.federated_computation(lambda: 'Hello, World!')()
b'Hello, World!'

กำลังเตรียมข้อมูลอินพุต

เริ่มต้นด้วยข้อมูล การเรียนรู้แบบรวมต้องใช้ชุดข้อมูลแบบรวมเช่นการรวบรวมข้อมูลจากผู้ใช้หลายคน ข้อมูลสหพันธ์เป็นปกติไม่ใช่ IID ซึ่งส่อเค้าชุดที่เป็นเอกลักษณ์ของความท้าทาย

เพื่ออำนวยความสะดวกในการทดลองเราเริ่มต้นที่เก็บ TFF ด้วยชุดข้อมูลสองสามชุดรวมถึง MNIST เวอร์ชันรวมที่มีเวอร์ชันของ ชุดข้อมูล NIST ดั้งเดิม ที่ได้รับการประมวลผลใหม่โดยใช้ Leaf เพื่อให้ข้อมูลถูกป้อนโดยผู้เขียนดั้งเดิมของ ตัวเลข เนื่องจากนักเขียนแต่ละคนมีสไตล์ที่ไม่ซ้ำกันชุดข้อมูลนี้จึงแสดงพฤติกรรมที่ไม่ใช่ iid ที่คาดหวังจากชุดข้อมูลแบบรวม

นี่คือวิธีที่เราสามารถโหลดได้

emnist_train, emnist_test = tff.simulation.datasets.emnist.load_data()

ชุดข้อมูลที่ส่งคืนโดย load_data() เป็นอินสแตนซ์ของ tff.simulation.ClientData ซึ่งเป็นอินเทอร์เฟซที่อนุญาตให้คุณระบุชุดของผู้ใช้เพื่อสร้างtf.data.Dataset ชุดข้อมูลที่แสดงถึงข้อมูลของผู้ใช้เฉพาะและเพื่อค้นหา โครงสร้างของแต่ละองค์ประกอบ นี่คือวิธีที่คุณสามารถใช้อินเทอร์เฟซนี้เพื่อสำรวจเนื้อหาของชุดข้อมูล โปรดทราบว่าแม้ว่าอินเทอร์เฟซนี้จะช่วยให้คุณสามารถทำซ้ำรหัสไคลเอ็นต์ได้ แต่นี่เป็นเพียงคุณลักษณะของข้อมูลจำลองเท่านั้น ดังที่คุณจะเห็นในไม่ช้านี้ไม่มีการใช้ข้อมูลประจำตัวของไคลเอ็นต์โดยกรอบการเรียนรู้แบบรวมจุดประสงค์เพียงอย่างเดียวคืออนุญาตให้คุณเลือกชุดย่อยของข้อมูลสำหรับการจำลอง

len(emnist_train.client_ids)
3383
emnist_train.element_type_structure
OrderedDict([('pixels', TensorSpec(shape=(28, 28), dtype=tf.float32, name=None)), ('label', TensorSpec(shape=(), dtype=tf.int32, name=None))])
example_dataset = emnist_train.create_tf_dataset_for_client(
    emnist_train.client_ids[0])

example_element = next(iter(example_dataset))

example_element['label'].numpy()
1
from matplotlib import pyplot as plt
plt.imshow(example_element['pixels'].numpy(), cmap='gray', aspect='equal')
plt.grid(False)
_ = plt.show()

png

การสำรวจความแตกต่างในข้อมูลแบบรวม

โดยทั่วไปแล้วข้อมูลแบบรวมจะไม่ใช่ iid ซึ่งโดยทั่วไปแล้วผู้ใช้จะมีการกระจายข้อมูลที่แตกต่างกันขึ้นอยู่กับรูปแบบการใช้งาน ลูกค้าบางรายอาจมีตัวอย่างการฝึกอบรมบนอุปกรณ์น้อยลงเนื่องจากมีปัญหาข้อมูลไม่เพียงพอในพื้นที่ในขณะที่ลูกค้าบางรายจะมีตัวอย่างการฝึกอบรมมากเกินพอ มาสำรวจแนวคิดเกี่ยวกับความแตกต่างของข้อมูลตามแบบฉบับของระบบสหพันธรัฐด้วยข้อมูล EMNIST ที่เรามีอยู่ สิ่งสำคัญคือต้องทราบว่าการวิเคราะห์ข้อมูลของลูกค้าในเชิงลึกนี้มีให้สำหรับเราเท่านั้นเพราะนี่คือสภาพแวดล้อมการจำลองที่ข้อมูลทั้งหมดพร้อมใช้งานสำหรับเราในพื้นที่ ในสภาพแวดล้อมแบบรวมการผลิตจริงคุณจะไม่สามารถตรวจสอบข้อมูลของไคลเอ็นต์เดียวได้

ก่อนอื่นมาสุ่มตัวอย่างข้อมูลของลูกค้ารายหนึ่งเพื่อทำความเข้าใจกับตัวอย่างในอุปกรณ์จำลองเครื่องเดียว เนื่องจากชุดข้อมูลที่เราใช้นั้นได้รับการคีย์โดยนักเขียนที่ไม่ซ้ำกันข้อมูลของลูกค้าหนึ่งรายจึงแสดงด้วยลายมือของบุคคลหนึ่งสำหรับตัวอย่างตัวเลข 0 ถึง 9 โดยจำลอง "รูปแบบการใช้งาน" ที่ไม่ซ้ำกันของผู้ใช้รายหนึ่ง

## Example MNIST digits for one client
figure = plt.figure(figsize=(20, 4))
j = 0

for example in example_dataset.take(40):
  plt.subplot(4, 10, j+1)
  plt.imshow(example['pixels'].numpy(), cmap='gray', aspect='equal')
  plt.axis('off')
  j += 1

png

ตอนนี้เรามาดูจำนวนตัวอย่างในไคลเอนต์แต่ละตัวสำหรับป้ายกำกับหลัก MNIST แต่ละรายการ ในสภาพแวดล้อมแบบรวมจำนวนตัวอย่างในแต่ละไคลเอ็นต์อาจแตกต่างกันเล็กน้อยขึ้นอยู่กับพฤติกรรมของผู้ใช้

# Number of examples per layer for a sample of clients
f = plt.figure(figsize=(12, 7))
f.suptitle('Label Counts for a Sample of Clients')
for i in range(6):
  client_dataset = emnist_train.create_tf_dataset_for_client(
      emnist_train.client_ids[i])
  plot_data = collections.defaultdict(list)
  for example in client_dataset:
    # Append counts individually per label to make plots
    # more colorful instead of one color per plot.
    label = example['label'].numpy()
    plot_data[label].append(label)
  plt.subplot(2, 3, i+1)
  plt.title('Client {}'.format(i))
  for j in range(10):
    plt.hist(
        plot_data[j],
        density=False,
        bins=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10])

png

ตอนนี้เรามาดูภาพเฉลี่ยต่อไคลเอ็นต์สำหรับป้ายกำกับ MNIST แต่ละรายการ รหัสนี้จะสร้างค่าเฉลี่ยของค่าพิกเซลแต่ละค่าสำหรับตัวอย่างของผู้ใช้ทั้งหมดสำหรับป้ายกำกับเดียว เราจะเห็นว่ารูปภาพเฉลี่ยของลูกค้ารายหนึ่งสำหรับตัวเลขจะดูแตกต่างจากรูปภาพเฉลี่ยของลูกค้ารายอื่นสำหรับตัวเลขเดียวกันเนื่องจากลักษณะการเขียนด้วยลายมือที่เป็นเอกลักษณ์ของแต่ละคน เราสามารถรำพึงได้ว่าแต่ละรอบการฝึกอบรมในพื้นที่จะผลักดันโมเดลไปในทิศทางที่แตกต่างกันอย่างไรในไคลเอนต์แต่ละรายเนื่องจากเรากำลังเรียนรู้จากข้อมูลเฉพาะของผู้ใช้รายนั้นในรอบท้องถิ่นนั้น ๆ ต่อไปในบทช่วยสอนเราจะดูว่าเราสามารถนำการอัปเดตแต่ละครั้งไปยังโมเดลจากไคลเอนต์ทั้งหมดและรวมเข้าด้วยกันเป็นโมเดลใหม่ระดับโลกของเราซึ่งได้เรียนรู้จากข้อมูลเฉพาะของลูกค้าแต่ละราย

# Each client has different mean images, meaning each client will be nudging
# the model in their own directions locally.

for i in range(5):
  client_dataset = emnist_train.create_tf_dataset_for_client(
      emnist_train.client_ids[i])
  plot_data = collections.defaultdict(list)
  for example in client_dataset:
    plot_data[example['label'].numpy()].append(example['pixels'].numpy())
  f = plt.figure(i, figsize=(12, 5))
  f.suptitle("Client #{}'s Mean Image Per Label".format(i))
  for j in range(10):
    mean_img = np.mean(plot_data[j], 0)
    plt.subplot(2, 5, j+1)
    plt.imshow(mean_img.reshape((28, 28)))
    plt.axis('off')

png

png

png

png

png

ข้อมูลผู้ใช้อาจมีเสียงดังและมีป้ายกำกับไม่น่าเชื่อถือ ตัวอย่างเช่นเมื่อดูข้อมูลของ Client # 2 ด้านบนเราจะเห็นว่าสำหรับป้ายกำกับ 2 เป็นไปได้ว่าอาจมีตัวอย่างที่ติดป้ายกำกับไม่ถูกต้องซึ่งทำให้เกิดภาพที่มีค่าเฉลี่ยที่ดังกว่า

การประมวลผลข้อมูลอินพุตล่วงหน้า

เนื่องจากข้อมูลเป็นtf.data.Dataset อยู่แล้วการประมวลผลล่วงหน้าสามารถทำได้โดยใช้การแปลงชุดข้อมูล ที่นี่เราทำให้ภาพ 28x28 แบนเป็น 784 - องค์ประกอบอาร์เรย์สับเปลี่ยนแต่ละตัวอย่างจัดระเบียบเป็นชุดและเปลี่ยนชื่อคุณลักษณะจาก pixels และ label เป็น x และ y เพื่อใช้กับ Keras นอกจากนี้เรายัง repeat ในชุดข้อมูลเพื่อเรียกใช้ยุคต่างๆ

NUM_CLIENTS = 10
NUM_EPOCHS = 5
BATCH_SIZE = 20
SHUFFLE_BUFFER = 100
PREFETCH_BUFFER = 10

def preprocess(dataset):

  def batch_format_fn(element):
    """Flatten a batch `pixels` and return the features as an `OrderedDict`."""
    return collections.OrderedDict(
        x=tf.reshape(element['pixels'], [-1, 784]),
        y=tf.reshape(element['label'], [-1, 1]))

  return dataset.repeat(NUM_EPOCHS).shuffle(SHUFFLE_BUFFER).batch(
      BATCH_SIZE).map(batch_format_fn).prefetch(PREFETCH_BUFFER)

มาตรวจสอบว่าได้ผล

preprocessed_example_dataset = preprocess(example_dataset)

sample_batch = tf.nest.map_structure(lambda x: x.numpy(),
                                     next(iter(preprocessed_example_dataset)))

sample_batch
OrderedDict([('x', array([[1., 1., 1., ..., 1., 1., 1.],
       [1., 1., 1., ..., 1., 1., 1.],
       [1., 1., 1., ..., 1., 1., 1.],
       ...,
       [1., 1., 1., ..., 1., 1., 1.],
       [1., 1., 1., ..., 1., 1., 1.],
       [1., 1., 1., ..., 1., 1., 1.]], dtype=float32)), ('y', array([[0],
       [5],
       [0],
       [1],
       [3],
       [0],
       [5],
       [4],
       [1],
       [7],
       [0],
       [4],
       [0],
       [1],
       [7],
       [2],
       [2],
       [0],
       [7],
       [1]], dtype=int32))])

เรามีหน่วยการสร้างเกือบทั้งหมดเพื่อสร้างชุดข้อมูลแบบรวมศูนย์

วิธีหนึ่งในการป้อนข้อมูลแบบรวมศูนย์ไปยัง TFF ในการจำลองเป็นเพียงรายการ Python โดยแต่ละองค์ประกอบของรายการจะเก็บข้อมูลของผู้ใช้แต่ละรายไม่ว่าจะเป็นรายการหรือเป็นtf.data.Dataset เนื่องจากเรามีอินเทอร์เฟซที่ให้หลังอยู่แล้วมาใช้กันเถอะ

นี่คือฟังก์ชั่นตัวช่วยง่ายๆที่จะสร้างรายการชุดข้อมูลจากกลุ่มผู้ใช้ที่กำหนดเพื่อเป็นข้อมูลเข้าในรอบการฝึกอบรมหรือการประเมินผล

def make_federated_data(client_data, client_ids):
  return [
      preprocess(client_data.create_tf_dataset_for_client(x))
      for x in client_ids
  ]

ตอนนี้เราจะเลือกลูกค้าอย่างไร?

ในสถานการณ์การฝึกอบรมแบบสหพันธ์ทั่วไปเรากำลังเผชิญกับจำนวนผู้ใช้อุปกรณ์ที่อาจมีจำนวนมากซึ่งอาจมีเพียงบางส่วนเท่านั้นสำหรับการฝึกอบรมในช่วงเวลาที่กำหนด ในกรณีนี้ตัวอย่างเช่นเมื่ออุปกรณ์ไคลเอนต์เป็นโทรศัพท์มือถือที่เข้าร่วมการฝึกอบรมเฉพาะเมื่อเสียบเข้ากับแหล่งจ่ายไฟปิดเครือข่ายแบบมิเตอร์และไม่ได้ใช้งาน

แน่นอนว่าเราอยู่ในสภาพแวดล้อมจำลองและข้อมูลทั้งหมดมีอยู่ในเครื่อง โดยปกติแล้วเมื่อใช้การจำลองสถานการณ์เราจะสุ่มตัวอย่างกลุ่มย่อยของลูกค้าที่จะมีส่วนร่วมในการฝึกอบรมแต่ละรอบโดยทั่วไปจะแตกต่างกันไปในแต่ละรอบ

ดังที่คุณสามารถหาได้จากการศึกษาเอกสารเกี่ยวกับอัลกอริธึม Federated Averaging การบรรลุคอนเวอร์เจนซ์ในระบบที่มีการสุ่มตัวอย่างชุดย่อยของไคลเอ็นต์ในแต่ละรอบอาจใช้เวลาสักครู่และจะเป็นไปไม่ได้ที่จะต้องวิ่งหลายร้อยรอบใน บทช่วยสอนแบบโต้ตอบนี้

สิ่งที่เราจะทำแทนคือการสุ่มตัวอย่างชุดของไคลเอ็นต์เพียงครั้งเดียวและใช้ชุดเดียวกันซ้ำทุกรอบเพื่อเพิ่มความเร็วในการบรรจบกัน (โดยตั้งใจให้เหมาะสมกับข้อมูลของผู้ใช้ไม่กี่รายเหล่านี้) เราปล่อยให้มันเป็นแบบฝึกหัดสำหรับผู้อ่านในการปรับเปลี่ยนบทช่วยสอนนี้เพื่อจำลองการสุ่มตัวอย่างแบบสุ่ม - มันค่อนข้างง่ายที่จะทำ (เมื่อคุณทำแล้วโปรดทราบว่าการทำให้โมเดลมาบรรจบกันอาจใช้เวลาสักครู่)

sample_clients = emnist_train.client_ids[0:NUM_CLIENTS]

federated_train_data = make_federated_data(emnist_train, sample_clients)

print('Number of client datasets: {l}'.format(l=len(federated_train_data)))
print('First dataset: {d}'.format(d=federated_train_data[0]))
Number of client datasets: 10
First dataset: <DatasetV1Adapter shapes: OrderedDict([(x, (None, 784)), (y, (None, 1))]), types: OrderedDict([(x, tf.float32), (y, tf.int32)])>

การสร้างโมเดลด้วย Keras

หากคุณใช้ Keras คุณอาจมีโค้ดที่สร้างโมเดล Keras อยู่แล้ว นี่คือตัวอย่างโมเดลง่ายๆที่เพียงพอต่อความต้องการของเรา

def create_keras_model():
  return tf.keras.models.Sequential([
      tf.keras.layers.Input(shape=(784,)),
      tf.keras.layers.Dense(10, kernel_initializer='zeros'),
      tf.keras.layers.Softmax(),
  ])

ในการใช้โมเดลใด ๆ กับ TFF จำเป็นต้องรวมอินสแตนซ์ของอินสแตนซ์ของอินเทอร์เฟซ tff.learning.Model ซึ่งแสดงวิธีการประทับตราฟอร์เวิร์ดพาสคุณสมบัติข้อมูลเมตา ฯลฯ ของโมเดลเช่นเดียวกับ Keras แต่ยังแนะนำเพิ่มเติม องค์ประกอบต่างๆเช่นวิธีควบคุมกระบวนการคำนวณเมตริกแบบรวม ตอนนี้ไม่ต้องกังวลเกี่ยวกับเรื่องนี้ หากคุณมีโมเดล Keras เช่นเดียวกับที่เราได้กำหนดไว้ข้างต้นคุณสามารถให้ TFF รวมไว้ให้คุณได้โดยเรียกใช้ tff.learning.from_keras_model ส่งโมเดลและชุดข้อมูลตัวอย่างเป็นอาร์กิวเมนต์ดังที่แสดงด้านล่าง

def model_fn():
  # We _must_ create a new model here, and _not_ capture it from an external
  # scope. TFF will call this within different graph contexts.
  keras_model = create_keras_model()
  return tff.learning.from_keras_model(
      keras_model,
      input_spec=preprocessed_example_dataset.element_spec,
      loss=tf.keras.losses.SparseCategoricalCrossentropy(),
      metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])

การฝึกอบรมโมเดลกับข้อมูลแบบรวม

ตอนนี้เรามีโมเดลที่รวมเป็น tff.learning.Model สำหรับใช้กับ TFF แล้วเราสามารถให้ TFF สร้างอัลกอริธึม Federated Averaging โดยเรียกใช้ฟังก์ชันตัวช่วย tff.learning.build_federated_averaging_process ดังต่อไปนี้

โปรดทราบว่าอาร์กิวเมนต์จำเป็นต้องเป็นตัวสร้าง (เช่น model_fn ด้านบน) ไม่ใช่อินสแตนซ์ที่สร้างขึ้นแล้วเพื่อให้การสร้างโมเดลของคุณสามารถเกิดขึ้นได้ในบริบทที่ควบคุมโดย TFF (หากคุณอยากรู้เกี่ยวกับเหตุผลของ เราขอแนะนำให้คุณอ่านบทแนะนำการติดตามเกี่ยวกับ อัลกอริทึมที่กำหนดเอง )

ข้อสังเกตที่สำคัญอย่างหนึ่งเกี่ยวกับอัลกอริทึมการเฉลี่ยรวมด้านล่างมีเครื่องมือเพิ่มประสิทธิภาพ 2 ตัว ได้แก่ เครื่องมือเพิ่มประสิทธิภาพ _client และเครื่องมือ เพิ่มประสิทธิภาพ _server _client Optimizer ใช้เพื่อคำนวณการอัปเดตโมเดลภายในของไคลเอ็นต์แต่ละตัวเท่านั้น _server optimizer ใช้การอัปเดตโดยเฉลี่ยกับโมเดลส่วนกลางที่เซิร์ฟเวอร์ โดยเฉพาะอย่างยิ่งนั่นหมายความว่าตัวเลือกของเครื่องมือเพิ่มประสิทธิภาพและอัตราการเรียนรู้ที่ใช้อาจต้องแตกต่างจากที่คุณใช้ในการฝึกโมเดลบนชุดข้อมูล iid มาตรฐาน เราขอแนะนำให้เริ่มต้นด้วย SGD ปกติโดยอาจมีอัตราการเรียนรู้ที่น้อยกว่าปกติ อัตราการเรียนรู้ที่เราใช้ไม่ได้รับการปรับแต่งอย่างรอบคอบอย่าลังเลที่จะทดลอง

iterative_process = tff.learning.build_federated_averaging_process(
    model_fn,
    client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=0.02),
    server_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=1.0))

เกิดอะไรขึ้น? TFF ได้สร้างคู่ของ การคำนวณแบบรวมศูนย์ และรวมไว้ใน tff.templates.IterativeProcess ซึ่งการคำนวณเหล่านี้พร้อมใช้งานเป็นคู่ของคุณสมบัติ initialize และ next

โดยสรุป การคำนวณแบบรวมศูนย์ คือโปรแกรมในภาษาภายในของ TFF ที่สามารถแสดงอัลกอริทึมแบบรวมต่างๆได้ (คุณสามารถดูข้อมูลเพิ่มเติมเกี่ยวกับสิ่งนี้ได้ในบทช่วยสอน อัลกอริทึมที่กำหนดเอง ) ในกรณีนี้การคำนวณสองรายการที่สร้างขึ้นและบรรจุลงใน iterative_process ใช้ Federated Averaging

เป็นเป้าหมายของ TFF ในการกำหนดการคำนวณในลักษณะที่สามารถดำเนินการได้ในการตั้งค่าการเรียนรู้แบบรวมศูนย์จริง แต่ปัจจุบันมีการใช้งานรันไทม์การจำลองการดำเนินการในระบบเท่านั้น ในการดำเนินการคำนวณในโปรแกรมจำลองคุณเพียงแค่เรียกใช้เช่นฟังก์ชัน Python สภาพแวดล้อมที่ตีความเริ่มต้นนี้ไม่ได้ออกแบบมาเพื่อประสิทธิภาพสูง แต่ก็เพียงพอสำหรับบทช่วยสอนนี้ เราคาดว่าจะจัดเตรียมช่วงเวลาการจำลองที่มีประสิทธิภาพสูงขึ้นเพื่ออำนวยความสะดวกในการวิจัยขนาดใหญ่ในรุ่นต่อ ๆ ไป

เริ่มต้นด้วยการ initialize การคำนวณ ในกรณีของการคำนวณแบบรวมศูนย์ทั้งหมดคุณสามารถคิดว่ามันเป็นฟังก์ชัน การคำนวณไม่ใช้อาร์กิวเมนต์และส่งกลับผลลัพธ์หนึ่งรายการ - การแสดงสถานะของกระบวนการรวมค่าเฉลี่ยบนเซิร์ฟเวอร์ ในขณะที่เราไม่ต้องการดำดิ่งลงไปในรายละเอียดของ TFF แต่อาจมีคำแนะนำให้ดูว่าสถานะนี้มีลักษณะอย่างไร คุณสามารถเห็นภาพดังต่อไปนี้

str(iterative_process.initialize.type_signature)
'( -> <model=<trainable=<float32[784,10],float32[10]>,non_trainable=<>>,optimizer_state=<int64>,delta_aggregate_state=<value_sum_process=<>,weight_sum_process=<>>,model_broadcast_state=<>>@SERVER)'

แม้ว่าลายเซ็นประเภทข้างต้นในตอนแรกอาจดูเหมือนเป็นความลับเล็กน้อย แต่คุณสามารถรับรู้ได้ว่าสถานะเซิร์ฟเวอร์ประกอบด้วย model (พารามิเตอร์โมเดลเริ่มต้นสำหรับ MNIST ที่จะกระจายไปยังอุปกรณ์ทั้งหมด) และ optimizer_state (ข้อมูลเพิ่มเติมที่เซิร์ฟเวอร์ดูแล เช่นจำนวนรอบที่จะใช้สำหรับตารางเวลาไฮเปอร์พารามิเตอร์เป็นต้น)

มาเรียก initialize การคำนวณ initialize เพื่อสร้างสถานะเซิร์ฟเวอร์

state = iterative_process.initialize()

คู่ที่สองของการคำนวณแบบรวมคู่ next เป็นการแสดงรอบเดียวของ Federated Averaging ซึ่งประกอบด้วยการผลักสถานะเซิร์ฟเวอร์ (รวมถึงพารามิเตอร์ของโมเดล) ไปยังไคลเอนต์การฝึกอบรมบนอุปกรณ์เกี่ยวกับข้อมูลในเครื่องการรวบรวมและการอัปเดตโมเดลโดยเฉลี่ย และสร้างโมเดลที่อัปเดตใหม่ที่เซิร์ฟเวอร์

ตามแนวคิดแล้วคุณสามารถคิด next ว่ามีลายเซ็นประเภทการทำงานที่มีลักษณะดังนี้

SERVER_STATE, FEDERATED_DATA -> SERVER_STATE, TRAINING_METRICS

โดยเฉพาะอย่างยิ่งเราควรคิดเกี่ยวกับ next() ไม่ใช่ว่าเป็นฟังก์ชันที่ทำงานบนเซิร์ฟเวอร์ แต่เป็นการแสดงหน้าที่อย่างเปิดเผยของการคำนวณแบบกระจายอำนาจทั้งหมด - อินพุตบางส่วนจัดเตรียมโดยเซิร์ฟเวอร์ ( SERVER_STATE ) แต่แต่ละส่วนมีส่วนร่วม อุปกรณ์สร้างชุดข้อมูลในเครื่องของตัวเอง

ลองฝึกรอบเดียวและดูผลลัพธ์ เราสามารถใช้ข้อมูลรวมที่เราสร้างไว้ข้างต้นสำหรับกลุ่มตัวอย่างผู้ใช้

state, metrics = iterative_process.next(state, federated_train_data)
print('round  1, metrics={}'.format(metrics))
round  1, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('value_sum_process', ()), ('weight_sum_process', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.11502057), ('loss', 3.244929)]))])

มาวิ่งอีกสองสามรอบ ตามที่ระบุไว้ก่อนหน้านี้โดยทั่วไป ณ จุดนี้คุณจะเลือกชุดย่อยของข้อมูลการจำลองของคุณจากกลุ่มตัวอย่างใหม่ที่สุ่มเลือกของผู้ใช้สำหรับแต่ละรอบเพื่อจำลองการปรับใช้จริงที่ผู้ใช้ไปมาอย่างต่อเนื่อง แต่ในสมุดบันทึกแบบโต้ตอบนี้สำหรับ เพื่อการสาธิตเราจะใช้ผู้ใช้เดิมซ้ำเพื่อให้ระบบมาบรรจบกันอย่างรวดเร็ว

NUM_ROUNDS = 11
for round_num in range(2, NUM_ROUNDS):
  state, metrics = iterative_process.next(state, federated_train_data)
  print('round {:2d}, metrics={}'.format(round_num, metrics))
round  2, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('value_sum_process', ()), ('weight_sum_process', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.14609054), ('loss', 2.9141645)]))])
round  3, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('value_sum_process', ()), ('weight_sum_process', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.15205762), ('loss', 2.9237952)]))])
round  4, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('value_sum_process', ()), ('weight_sum_process', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.18600823), ('loss', 2.7629454)]))])
round  5, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('value_sum_process', ()), ('weight_sum_process', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.20884773), ('loss', 2.622908)]))])
round  6, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('value_sum_process', ()), ('weight_sum_process', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.21872428), ('loss', 2.543587)]))])
round  7, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('value_sum_process', ()), ('weight_sum_process', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.2372428), ('loss', 2.4210362)]))])
round  8, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('value_sum_process', ()), ('weight_sum_process', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.28209877), ('loss', 2.2297976)]))])
round  9, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('value_sum_process', ()), ('weight_sum_process', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.2685185), ('loss', 2.195803)]))])
round 10, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('value_sum_process', ()), ('weight_sum_process', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.33868313), ('loss', 2.0523348)]))])

การสูญเสียการฝึกกำลังลดลงหลังจากการฝึกอบรมแบบสหพันธ์แต่ละรอบซึ่งบ่งชี้ว่าโมเดลกำลังมาบรรจบกัน มีคำเตือนที่สำคัญบางประการเกี่ยวกับเมตริกการฝึกอบรมเหล่านี้อย่างไรก็ตามโปรดดูหัวข้อการ ประเมินผล ในบทช่วยสอนนี้

การแสดงเมตริกโมเดลใน TensorBoard

ต่อไปเรามาดูเมตริกจากการคำนวณแบบรวมศูนย์เหล่านี้โดยใช้ Tensorboard

เริ่มต้นด้วยการสร้างไดเร็กทอรีและตัวเขียนสรุปที่เกี่ยวข้องเพื่อเขียนเมตริก

logdir = "/tmp/logs/scalars/training/"
summary_writer = tf.summary.create_file_writer(logdir)
state = iterative_process.initialize()

พล็อตเมตริกสเกลาร์ที่เกี่ยวข้องกับผู้เขียนสรุปคนเดียวกัน

with summary_writer.as_default():
  for round_num in range(1, NUM_ROUNDS):
    state, metrics = iterative_process.next(state, federated_train_data)
    for name, value in metrics['train'].items():
      tf.summary.scalar(name, value, step=round_num)

เริ่ม TensorBoard ด้วยไดเร็กทอรีบันทึกรูทที่ระบุไว้ด้านบน อาจใช้เวลาสองสามวินาทีในการโหลดข้อมูล

!ls {logdir}
%tensorboard --logdir {logdir} --port=0
events.out.tfevents.1604020204.isim77-20020ad609500000b02900f40f27a5f6.prod.google.com.686098.10633.v2
events.out.tfevents.1604020602.isim77-20020ad609500000b02900f40f27a5f6.prod.google.com.794554.10607.v2
Launching TensorBoard...
<IPython.core.display.Javascript at 0x7fc5e8d3c128>
# Uncomment and run this this cell to clean your directory of old output for
# future graphs from this directory. We don't run it by default so that if 
# you do a "Runtime > Run all" you don't lose your results.

# !rm -R /tmp/logs/scalars/*

ในการดูเมตริกการประเมินผลในลักษณะเดียวกันคุณสามารถสร้างโฟลเดอร์ eval แยกต่างหากเช่น "logs / scalars / eval" เพื่อเขียนไปยัง TensorBoard

การปรับแต่งการใช้งานแบบจำลอง

Keras เป็น API โมเดลระดับสูงที่แนะนำสำหรับ TensorFlow และเราขอแนะนำให้ใช้โมเดล Keras (ผ่าน tff.learning.from_keras_model ) ใน TFF ทุกครั้งที่ทำได้

อย่างไรก็ตาม tff.learning จัดเตรียมอินเทอร์เฟซโมเดลระดับล่าง tff.learning.Model ซึ่งแสดงฟังก์ชันขั้นต่ำที่จำเป็นสำหรับการใช้โมเดลสำหรับการเรียนรู้แบบรวมศูนย์ การนำอินเทอร์เฟซนี้ไปใช้โดยตรง (อาจยังคงใช้หน่วยการสร้างเช่นtf.keras.layers ) ช่วยให้สามารถปรับแต่งได้สูงสุดโดยไม่ต้องแก้ไขภายในของอัลกอริทึมการเรียนรู้แบบรวมศูนย์

ลองทำใหม่ทั้งหมดอีกครั้งตั้งแต่เริ่มต้น

การกำหนดตัวแปรรูปแบบการส่งต่อและเมตริก

ขั้นตอนแรกคือการระบุตัวแปร TensorFlow ที่เรากำลังจะดำเนินการ เพื่อให้โค้ดต่อไปนี้ชัดเจนยิ่งขึ้นเรามากำหนดโครงสร้างข้อมูลเพื่อแสดงทั้งชุด ซึ่งจะรวมถึงตัวแปรเช่น weights และ bias ที่เราจะฝึกเช่นเดียวกับตัวแปรที่จะถือสถิติต่างๆที่สะสมและเคาน์เตอร์เราจะปรับปรุงในระหว่างการฝึกเช่น loss_sum , accuracy_sum และ num_examples

MnistVariables = collections.namedtuple(
    'MnistVariables', 'weights bias num_examples loss_sum accuracy_sum')

นี่คือวิธีการสร้างตัวแปร เพื่อความเรียบง่ายเราแสดงสถิติทั้งหมดเป็น tf.float32 เนื่องจากจะช่วยขจัดความจำเป็นในการแปลงประเภทในระยะต่อไป การห่อตัวแปรเริ่มต้นเป็น lambdas เป็นข้อกำหนดที่กำหนดโดย ตัวแปรทรัพยากร

def create_mnist_variables():
  return MnistVariables(
      weights=tf.Variable(
          lambda: tf.zeros(dtype=tf.float32, shape=(784, 10)),
          name='weights',
          trainable=True),
      bias=tf.Variable(
          lambda: tf.zeros(dtype=tf.float32, shape=(10)),
          name='bias',
          trainable=True),
      num_examples=tf.Variable(0.0, name='num_examples', trainable=False),
      loss_sum=tf.Variable(0.0, name='loss_sum', trainable=False),
      accuracy_sum=tf.Variable(0.0, name='accuracy_sum', trainable=False))

ด้วยตัวแปรสำหรับพารามิเตอร์แบบจำลองและสถิติสะสมในขณะนี้เราสามารถกำหนดวิธีการส่งต่อที่คำนวณการสูญเสียปล่อยการคาดการณ์และอัปเดตสถิติสะสมสำหรับข้อมูลอินพุตชุดเดียวได้ดังนี้

def mnist_forward_pass(variables, batch):
  y = tf.nn.softmax(tf.matmul(batch['x'], variables.weights) + variables.bias)
  predictions = tf.cast(tf.argmax(y, 1), tf.int32)

  flat_labels = tf.reshape(batch['y'], [-1])
  loss = -tf.reduce_mean(
      tf.reduce_sum(tf.one_hot(flat_labels, 10) * tf.math.log(y), axis=[1]))
  accuracy = tf.reduce_mean(
      tf.cast(tf.equal(predictions, flat_labels), tf.float32))

  num_examples = tf.cast(tf.size(batch['y']), tf.float32)

  variables.num_examples.assign_add(num_examples)
  variables.loss_sum.assign_add(loss * num_examples)
  variables.accuracy_sum.assign_add(accuracy * num_examples)

  return loss, predictions

ต่อไปเราจะกำหนดฟังก์ชันที่ส่งคืนชุดของเมตริกท้องถิ่นอีกครั้งโดยใช้ TensorFlow ค่าเหล่านี้คือค่า (นอกเหนือจากการอัปเดตแบบจำลองซึ่งได้รับการจัดการโดยอัตโนมัติ) ที่มีสิทธิ์รวมเข้ากับเซิร์ฟเวอร์ในกระบวนการเรียนรู้แบบรวมหรือการประเมินผล

ที่นี่เราเพียงแค่ส่งคืนการ loss และ accuracy โดยเฉลี่ยรวมถึงตัวอย่าง num_examples ซึ่งเราจำเป็นต้องให้น้ำหนักการมีส่วนร่วมจากผู้ใช้ที่แตกต่างกันอย่างถูกต้องเมื่อคำนวณผลรวมแบบรวม

def get_local_mnist_metrics(variables):
  return collections.OrderedDict(
      num_examples=variables.num_examples,
      loss=variables.loss_sum / variables.num_examples,
      accuracy=variables.accuracy_sum / variables.num_examples)

สุดท้ายเราต้องกำหนดวิธีการรวมเมตริกท้องถิ่นที่ปล่อยออกมาโดยอุปกรณ์แต่ละเครื่องผ่าน get_local_mnist_metrics นี่เป็นส่วนเดียวของโค้ดที่ไม่ได้เขียนใน TensorFlow - เป็นการ คำนวณแบบรวม ศูนย์ ที่ แสดงใน TFF หากคุณต้องการเจาะลึกลงไปให้อ่านบทช่วยสอน อัลกอริทึมที่กำหนดเอง แต่ในแอปพลิเคชันส่วนใหญ่คุณไม่จำเป็นต้องทำจริงๆ รูปแบบต่างๆที่แสดงด้านล่างน่าจะเพียงพอ นี่คือสิ่งที่ดูเหมือน:

@tff.federated_computation
def aggregate_mnist_metrics_across_clients(metrics):
  return collections.OrderedDict(
      num_examples=tff.federated_sum(metrics.num_examples),
      loss=tff.federated_mean(metrics.loss, metrics.num_examples),
      accuracy=tff.federated_mean(metrics.accuracy, metrics.num_examples))

อาร์กิวเมนต์ของ metrics อินพุตสอดคล้องกับ OrderedDict ส่งคืนโดย get_local_mnist_metrics ด้านบน แต่ค่าวิกฤตไม่ได้เป็น tf.Tensors อีกต่อไป tf.Tensors - มีการ "บรรจุกล่อง" เป็น tff.Value เพื่อให้ชัดเจนว่าคุณไม่สามารถจัดการได้อีกต่อไปโดยใช้ TensorFlow แต่เพียงอย่างเดียว โดยใช้ตัวดำเนินการแบบรวมของ TFF เช่น tff.federated_mean และ tff.federated_sum พจนานุกรมที่ส่งคืนของการรวมทั่วโลกกำหนดชุดของเมตริกซึ่งจะพร้อมใช้งานบนเซิร์ฟเวอร์

การสร้างอินสแตนซ์ของ tff.learning.Model

ด้วยสิ่งที่กล่าวมาทั้งหมดเราพร้อมที่จะสร้างการแสดงโมเดลสำหรับใช้กับ TFF ที่คล้ายกับที่สร้างขึ้นสำหรับคุณเมื่อคุณปล่อยให้ TFF นำเข้าโมเดล Keras

class MnistModel(tff.learning.Model):

  def __init__(self):
    self._variables = create_mnist_variables()

  @property
  def trainable_variables(self):
    return [self._variables.weights, self._variables.bias]

  @property
  def non_trainable_variables(self):
    return []

  @property
  def local_variables(self):
    return [
        self._variables.num_examples, self._variables.loss_sum,
        self._variables.accuracy_sum
    ]

  @property
  def input_spec(self):
    return collections.OrderedDict(
        x=tf.TensorSpec([None, 784], tf.float32),
        y=tf.TensorSpec([None, 1], tf.int32))

  @tf.function
  def forward_pass(self, batch, training=True):
    del training
    loss, predictions = mnist_forward_pass(self._variables, batch)
    num_exmaples = tf.shape(batch['x'])[0]
    return tff.learning.BatchOutput(
        loss=loss, predictions=predictions, num_examples=num_exmaples)

  @tf.function
  def report_local_outputs(self):
    return get_local_mnist_metrics(self._variables)

  @property
  def federated_output_computation(self):
    return aggregate_mnist_metrics_across_clients

ดังที่คุณเห็นวิธีนามธรรมและคุณสมบัติที่กำหนดโดย tff.learning.Model สอดคล้องกับข้อมูลโค้ดในส่วนก่อนหน้านี้ที่แนะนำตัวแปรและกำหนดการสูญเสียและสถิติ

จุดไฮไลต์ที่ควรค่าแก่การเน้นมีดังนี้

  • สถานะทั้งหมดที่โมเดลของคุณจะใช้ต้องถูกจับเป็นตัวแปร TensorFlow เนื่องจาก TFF ไม่ได้ใช้ Python ในรันไทม์ (โปรดจำไว้ว่าควรเขียนโค้ดของคุณเพื่อให้สามารถปรับใช้กับอุปกรณ์มือถือได้ดูบทช่วยสอน อัลกอริทึมที่กำหนดเอง สำหรับข้อมูลเชิงลึกเพิ่มเติม ความเห็นเกี่ยวกับเหตุผล)
  • แบบจำลองของคุณควรอธิบายถึงรูปแบบของข้อมูลที่ยอมรับ ( input_spec ) โดยทั่วไป TFF เป็นสภาพแวดล้อมที่พิมพ์มากและต้องการกำหนดลายเซ็นประเภทสำหรับส่วนประกอบทั้งหมด การประกาศรูปแบบข้อมูลเข้าของโมเดลของคุณเป็นส่วนสำคัญ
  • แม้ว่าจะไม่จำเป็นต้องใช้ในทางเทคนิคเราขอแนะนำให้รวมลอจิก TensorFlow ทั้งหมด (การส่งต่อการคำนวณเมตริก ฯลฯ ) เป็น tf.function เนื่องจากจะช่วยให้มั่นใจได้ว่า TensorFlow สามารถต่ออนุกรมกันได้และขจัดความจำเป็นในการพึ่งพาการควบคุมอย่างชัดเจน

ข้างต้นเพียงพอสำหรับการประเมินและอัลกอริทึมเช่น Federated SGD อย่างไรก็ตามสำหรับ Federated Averaging เราจำเป็นต้องระบุว่าโมเดลควรฝึกแบบโลคัลในแต่ละแบทช์อย่างไร เราจะระบุเครื่องมือเพิ่มประสิทธิภาพโลคัลเมื่อสร้างอัลกอริทึม Federated Averaging

จำลองการฝึกอบรมแบบสหพันธรัฐด้วยรูปแบบใหม่

ด้วยสิ่งที่กล่าวมาทั้งหมดส่วนที่เหลือของกระบวนการจะดูเหมือนกับสิ่งที่เราเคยเห็นไปแล้ว - เพียงแค่แทนที่ตัวสร้างโมเดลด้วยตัวสร้างของคลาสโมเดลใหม่ของเราและใช้การคำนวณแบบรวมทั้งสองแบบรวมในกระบวนการวนซ้ำที่คุณสร้างขึ้นเพื่อวนรอบ รอบการฝึกอบรม

iterative_process = tff.learning.build_federated_averaging_process(
    MnistModel,
    client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=0.02))
state = iterative_process.initialize()
state, metrics = iterative_process.next(state, federated_train_data)
print('round  1, metrics={}'.format(metrics))
round  1, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('value_sum_process', ()), ('weight_sum_process', ())])), ('train', OrderedDict([('num_examples', 4860.0), ('loss', 3.1527398), ('accuracy', 0.12469136)]))])
for round_num in range(2, 11):
  state, metrics = iterative_process.next(state, federated_train_data)
  print('round {:2d}, metrics={}'.format(round_num, metrics))
round  2, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('value_sum_process', ()), ('weight_sum_process', ())])), ('train', OrderedDict([('num_examples', 4860.0), ('loss', 2.941014), ('accuracy', 0.14218107)]))])
round  3, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('value_sum_process', ()), ('weight_sum_process', ())])), ('train', OrderedDict([('num_examples', 4860.0), ('loss', 2.9052832), ('accuracy', 0.14444445)]))])
round  4, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('value_sum_process', ()), ('weight_sum_process', ())])), ('train', OrderedDict([('num_examples', 4860.0), ('loss', 2.7491086), ('accuracy', 0.17962962)]))])
round  5, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('value_sum_process', ()), ('weight_sum_process', ())])), ('train', OrderedDict([('num_examples', 4860.0), ('loss', 2.5129666), ('accuracy', 0.19526748)]))])
round  6, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('value_sum_process', ()), ('weight_sum_process', ())])), ('train', OrderedDict([('num_examples', 4860.0), ('loss', 2.4175923), ('accuracy', 0.23600823)]))])
round  7, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('value_sum_process', ()), ('weight_sum_process', ())])), ('train', OrderedDict([('num_examples', 4860.0), ('loss', 2.4273515), ('accuracy', 0.24176955)]))])
round  8, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('value_sum_process', ()), ('weight_sum_process', ())])), ('train', OrderedDict([('num_examples', 4860.0), ('loss', 2.2426176), ('accuracy', 0.2802469)]))])
round  9, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('value_sum_process', ()), ('weight_sum_process', ())])), ('train', OrderedDict([('num_examples', 4860.0), ('loss', 2.1567981), ('accuracy', 0.295679)]))])
round 10, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('value_sum_process', ()), ('weight_sum_process', ())])), ('train', OrderedDict([('num_examples', 4860.0), ('loss', 2.1092515), ('accuracy', 0.30843621)]))])

หากต้องการดูเมตริกเหล่านี้ภายใน TensorBoard โปรดดูขั้นตอนที่ระบุไว้ด้านบนใน "การแสดงเมตริกแบบจำลองใน TensorBoard"

การประเมินผล

จนถึงขณะนี้การทดลองทั้งหมดของเรานำเสนอเฉพาะเมตริกการฝึกอบรมแบบรวมศูนย์ซึ่งเป็นเมตริกเฉลี่ยสำหรับชุดข้อมูลทั้งหมดที่ได้รับการฝึกอบรมจากลูกค้าทั้งหมดในรอบนั้น สิ่งนี้ทำให้เกิดข้อกังวลตามปกติเกี่ยวกับการติดตั้งอุปกรณ์มากเกินไปโดยเฉพาะอย่างยิ่งเมื่อเราใช้ไคลเอนต์ชุดเดียวกันในแต่ละรอบเพื่อความเรียบง่าย แต่มีแนวคิดเพิ่มเติมเกี่ยวกับเมตริกการฝึกอบรมที่มากเกินไปโดยเฉพาะสำหรับอัลกอริทึมการเฉลี่ยรวม นี่เป็นวิธีที่ง่ายที่สุดในการดูว่าเราจินตนาการว่าลูกค้าแต่ละรายมีข้อมูลชุดเดียวหรือไม่และเราฝึกอบรมชุดนั้นสำหรับการทำซ้ำหลาย ๆ ครั้ง (ยุค) ในกรณีนี้โมเดลโลคัลจะพอดีกับชุดนั้นอย่างรวดเร็วดังนั้นเมตริกความแม่นยำในพื้นที่ที่เราเฉลี่ยจะเข้าใกล้ 1.0 ดังนั้นเมตริกการฝึกอบรมเหล่านี้สามารถใช้เป็นสัญญาณบ่งชี้ว่าการฝึกอบรมกำลังก้าวหน้า แต่ไม่มากไปกว่านั้น

ในการประเมินข้อมูลแบบรวมคุณสามารถสร้าง การคำนวณแบบรวมศูนย์ อื่น ที่ ออกแบบมาเพื่อจุดประสงค์นี้โดยใช้ฟังก์ชัน tff.learning.build_federated_evaluation และส่งผ่านตัวสร้างโมเดลของคุณเป็นอาร์กิวเมนต์ โปรดทราบว่าไม่เหมือนกับ Federated Averaging ที่เราใช้ MnistTrainableModel ก็เพียงพอที่จะส่งผ่าน MnistModel การประเมินผลไม่ได้ทำการไล่ระดับสีและไม่จำเป็นต้องสร้างเครื่องมือเพิ่มประสิทธิภาพ

สำหรับการทดลองและการวิจัยเมื่อมีชุดข้อมูลการทดสอบแบบรวมศูนย์การ เรียนรู้ แบบรวมศูนย์ สำหรับการสร้างข้อความจะ แสดงตัวเลือกการประเมินอีกวิธีหนึ่งคือการรับน้ำหนักที่ได้รับการฝึกฝนจากการเรียนรู้แบบรวมศูนย์นำไปใช้กับโมเดล Keras มาตรฐานแล้วเรียกใช้ tf.keras.models.Model.evaluate() บนชุดข้อมูลส่วนกลาง

evaluation = tff.learning.build_federated_evaluation(MnistModel)

คุณสามารถตรวจสอบลายเซ็นชนิดนามธรรมของฟังก์ชันการประเมินได้ดังต่อไปนี้

str(evaluation.type_signature)
'(<server_model_weights=<trainable=<float32[784,10],float32[10]>,non_trainable=<>>@SERVER,federated_dataset={<x=float32[?,784],y=int32[?,1]>*}@CLIENTS> -> <num_examples=float32@SERVER,loss=float32@SERVER,accuracy=float32@SERVER>)'

ไม่จำเป็นต้องกังวลเกี่ยวกับรายละเอียดในตอนนี้เพียงแค่ทราบว่าต้องใช้รูปแบบทั่วไปดังต่อไปนี้คล้ายกับ tff.templates.IterativeProcess.next แต่มีความแตกต่างที่สำคัญสองประการ อันดับแรกเราจะไม่ส่งคืนสถานะเซิร์ฟเวอร์เนื่องจากการประเมินไม่ได้ปรับเปลี่ยนโมเดลหรือลักษณะอื่น ๆ ของสถานะคุณอาจคิดว่าสถานะนี้เป็นแบบไร้สัญชาติ ประการที่สองการประเมินต้องใช้โมเดลเท่านั้นและไม่จำเป็นต้องมีส่วนอื่น ๆ ของสถานะเซิร์ฟเวอร์ที่อาจเกี่ยวข้องกับการฝึกอบรมเช่นตัวแปรเครื่องมือเพิ่มประสิทธิภาพ

SERVER_MODEL, FEDERATED_DATA -> TRAINING_METRICS

มาประเมินสถานะล่าสุดที่เรามาถึงในระหว่างการฝึกกันเถอะ ในการดึงโมเดลที่ได้รับการฝึกอบรมล่าสุดออกจากสถานะเซิร์ฟเวอร์คุณเพียงแค่เข้าถึงสมาชิก. .model ดังต่อไปนี้

train_metrics = evaluation(state.model, federated_train_data)

นี่คือสิ่งที่เราได้รับ สังเกตว่าตัวเลขดูดีกว่าที่รายงานในรอบสุดท้ายของการฝึกอบรมด้านบนเล็กน้อย ตามแบบแผนเมตริกการฝึกอบรมที่รายงานโดยกระบวนการฝึกอบรมซ้ำโดยทั่วไปจะสะท้อนถึงประสิทธิภาพของแบบจำลองในช่วงเริ่มต้นของรอบการฝึกอบรมดังนั้นเมตริกการประเมินผลจะนำหน้าไปหนึ่งก้าวเสมอ

str(train_metrics)
'<num_examples=4860.0,loss=1.7142657041549683,accuracy=0.38683128356933594>'

ตอนนี้เรามารวบรวมตัวอย่างการทดสอบของข้อมูลแบบรวมและเรียกใช้การประเมินผลข้อมูลทดสอบอีกครั้ง ข้อมูลจะมาจากกลุ่มตัวอย่างเดียวกันของผู้ใช้จริง แต่มาจากชุดข้อมูลที่ระงับไว้ที่แตกต่างกัน

federated_test_data = make_federated_data(emnist_test, sample_clients)

len(federated_test_data), federated_test_data[0]
(10,
 <DatasetV1Adapter shapes: OrderedDict([(x, (None, 784)), (y, (None, 1))]), types: OrderedDict([(x, tf.float32), (y, tf.int32)])>)
test_metrics = evaluation(state.model, federated_test_data)
str(test_metrics)
'<num_examples=580.0,loss=1.861915111541748,accuracy=0.3362068831920624>'

สรุปบทแนะนำนี้ เราขอแนะนำให้คุณเล่นกับพารามิเตอร์ (เช่นขนาดแบทช์จำนวนผู้ใช้ยุคอัตราการเรียนรู้ ฯลฯ ) เพื่อแก้ไขโค้ดด้านบนเพื่อจำลองการฝึกอบรมเกี่ยวกับตัวอย่างสุ่มของผู้ใช้ในแต่ละรอบและเพื่อสำรวจบทช่วยสอนอื่น ๆ เราได้พัฒนา