本頁面由 Cloud Translation API 翻譯而成。
Switch to English

聯合學習進行圖像分類

在TensorFlow.org上查看 在Google Colab中運行 在GitHub上查看源代碼

在本教程中,我們使用經典的MNIST培訓示例介紹TFF的聯合學習(FL)API層tff.learning一組高級接口,可用於執行常見類型的聯合學習任務,例如針對TensorFlow中實現的用戶提供的模型進行聯合培訓。

本教程以及Federated Learning API,主要供希望將自己的TensorFlow模型插入TFF的用戶使用,後者主要將其視為黑匣子。要更深入地了解TFF以及如何實現自己的聯合學習算法,請參閱FC Core API的教程- 定制聯合算法第1 部分第2部分

有關tff.learning更多tff.learning ,請繼續閱讀《 聯邦學習的文本生成 》教程,該教程除介紹循環模型外,還演示瞭如何加載預訓練的序列化Keras模型以結合聯邦學習和Keras評估進行完善。

開始之前

在開始之前,請運行以下命令以確保正確設置您的環境。如果看不到問候語,請參閱安裝指南以獲取說明。


!pip install --quiet --upgrade tensorflow_federated_nightly
!pip install --quiet --upgrade nest_asyncio

import nest_asyncio
nest_asyncio.apply()

%load_ext tensorboard
import collections

import numpy as np
import tensorflow as tf
import tensorflow_federated as tff

np.random.seed(0)

tff.federated_computation(lambda: 'Hello, World!')()
b'Hello, World!'

準備輸入數據

讓我們從數據開始。聯合學習需要聯合數據集,即來自多個用戶的數據集合。聯合數據通常是非獨立的 ,這帶來了一系列獨特的挑戰。

為了便於進行實驗,我們在TFF存儲庫中註入了一些數據集,其中包括MNIST的聯合版本,該版本包含已使用Leaf重新處理過的原始NIST數據集的版本,以便由原始作者編寫密鑰數字。由於每個作者都有獨特的風格,因此該數據集展現了聯合數據集所期望的非同伴行為。

這是我們如何加載它。

emnist_train, emnist_test = tff.simulation.datasets.emnist.load_data()

load_data()返回的數據集是load_data()的實例, tff.simulation.ClientData的接口使您可以枚舉用戶集,構造代表特定用戶數據的tf.data.Dataset並查詢各個元素的結構。這是使用該界面瀏覽數據集內容的方法。請記住,儘管此界面允許您遍歷客戶端ID,但這只是模擬數據的功能。如您將很快看到的,聯合學習框架不使用客戶端身份-它們的唯一目的是允許您選擇數據的子集進行仿真。

len(emnist_train.client_ids)
3383
emnist_train.element_type_structure
OrderedDict([('pixels', TensorSpec(shape=(28, 28), dtype=tf.float32, name=None)), ('label', TensorSpec(shape=(), dtype=tf.int32, name=None))])
example_dataset = emnist_train.create_tf_dataset_for_client(
    emnist_train.client_ids[0])

example_element = next(iter(example_dataset))

example_element['label'].numpy()
1
from matplotlib import pyplot as plt
plt.imshow(example_element['pixels'].numpy(), cmap='gray', aspect='equal')
plt.grid(False)
_ = plt.show()

png

探索聯合數據中的異構性

聯合數據通常是非空閒的 ,根據使用模式,用戶通常具有不同的數據分佈。一些客戶端可能在設備上缺少較少的培訓示例,而本地缺乏數據,而某些客戶端將有足夠的培訓示例。讓我們用可用的EMNIST數據來探索這種聯合系統中典型的數據異構性概念。重要的是要注意,只有對我們才可以對客戶的數據進行深入分析,因為這是一個模擬環境,其中所有數據都可以在本地使用。在實際的生產聯合環境中,您將無法檢查單個客戶的數據。

首先,讓我們獲取一個客戶數據的樣本,以在一台模擬設備上感受示例。因為我們正在使用的數據集已經由唯一作者編寫了密鑰,所以一個客戶的數據代表0到9的數字樣本的一個人的筆跡,模擬了一個用戶的唯一“使用模式”。

## Example MNIST digits for one client
figure = plt.figure(figsize=(20, 4))
j = 0

for example in example_dataset.take(40):
  plt.subplot(4, 10, j+1)
  plt.imshow(example['pixels'].numpy(), cmap='gray', aspect='equal')
  plt.axis('off')
  j += 1

png

現在,讓我們直觀地看到每個客戶端上每個MNIST數字標籤的示例數量。在聯合環境中,每個客戶端上的示例數量可能會有所不同,具體取決於用戶的行為。

# Number of examples per layer for a sample of clients
f = plt.figure(figsize=(12, 7))
f.suptitle('Label Counts for a Sample of Clients')
for i in range(6):
  client_dataset = emnist_train.create_tf_dataset_for_client(
      emnist_train.client_ids[i])
  plot_data = collections.defaultdict(list)
  for example in client_dataset:
    # Append counts individually per label to make plots
    # more colorful instead of one color per plot.
    label = example['label'].numpy()
    plot_data[label].append(label)
  plt.subplot(2, 3, i+1)
  plt.title('Client {}'.format(i))
  for j in range(10):
    plt.hist(
        plot_data[j],
        density=False,
        bins=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10])

png

現在,讓我們可視化每個MNIST標籤的每個客戶端的平均圖像。該代碼將為一個標籤的所有用戶示例生成每個像素值的平均值。我們將看到,由於每個人獨特的手寫風格,一位客戶的平均位數圖像將與另一位客戶的同一位數看起來不同。當我們從該本地回合中該用戶自己的唯一數據中學習時,我們可以思考每個本地訓練回合將如何在每個客戶端上以不同的方向推動模型。在本教程的後面,我們將看到如何從所有客戶端獲取對模型的每次更新,並將它們匯總到我們新的全局模型中,該模型從我們每個客戶端自己的唯一數據中學到了。

# Each client has different mean images, meaning each client will be nudging
# the model in their own directions locally.

for i in range(5):
  client_dataset = emnist_train.create_tf_dataset_for_client(
      emnist_train.client_ids[i])
  plot_data = collections.defaultdict(list)
  for example in client_dataset:
    plot_data[example['label'].numpy()].append(example['pixels'].numpy())
  f = plt.figure(i, figsize=(12, 5))
  f.suptitle("Client #{}'s Mean Image Per Label".format(i))
  for j in range(10):
    mean_img = np.mean(plot_data[j], 0)
    plt.subplot(2, 5, j+1)
    plt.imshow(mean_img.reshape((28, 28)))
    plt.axis('off')

png

png

png

png

png

用戶數據可能嘈雜且標籤不可靠。例如,查看上面客戶2的數據,我們可以看到,對於標籤2,可能存在一些標註錯誤的示例,從而產生了較嘈雜的均值圖像。

預處理輸入數據

由於數據已經是tf.data.Dataset ,因此可以使用Dataset轉換完成預處理。在這裡,我們將28x28圖像展平為784元素的數組,將各個示例進行混洗,將它們組織成批,然後將特徵從pixelslabel重命名為xy以用於Keras。我們還對數據集進行repeat以運行多個時期。

NUM_CLIENTS = 10
NUM_EPOCHS = 5
BATCH_SIZE = 20
SHUFFLE_BUFFER = 100
PREFETCH_BUFFER= 10

def preprocess(dataset):

  def batch_format_fn(element):
    """Flatten a batch `pixels` and return the features as an `OrderedDict`."""
    return collections.OrderedDict(
        x=tf.reshape(element['pixels'], [-1, 784]),
        y=tf.reshape(element['label'], [-1, 1]))

  return dataset.repeat(NUM_EPOCHS).shuffle(SHUFFLE_BUFFER).batch(
      BATCH_SIZE).map(batch_format_fn).prefetch(PREFETCH_BUFFER)

讓我們驗證一下是否可行。

preprocessed_example_dataset = preprocess(example_dataset)

sample_batch = tf.nest.map_structure(lambda x: x.numpy(),
                                     next(iter(preprocessed_example_dataset)))

sample_batch
OrderedDict([('x', array([[1., 1., 1., ..., 1., 1., 1.],
       [1., 1., 1., ..., 1., 1., 1.],
       [1., 1., 1., ..., 1., 1., 1.],
       ...,
       [1., 1., 1., ..., 1., 1., 1.],
       [1., 1., 1., ..., 1., 1., 1.],
       [1., 1., 1., ..., 1., 1., 1.]], dtype=float32)), ('y', array([[2],
       [1],
       [2],
       [3],
       [6],
       [0],
       [1],
       [4],
       [1],
       [0],
       [6],
       [9],
       [9],
       [3],
       [6],
       [1],
       [4],
       [8],
       [0],
       [2]], dtype=int32))])

我們幾乎已具備構建聯邦數據集的所有構建塊。

在模擬中將聯合數據饋送到TFF的方法之一就是簡單地作為Python列表,列表的每個元素都保存單個用戶的數據,無論是作為列表還是作為tf.data.Dataset 。因為我們已經有一個提供後者的接口,所以讓我們使用它。

這是一個簡單的幫助程序功能,該功能將從給定的用戶集中構造數據集列表,作為一輪培訓或評估的輸入。

def make_federated_data(client_data, client_ids):
  return [
      preprocess(client_data.create_tf_dataset_for_client(x))
      for x in client_ids
  ]

現在,我們如何選擇客戶?

在典型的聯合培訓場景中,我們正在處理大量潛在的用戶設備,其中只有一小部分可以在給定的時間點進行培訓。例如,當客戶端設備是僅在插入電源,斷開計量網絡連接或處於空閒狀態時才參與培訓的移動電話時,就是這種情況。

當然,我們處於仿真環境中,所有數據都在本地可用。通常,然後,在運行模擬時,我們將簡單地抽樣要參與每一輪培訓的客戶的隨機子集,通常每一輪都不同。

就是說,正如您通過研究有關聯合平均算法的論文所發現的那樣,在每個回合中具有隨機採樣的客戶子集的系統中實現收斂可能需要一段時間,並且在其中進行數百回合是不切實際的。本互動教程。

相反,我們要做的是對一組客戶端採樣一次,並在各回合中重複使用同一組客戶端,以加快收斂速度(有意過分適應這幾位用戶的數據)。我們將其作為練習,供讀者修改本教程以模擬隨機抽樣-這相當容易做到(一旦這樣做,請記住,使模型收斂可能要花一些時間)。

sample_clients = emnist_train.client_ids[0:NUM_CLIENTS]

federated_train_data = make_federated_data(emnist_train, sample_clients)

print('Number of client datasets: {l}'.format(l=len(federated_train_data)))
print('First dataset: {d}'.format(d=federated_train_data[0]))
Number of client datasets: 10
First dataset: <DatasetV1Adapter shapes: OrderedDict([(x, (None, 784)), (y, (None, 1))]), types: OrderedDict([(x, tf.float32), (y, tf.int32)])>

用Keras創建模型

如果您正在使用Keras,則您可能已經具有構造Keras模型的代碼。這是一個滿足我們需求的簡單模型的示例。

def create_keras_model():
  return tf.keras.models.Sequential([
      tf.keras.layers.Input(shape=(784,)),
      tf.keras.layers.Dense(10, kernel_initializer='zeros'),
      tf.keras.layers.Softmax(),
  ])

為了將任何模型與TFF一起使用,需要將其包裝在tff.learning.Model接口的實例中,該接口與tff.learning.Model ,公開用於標記模型的前向通過,元數據屬性等的方法,但還引入了其他方法元素,例如控制計算聯合指標的過程的方式。我們暫時不用擔心。如果您具有上面剛剛定義的tff.learning.from_keras_model模型,則可以通過調用tff.learning.from_keras_model ,將模型和示例數據批處理作為參數,來讓TFF為您包裝它,如下所示。

def model_fn():
  # We _must_ create a new model here, and _not_ capture it from an external
  # scope. TFF will call this within different graph contexts.
  keras_model = create_keras_model()
  return tff.learning.from_keras_model(
      keras_model,
      input_spec=preprocessed_example_dataset.element_spec,
      loss=tf.keras.losses.SparseCategoricalCrossentropy(),
      metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])

在聯合數據上訓練模型

現在,我們有了一個包裝為tff.learning.Model的模型以與TFF tff.learning.Model使用,我們可以讓TFF通過調用輔助函數tff.learning.build_federated_averaging_process構造聯合平均算法,如下所示。

請記住,該參數必須是構造函數(例如上面的model_fn ),而不是已經構造的實例,以便可以在TFF控制的上下文中進行模型的構造(如果您對這樣做的原因感到好奇)為此,我們建議您閱讀有關自定義算法的後續教程)。

關於下面的聯合平均算法的一個重要說明,有2個優化器:_client 優化器和_server 優化器 。 _client 優化器僅用於計算每個客戶端上的本地模型更新。 _server 優化器將平均更新應用於服務器上的全局模型。特別是,這意味著所使用的優化程序和學習速率的選擇可能需要與您在標準iid數據集上訓練模型所使用的選擇不同。我們建議從常規SGD開始,學習速度可能會比平常小。我們使用的學習率尚未經過仔細調整,可以隨時嘗試。

52

剛才發生了什麼? TFF構建了一對聯合計算 ,並將它們打包到一個tff.templates.IterativeProcess中,其中,這些計算可以作為一對屬性initializenext

簡而言之, 聯邦計算是以TFF的內部語言編寫的程序,可以表示多種聯邦算法(您可以在自定義算法教程中找到有關此的更多信息)。在這種情況下,生成並打包到iterative_process的兩個計算實現了Federated Averaging

TFF的目標是以一種可以在真正的聯合學習設置中執行計算的方式定義計算,但是目前僅實現本地執行模擬運行時。要在模擬器中執行計算,您只需像Python函數一樣調用它即可。這個默認的解釋環境不是為高性能而設計的,但足以滿足本教程的要求。我們希望提供更高性能的仿真運行時,以促進將來版本中的大規模研究。

讓我們從initialize計算開始。與所有聯合計算一樣,您可以將其視為一個函數。該計算不接受任何參數,並返回一個結果-服務器上聯合平均進程狀態的表示。雖然我們不想深入了解TFF的細節,但了解這種狀態可能是有益的。您可以將其可視化如下。

str(iterative_process.initialize.type_signature)
'( -> <model=<trainable=<float32[784,10],float32[10]>,non_trainable=<>>,optimizer_state=<int64>,delta_aggregate_state=<>,model_broadcast_state=<>>@SERVER)'

雖然上述類型簽名乍看之下似乎有點晦澀難懂,但您可以認識到服務器狀態包括一個model (MNIST的初始模型參數,它將分配給所有設備)和optimizer_state (服務器維護的其他信息,例如用於超參數時間表的輪數等)。

讓我們調用initialize計算來構造服務器狀態。

state = iterative_process.initialize()

聯邦計算對中的第二個代表next一輪聯邦平均,包括將服務器狀態(包括模型參數)推送給客戶端,對本地數據進行設備上的培訓,收集並平均模型更新,並在服務器上生成新的更新模型。

從概念上講,您可以認為next具有如下所示的功能類型簽名。

SERVER_STATE, FEDERATED_DATA -> SERVER_STATE, TRAINING_METRICS

特別是,不應將next()視為運行在服務器上的函數,而應視為整個分散計算的聲明性函數表示形式-一些輸入由服務器提供( SERVER_STATE ),但每個輸入設備貢獻自己的本地數據集。

讓我們進行一次單輪訓練並可視化結果。我們可以將上面已經生成的聯合數據用於用戶樣本。

state, metrics = iterative_process.next(state, federated_train_data)
print('round  1, metrics={}'.format(metrics))
round  1, metrics=<broadcast=<>,aggregation=<>,train=<sparse_categorical_accuracy=0.12037037312984467,loss=3.0108425617218018>>

讓我們再進行幾輪。如前所述,通常在這一點上,您將為每個回合從隨機選擇的新用戶樣本中選擇一部分模擬數據,以模擬現實的部署,在該部署中,用戶不斷地來去去去,但是在此交互式筆記本中,為了演示起見,我們將重複使用相同的用戶,以便系統快速收斂。

NUM_ROUNDS = 11
for round_num in range(2, NUM_ROUNDS):
  state, metrics = iterative_process.next(state, federated_train_data)
  print('round {:2d}, metrics={}'.format(round_num, metrics))
round  2, metrics=<broadcast=<>,aggregation=<>,train=<sparse_categorical_accuracy=0.14814814925193787,loss=2.8865506649017334>>
round  3, metrics=<broadcast=<>,aggregation=<>,train=<sparse_categorical_accuracy=0.148765429854393,loss=2.9079062938690186>>
round  4, metrics=<broadcast=<>,aggregation=<>,train=<sparse_categorical_accuracy=0.17633745074272156,loss=2.724686622619629>>
round  5, metrics=<broadcast=<>,aggregation=<>,train=<sparse_categorical_accuracy=0.20226337015628815,loss=2.6334855556488037>>
round  6, metrics=<broadcast=<>,aggregation=<>,train=<sparse_categorical_accuracy=0.22427983582019806,loss=2.5482592582702637>>
round  7, metrics=<broadcast=<>,aggregation=<>,train=<sparse_categorical_accuracy=0.24094650149345398,loss=2.4472343921661377>>
round  8, metrics=<broadcast=<>,aggregation=<>,train=<sparse_categorical_accuracy=0.259876549243927,loss=2.3809611797332764>>
round  9, metrics=<broadcast=<>,aggregation=<>,train=<sparse_categorical_accuracy=0.29814815521240234,loss=2.156442403793335>>
round 10, metrics=<broadcast=<>,aggregation=<>,train=<sparse_categorical_accuracy=0.31687241792678833,loss=2.122845411300659>>

在每輪聯合訓練之後,訓練損失都在減少,這表明該模型正在收斂。這些訓練指標有一些重要的警告,但是,請參閱本教程後面的“ 評估 ”部分。

在TensorBoard中顯示模型指標

接下來,讓我們使用Tensorboard可視化來自這些聯合計算的指標。

讓我們首先創建目錄和相應的摘要編寫器以將度量寫入其中。


logdir = "/tmp/logs/scalars/training/"
summary_writer = tf.summary.create_file_writer(logdir)
state = iterative_process.initialize()

使用相同的摘要編寫器繪製相關的標量度量。


with summary_writer.as_default():
  for round_num in range(1, NUM_ROUNDS):
    state, metrics = iterative_process.next(state, federated_train_data)
    for name, value in metrics.train._asdict().items():
      tf.summary.scalar(name, value, step=round_num)

使用上面指定的根日誌目錄啟動TensorBoard。加載數據可能需要幾秒鐘。


%tensorboard --logdir /tmp/logs/scalars/ --port=0

# Run this this cell to clean your directory of old output for future graphs from this directory.
!rm -R /tmp/logs/scalars/*

為了以相同的方式查看評估指標,您可以創建一個單獨的eval文件夾,例如“ logs / scalars / eval”,以寫入TensorBoard。

定制模型實施

Keras是TensorFlow推薦高級模型API ,我們鼓勵盡可能在TFF中使用tff.learning.from_keras_model模型(通過tff.learning.from_keras_model )。

但是, tff.learning提供了一個較低級的模型接口tff.learning.Model ,它公開了使用模型進行聯合學習所需的最小功能。直接實現此接口(可能仍在使用諸如tf.keras.layers類的構造塊)可以實現最大程度的自定義,而無需修改聯合學習算法的內部結構。

因此,讓我們從頭開始重新做一遍。

定義模型變量,正向傳遞和度量

第一步是確定我們將要使用的TensorFlow變量。為了使以下代碼更清晰易懂,讓我們定義一個數據結構來表示整個集合。這將包括變量,如weightsbias ,我們將培訓,以及變量,將舉行各種累積的統計數據和計數器,我們將在培訓期間更新,如loss_sumaccuracy_sumnum_examples

MnistVariables = collections.namedtuple(
    'MnistVariables', 'weights bias num_examples loss_sum accuracy_sum')

這是創建變量的方法。為了簡單起見,我們將所有統計信息表示為tf.float32 ,因為這將在以後消除類型轉換的需要。將變量初始化程序包裝為lambda是資源變量施加的要求。

def create_mnist_variables():
  return MnistVariables(
      weights=tf.Variable(
          lambda: tf.zeros(dtype=tf.float32, shape=(784, 10)),
          name='weights',
          trainable=True),
      bias=tf.Variable(
          lambda: tf.zeros(dtype=tf.float32, shape=(10)),
          name='bias',
          trainable=True),
      num_examples=tf.Variable(0.0, name='num_examples', trainable=False),
      loss_sum=tf.Variable(0.0, name='loss_sum', trainable=False),
      accuracy_sum=tf.Variable(0.0, name='accuracy_sum', trainable=False))

有了模型參數的變量和累積統計信息,我們現在可以定義向前傳遞方法,該方法可以計算損失,發出預測並更新單批輸入數據的累積統計信息,如下所示。

def mnist_forward_pass(variables, batch):
  y = tf.nn.softmax(tf.matmul(batch['x'], variables.weights) + variables.bias)
  predictions = tf.cast(tf.argmax(y, 1), tf.int32)

  flat_labels = tf.reshape(batch['y'], [-1])
  loss = -tf.reduce_mean(
      tf.reduce_sum(tf.one_hot(flat_labels, 10) * tf.math.log(y), axis=[1]))
  accuracy = tf.reduce_mean(
      tf.cast(tf.equal(predictions, flat_labels), tf.float32))

  num_examples = tf.cast(tf.size(batch['y']), tf.float32)

  variables.num_examples.assign_add(num_examples)
  variables.loss_sum.assign_add(loss * num_examples)
  variables.accuracy_sum.assign_add(accuracy * num_examples)

  return loss, predictions

接下來,我們再次使用TensorFlow定義一個返回一組本地指標的函數。這些是值(除了自動更新的模型更新之外),可以在聯合學習或評估過程中將這些值聚合到服務器。

在這裡,我們僅返回平均lossaccuracy以及num_examples ,在計算聯合聚合時,我們需要正確地加權不同用戶的貢獻。

def get_local_mnist_metrics(variables):
  return collections.OrderedDict(
      num_examples=variables.num_examples,
      loss=variables.loss_sum / variables.num_examples,
      accuracy=variables.accuracy_sum / variables.num_examples)

最後,我們需要確定如何通過get_local_mnist_metrics匯總每個設備發出的本地指標。這是代碼中唯一沒有用TensorFlow編寫的部分-這是用TFF表示的聯合計算 。如果您想更深入地學習,請瀏覽自定義算法教程,但是在大多數應用程序中,您並不需要。下面顯示的模式的變體就足夠了。看起來是這樣的:

@tff.federated_computation
def aggregate_mnist_metrics_across_clients(metrics):
  return collections.OrderedDict(
      num_examples=tff.federated_sum(metrics.num_examples),
      loss=tff.federated_mean(metrics.loss, metrics.num_examples),
      accuracy=tff.federated_mean(metrics.accuracy, metrics.num_examples))
  

輸入metrics參數對應於上述get_local_mnist_metrics返回的OrderedDict ,但關鍵是值不再是tf.Tensors它們被裝箱為tff.Value ,為了清楚tff.Value ,您無法再使用TensorFlow對其進行操作,但是使用TFF的聯合運算符,例如tff.federated_meantff.federated_sum 。返回的全局聚合字典定義了一組在服務器上可用的度量。

構造一個tff.learning.Model的實例

完成上述所有操作後,我們就可以構建一個與TFF一起使用的模型表示,類似於讓TFF提取Keras模型時為您生成的模型表示。

class MnistModel(tff.learning.Model):

  def __init__(self):
    self._variables = create_mnist_variables()

  @property
  def trainable_variables(self):
    return [self._variables.weights, self._variables.bias]

  @property
  def non_trainable_variables(self):
    return []

  @property
  def local_variables(self):
    return [
        self._variables.num_examples, self._variables.loss_sum,
        self._variables.accuracy_sum
    ]

  @property
  def input_spec(self):
    return collections.OrderedDict(
        x=tf.TensorSpec([None, 784], tf.float32),
        y=tf.TensorSpec([None, 1], tf.int32))

  @tf.function
  def forward_pass(self, batch, training=True):
    del training
    loss, predictions = mnist_forward_pass(self._variables, batch)
    num_exmaples = tf.shape(batch['x'])[0]
    return tff.learning.BatchOutput(
        loss=loss, predictions=predictions, num_examples=num_exmaples)

  @tf.function
  def report_local_outputs(self):
    return get_local_mnist_metrics(self._variables)

  @property
  def federated_output_computation(self):
    return aggregate_mnist_metrics_across_clients

如您所見, tff.learning.Model定義的抽象方法和屬性與上一節中介紹變量並定義損失和統計信息的代碼段相對應。

這裡有幾點值得強調:

  • 您的模型將使用的所有狀態都必須捕獲為TensorFlow變量,因為TFF在運行時不使用Python(請記住,您的代碼應該編寫為可以部署到移動設備上;有關詳細信息,請參閱自定義算法教程)原因說明)。
  • 您的模型應描述其接受的數據形式( input_spec ),通常,TFF是強類型環境,並且希望確定所有組件的類型簽名。聲明模型輸入的格式是其中必不可少的一部分。
  • 儘管從技術上講不是必需的,但我們建議將所有TensorFlow邏輯(正向傳遞,度量計算等) tf.functiontf.function ,因為這有助於確保TensorFlow可以序列化,並且不需要顯式的控件依賴項。

以上對於評估和算法(例如聯合SGD)就足夠了。但是,對於聯合平均,我們需要指定模型如何在每個批次上進行本地訓練。在構建聯合平均算法時,我們將指定一個本地優化器。

使用新模型模擬聯合培訓

完成上述所有操作後,其餘過程看起來就像我們已經看到的一樣-只需將模型構造函數替換為新模型類的構造函數,然後在您創建的迭代過程中使用兩個聯合計算來循環訓練回合。

iterative_process = tff.learning.build_federated_averaging_process(
    MnistModel,
    client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=0.02))
state = iterative_process.initialize()
state, metrics = iterative_process.next(state, federated_train_data)
print('round  1, metrics={}'.format(metrics))
round  1, metrics=<broadcast=<>,aggregation=<>,train=<num_examples=4860.0,loss=2.9713594913482666,accuracy=0.13518518209457397>>

for round_num in range(2, 11):
  state, metrics = iterative_process.next(state, federated_train_data)
  print('round {:2d}, metrics={}'.format(round_num, metrics))
round  2, metrics=<broadcast=<>,aggregation=<>,train=<num_examples=4860.0,loss=2.975412607192993,accuracy=0.14032921195030212>>
round  3, metrics=<broadcast=<>,aggregation=<>,train=<num_examples=4860.0,loss=2.9395227432250977,accuracy=0.1594650149345398>>
round  4, metrics=<broadcast=<>,aggregation=<>,train=<num_examples=4860.0,loss=2.710164785385132,accuracy=0.17139917612075806>>
round  5, metrics=<broadcast=<>,aggregation=<>,train=<num_examples=4860.0,loss=2.5891618728637695,accuracy=0.20267489552497864>>
round  6, metrics=<broadcast=<>,aggregation=<>,train=<num_examples=4860.0,loss=2.5148487091064453,accuracy=0.21666666865348816>>
round  7, metrics=<broadcast=<>,aggregation=<>,train=<num_examples=4860.0,loss=2.2816808223724365,accuracy=0.2580246925354004>>
round  8, metrics=<broadcast=<>,aggregation=<>,train=<num_examples=4860.0,loss=2.3656885623931885,accuracy=0.25884774327278137>>
round  9, metrics=<broadcast=<>,aggregation=<>,train=<num_examples=4860.0,loss=2.23549222946167,accuracy=0.28477364778518677>>
round 10, metrics=<broadcast=<>,aggregation=<>,train=<num_examples=4860.0,loss=1.974222183227539,accuracy=0.35329216718673706>>

要在TensorBoard中查看這些指標,請參考上面“在TensorBoard中顯示模型指標”中列出的步驟。

評價

到目前為止,我們所有的實驗僅提供了聯邦訓練指標-整個回合中所有客戶訓練的所有數據批次的平均指標。這就引入了關於過度擬合的通常問題,特別是因為為簡單起見,我們在每一輪中都使用了相同的客戶端集,但是在針對聯合平均算法的訓練指標中還有過度擬合的概念。這很容易看出我們是否想像每個客戶端都有一整批數據,並且在該批數據上進行了多次迭代(曆元)訓練。在這種情況下,本地模型將很快完全適合該批次,因此我們平均的本地精度度量將接近1.0。因此,這些訓練指標可被視為訓練正在進行中的標誌,但僅此而已。

要對聯合數據執行評估,您可以使用tff.learning.build_federated_evaluation函數構造另一個為此目的而設計的聯合計算 ,並將模型構造函數作為參數傳入。請注意,與使用MnistTrainableModel聯合平均不同,它足以傳遞MnistModel 。評估不執行梯度下降,因此不需要構造優化器。

對於實驗和研究,當有集中式測試數據集可用時, 聯邦學習用於文本生成演示了另一個評估選項:從聯邦學習中獲取訓練後的權重,將其應用於標準Keras模型,然後簡單地調用tf.keras.models.Model.evaluate()在集中數據集上。

evaluation = tff.learning.build_federated_evaluation(MnistModel)

您可以按照以下步驟檢查評估函數的抽像類型簽名。

str(evaluation.type_signature)
'(<<trainable=<float32[784,10],float32[10]>,non_trainable=<>>@SERVER,{<x=float32[?,784],y=int32[?,1]>*}@CLIENTS> -> <num_examples=float32@SERVER,loss=float32@SERVER,accuracy=float32@SERVER>)'

此時無需擔心細節,只需注意它採用以下通用形式,類似於tff.templates.IterativeProcess.next但有兩個重要區別。首先,我們不返回服務器狀態,因為評估不會修改模型或狀態的任何其他方面-您可以將其視為無狀態。其次,評估只需要模型,不需要服務器狀態的任何其他部分,例如優化程序變量,這些部分都可能與培訓相關聯。

SERVER_MODEL, FEDERATED_DATA -> TRAINING_METRICS

讓我們對培訓期間達到的最新狀態進行評估。為了從服務器狀態中提取最新的經過訓練的模型,您只需訪問.model成員,如下所示。

train_metrics = evaluation(state.model, federated_train_data)

這就是我們得到的。請注意,這些數字看起來比上面的最後一輪培訓報告的數字略好。按照慣例,由迭代訓練過程報告的訓練指標通常會在訓練回合開始時反映模型的性能,因此評估指標將始終領先一步。

str(train_metrics)
'<num_examples=4860.0,loss=1.7142657041549683,accuracy=0.38683128356933594>'

現在,讓我們編譯一個聯邦數據的測試樣本,並對測試數據重新進行評估。數據將來自真實用戶的相同樣本,但來自截然不同的保留數據集。

federated_test_data = make_federated_data(emnist_test, sample_clients)

len(federated_test_data), federated_test_data[0]
(10,
 <DatasetV1Adapter shapes: OrderedDict([(x, (None, 784)), (y, (None, 1))]), types: OrderedDict([(x, tf.float32), (y, tf.int32)])>)
test_metrics = evaluation(state.model, federated_test_data)
str(test_metrics)
'<num_examples=580.0,loss=1.861915111541748,accuracy=0.3362068831920624>'

本教程到此結束。我們鼓勵您使用參數(例如批大小,用戶數量,時期,學習率等),修改上面的代碼以模擬每輪用戶隨機樣本的訓練,並探索其他教程。我們已經開發了。