Lưu ngày! Google I / O hoạt động trở lại từ ngày 18 đến 20 tháng 5 Đăng ký ngay
Trang này được dịch bởi Cloud Translation API.
Switch to English

Học liên kết để phân loại hình ảnh

Xem trên TensorFlow.org Chạy trong Google Colab Xem nguồn trên GitHub Tải xuống sổ ghi chép

Trong hướng dẫn này, chúng tôi sử dụng ví dụ đào tạo MNIST cổ điển để giới thiệu lớp API Học liên kết (FL) của TFF, tff.learning - một tập hợp các giao diện cấp cao hơn có thể được sử dụng để thực hiện các loại nhiệm vụ học liên kết phổ biến, chẳng hạn như đào tạo liên kết, chống lại các mô hình do người dùng cung cấp được triển khai trong TensorFlow.

Hướng dẫn này và API học liên kết, chủ yếu dành cho những người dùng muốn kết nối các mô hình TensorFlow của riêng họ vào TFF, coi mô hình sau chủ yếu là một hộp đen. Để hiểu sâu hơn về TFF và cách triển khai các thuật toán học liên kết của riêng bạn, hãy xem hướng dẫn về FC Core API - Thuật toán liên kết tùy chỉnh Phần 1Phần 2 .

Để biết thêm về tff.learning , hãy tiếp tục với Học liên kết cho Tạo văn bản , hướng dẫn ngoài việc bao gồm các mô hình lặp lại, còn trình bày việc tải mô hình Keras tuần tự được đào tạo trước để tinh chỉnh với học liên kết kết hợp với đánh giá bằng Keras.

Trước khi chúng ta bắt đầu

Trước khi chúng tôi bắt đầu, vui lòng chạy phần sau để đảm bảo rằng môi trường của bạn được thiết lập chính xác. Nếu bạn không thấy lời chào, vui lòng tham khảo Hướng dẫn cài đặt để được hướng dẫn.

# tensorflow_federated_nightly also bring in tf_nightly, which
# can causes a duplicate tensorboard install, leading to errors.
!pip uninstall --yes tensorboard tb-nightly

!pip install --quiet --upgrade tensorflow-federated-nightly
!pip install --quiet --upgrade nest-asyncio
!pip install --quiet --upgrade tb-nightly  # or tensorboard, but not both

import nest_asyncio
nest_asyncio.apply()
%load_ext tensorboard
Fetching TensorBoard MPM... done.
import collections

import numpy as np
import tensorflow as tf
import tensorflow_federated as tff

np.random.seed(0)

tff.federated_computation(lambda: 'Hello, World!')()
b'Hello, World!'

Chuẩn bị dữ liệu đầu vào

Hãy bắt đầu với dữ liệu. Học liên kết yêu cầu một tập dữ liệu được liên kết, tức là tập hợp dữ liệu từ nhiều người dùng. Dữ liệu Federated thường không iid , trong đó đặt ra một bộ duy nhất của những thách thức.

Để tạo điều kiện thuận lợi cho quá trình thử nghiệm, chúng tôi đã đưa vào kho lưu trữ TFF một số tập dữ liệu, bao gồm phiên bản liên kết của MNIST có chứa phiên bản của tập dữ liệu NIST gốc đã được xử lý lại bằng cách sử dụng Leaf để dữ liệu được người viết ban đầu của các chữ số. Vì mỗi người viết có một phong cách riêng, tập dữ liệu này thể hiện loại hành vi không ổn định như mong đợi của các tập dữ liệu liên kết.

Đây là cách chúng tôi có thể tải nó.

emnist_train, emnist_test = tff.simulation.datasets.emnist.load_data()

Các tập dữ liệu được trả về bởi load_data() là các thể hiện của tff.simulation.ClientData , một giao diện cho phép bạn liệt kê nhóm người dùng, để tạo mộttf.data.Dataset đại diện cho dữ liệu của một người dùng cụ thể và để truy vấn cấu trúc của các phần tử riêng lẻ. Đây là cách bạn có thể sử dụng giao diện này để khám phá nội dung của tập dữ liệu. Hãy nhớ rằng mặc dù giao diện này cho phép bạn lặp lại các id máy khách, nhưng đây chỉ là một tính năng của dữ liệu mô phỏng. Như bạn sẽ thấy ngay sau đây, danh tính khách hàng không được sử dụng bởi khung học tập liên kết - mục đích duy nhất của chúng là cho phép bạn chọn các tập hợp con của dữ liệu để mô phỏng.

len(emnist_train.client_ids)
3383
emnist_train.element_type_structure
OrderedDict([('pixels', TensorSpec(shape=(28, 28), dtype=tf.float32, name=None)), ('label', TensorSpec(shape=(), dtype=tf.int32, name=None))])
example_dataset = emnist_train.create_tf_dataset_for_client(
    emnist_train.client_ids[0])

example_element = next(iter(example_dataset))

example_element['label'].numpy()
1
from matplotlib import pyplot as plt
plt.imshow(example_element['pixels'].numpy(), cmap='gray', aspect='equal')
plt.grid(False)
_ = plt.show()

png

Khám phá sự không đồng nhất trong dữ liệu được liên kết

Dữ liệu Federated thường không iid , người dùng thường có sự phân bố dữ liệu khác nhau tùy thuộc vào thói quen sử dụng. Một số khách hàng có thể có ít ví dụ đào tạo hơn trên thiết bị, do dữ liệu bị mờ cục bộ, trong khi một số khách hàng sẽ có nhiều ví dụ đào tạo hơn. Hãy cùng khám phá khái niệm về tính không đồng nhất dữ liệu điển hình của một hệ thống liên kết với dữ liệu EMNIST mà chúng tôi có sẵn. Điều quan trọng cần lưu ý là phân tích sâu về dữ liệu của khách hàng chỉ có sẵn cho chúng tôi vì đây là môi trường mô phỏng nơi tất cả dữ liệu có sẵn cho chúng tôi tại địa phương. Trong môi trường liên kết sản xuất thực, bạn sẽ không thể kiểm tra dữ liệu của một khách hàng.

Đầu tiên, hãy lấy mẫu dữ liệu của một khách hàng để có cảm nhận về các ví dụ trên một thiết bị mô phỏng. Bởi vì tập dữ liệu chúng tôi đang sử dụng đã được khóa bởi người viết duy nhất, dữ liệu của một khách hàng đại diện cho chữ viết tay của một người cho một mẫu các chữ số từ 0 đến 9, mô phỏng "kiểu sử dụng" duy nhất của một người dùng.

## Example MNIST digits for one client
figure = plt.figure(figsize=(20, 4))
j = 0

for example in example_dataset.take(40):
  plt.subplot(4, 10, j+1)
  plt.imshow(example['pixels'].numpy(), cmap='gray', aspect='equal')
  plt.axis('off')
  j += 1

png

Bây giờ chúng ta hãy hình dung số lượng ví dụ trên mỗi máy khách cho mỗi nhãn chữ số MNIST. Trong môi trường liên kết, số lượng ví dụ trên mỗi máy khách có thể khác nhau khá nhiều, tùy thuộc vào hành vi của người dùng.

# Number of examples per layer for a sample of clients
f = plt.figure(figsize=(12, 7))
f.suptitle('Label Counts for a Sample of Clients')
for i in range(6):
  client_dataset = emnist_train.create_tf_dataset_for_client(
      emnist_train.client_ids[i])
  plot_data = collections.defaultdict(list)
  for example in client_dataset:
    # Append counts individually per label to make plots
    # more colorful instead of one color per plot.
    label = example['label'].numpy()
    plot_data[label].append(label)
  plt.subplot(2, 3, i+1)
  plt.title('Client {}'.format(i))
  for j in range(10):
    plt.hist(
        plot_data[j],
        density=False,
        bins=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10])

png

Bây giờ hãy hình dung hình ảnh trung bình trên mỗi khách hàng cho mỗi nhãn MNIST. Mã này sẽ tạo ra giá trị trung bình của mỗi pixel cho tất cả các ví dụ của người dùng cho một nhãn. Chúng ta sẽ thấy rằng hình ảnh trung bình của một khách hàng cho một chữ số sẽ trông khác với hình ảnh trung bình của một khách hàng khác cho cùng một chữ số, do kiểu chữ viết tay độc đáo của mỗi người. Chúng tôi có thể tìm hiểu về cách mỗi vòng đào tạo địa phương sẽ thúc đẩy mô hình theo một hướng khác nhau đối với từng khách hàng, vì chúng tôi đang học hỏi từ dữ liệu duy nhất của chính người dùng đó trong vòng đào tạo tại địa phương đó. Ở phần sau của hướng dẫn, chúng ta sẽ xem cách chúng ta có thể nhận từng bản cập nhật cho mô hình từ tất cả các khách hàng và tổng hợp chúng lại với nhau thành mô hình toàn cầu mới của chúng tôi, mô hình này đã học được từ dữ liệu duy nhất của mỗi khách hàng của chúng tôi.

# Each client has different mean images, meaning each client will be nudging
# the model in their own directions locally.

for i in range(5):
  client_dataset = emnist_train.create_tf_dataset_for_client(
      emnist_train.client_ids[i])
  plot_data = collections.defaultdict(list)
  for example in client_dataset:
    plot_data[example['label'].numpy()].append(example['pixels'].numpy())
  f = plt.figure(i, figsize=(12, 5))
  f.suptitle("Client #{}'s Mean Image Per Label".format(i))
  for j in range(10):
    mean_img = np.mean(plot_data[j], 0)
    plt.subplot(2, 5, j+1)
    plt.imshow(mean_img.reshape((28, 28)))
    plt.axis('off')

png

png

png

png

png

Dữ liệu người dùng có thể bị nhiễu và được gắn nhãn không đáng tin cậy. Ví dụ: nhìn vào dữ liệu của Khách hàng số 2 ở trên, chúng ta có thể thấy rằng đối với nhãn 2, có thể có một số ví dụ được gắn nhãn sai tạo ra một hình ảnh trung bình ồn ào hơn.

Xử lý trước dữ liệu đầu vào

Vì dữ liệu đã làtf.data.Dataset , nên việc xử lý trước có thể được thực hiện bằng cách sử dụng các phép biến đổi Dataset. Ở đây, chúng tôi làm phẳng các hình ảnh 28x28 thành các mảng 784 28x28 , xáo trộn các ví dụ riêng lẻ, sắp xếp chúng thành các lô và đổi tên các đối tượng địa lý từ pixelslabel thành xy để sử dụng với Keras. Chúng tôi cũng repeat tập dữ liệu để chạy một số kỷ nguyên.

NUM_CLIENTS = 10
NUM_EPOCHS = 5
BATCH_SIZE = 20
SHUFFLE_BUFFER = 100
PREFETCH_BUFFER = 10

def preprocess(dataset):

  def batch_format_fn(element):
    """Flatten a batch `pixels` and return the features as an `OrderedDict`."""
    return collections.OrderedDict(
        x=tf.reshape(element['pixels'], [-1, 784]),
        y=tf.reshape(element['label'], [-1, 1]))

  return dataset.repeat(NUM_EPOCHS).shuffle(SHUFFLE_BUFFER).batch(
      BATCH_SIZE).map(batch_format_fn).prefetch(PREFETCH_BUFFER)

Hãy xác minh điều này đã hoạt động.

preprocessed_example_dataset = preprocess(example_dataset)

sample_batch = tf.nest.map_structure(lambda x: x.numpy(),
                                     next(iter(preprocessed_example_dataset)))

sample_batch
OrderedDict([('x', array([[1., 1., 1., ..., 1., 1., 1.],
       [1., 1., 1., ..., 1., 1., 1.],
       [1., 1., 1., ..., 1., 1., 1.],
       ...,
       [1., 1., 1., ..., 1., 1., 1.],
       [1., 1., 1., ..., 1., 1., 1.],
       [1., 1., 1., ..., 1., 1., 1.]], dtype=float32)), ('y', array([[0],
       [5],
       [0],
       [1],
       [3],
       [0],
       [5],
       [4],
       [1],
       [7],
       [0],
       [4],
       [0],
       [1],
       [7],
       [2],
       [2],
       [0],
       [7],
       [1]], dtype=int32))])

Chúng tôi có gần như tất cả các khối xây dựng để xây dựng các tập dữ liệu được liên kết.

Một trong những cách để cung cấp dữ liệu liên kết cho TFF trong mô phỏng đơn giản là dưới dạng danh sách Python, với mỗi phần tử của danh sách chứa dữ liệu của một người dùng cá nhân, cho dù dưới dạng danh sách haytf.data.Dataset . Vì chúng ta đã có một giao diện cung cấp giao diện thứ hai, hãy sử dụng nó.

Đây là một hàm trợ giúp đơn giản sẽ tạo danh sách các tập dữ liệu từ một nhóm người dùng nhất định làm đầu vào cho một vòng đào tạo hoặc đánh giá.

def make_federated_data(client_data, client_ids):
  return [
      preprocess(client_data.create_tf_dataset_for_client(x))
      for x in client_ids
  ]

Bây giờ, chúng ta chọn khách hàng như thế nào?

Trong một kịch bản đào tạo liên hợp điển hình, chúng tôi đang đối phó với một lượng lớn thiết bị người dùng tiềm năng, chỉ một phần nhỏ trong số đó có thể sẵn sàng để đào tạo tại một thời điểm nhất định. Đây là trường hợp, ví dụ, khi các thiết bị khách hàng là điện thoại di động tham gia đào tạo chỉ khi được cắm vào nguồn điện, không kết nối với mạng đo lường, và nếu không thì không hoạt động.

Tất nhiên, chúng tôi đang ở trong một môi trường mô phỏng và tất cả dữ liệu đều có sẵn tại địa phương. Thông thường, khi chạy mô phỏng, chúng tôi chỉ cần lấy mẫu ngẫu nhiên một tập hợp con khách hàng tham gia vào mỗi vòng đào tạo, nói chung là khác nhau trong mỗi vòng.

Điều đó nói rằng, như bạn có thể tìm hiểu bằng cách nghiên cứu bài báo về thuật toán Trung bình liên kết , đạt được sự hội tụ trong một hệ thống với các tập hợp con khách hàng được lấy mẫu ngẫu nhiên trong mỗi vòng có thể mất một lúc và sẽ không thực tế nếu phải chạy hàng trăm vòng trong hướng dẫn tương tác này.

Thay vào đó, những gì chúng tôi sẽ làm là lấy mẫu nhóm khách hàng một lần và sử dụng lại nhóm khách hàng tương tự qua các vòng để tăng tốc độ hội tụ (cố ý phù hợp quá mức với dữ liệu của một số người dùng này). Chúng tôi để nó như một bài tập cho người đọc để sửa đổi hướng dẫn này để mô phỏng lấy mẫu ngẫu nhiên - nó khá dễ thực hiện (một khi bạn làm vậy, hãy nhớ rằng việc đưa mô hình hội tụ có thể mất một lúc).

sample_clients = emnist_train.client_ids[0:NUM_CLIENTS]

federated_train_data = make_federated_data(emnist_train, sample_clients)

print('Number of client datasets: {l}'.format(l=len(federated_train_data)))
print('First dataset: {d}'.format(d=federated_train_data[0]))
Number of client datasets: 10
First dataset: <DatasetV1Adapter shapes: OrderedDict([(x, (None, 784)), (y, (None, 1))]), types: OrderedDict([(x, tf.float32), (y, tf.int32)])>

Tạo mô hình với Keras

Nếu bạn đang sử dụng Keras, bạn có thể đã có mã xây dựng mô hình Keras. Đây là một ví dụ về một mô hình đơn giản sẽ đáp ứng đủ nhu cầu của chúng tôi.

def create_keras_model():
  return tf.keras.models.Sequential([
      tf.keras.layers.Input(shape=(784,)),
      tf.keras.layers.Dense(10, kernel_initializer='zeros'),
      tf.keras.layers.Softmax(),
  ])

Để sử dụng bất kỳ mô hình nào với TFF, nó cần được bao bọc trong một thể hiện của giao diện tff.learning.Model , giao diện này hiển thị các phương thức đóng dấu chuyển tiếp của mô hình, thuộc tính siêu dữ liệu, v.v., tương tự như Keras, nhưng cũng giới thiệu thêm các yếu tố, chẳng hạn như các cách kiểm soát quá trình tính toán các chỉ số liên hợp. Bây giờ chúng ta đừng lo lắng về điều này; nếu bạn có một mô hình Keras giống như mô hình mà chúng tôi vừa xác định ở trên, bạn có thể yêu cầu TFF bọc nó cho bạn bằng cách gọi tff.learning.from_keras_model , chuyển mô hình và một lô dữ liệu mẫu làm đối số, như được hiển thị bên dưới.

def model_fn():
  # We _must_ create a new model here, and _not_ capture it from an external
  # scope. TFF will call this within different graph contexts.
  keras_model = create_keras_model()
  return tff.learning.from_keras_model(
      keras_model,
      input_spec=preprocessed_example_dataset.element_spec,
      loss=tf.keras.losses.SparseCategoricalCrossentropy(),
      metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])

Đào tạo mô hình trên dữ liệu liên kết

Bây giờ chúng ta có một mô hình được bao bọc dưới dạng tff.learning.Model để sử dụng với TFF, chúng ta có thể để TFF xây dựng thuật toán Trung bình liên kết bằng cách gọi hàm trợ giúp tff.learning.build_federated_averaging_process , như sau.

Hãy nhớ rằng đối số cần phải là một phương thức khởi tạo (chẳng hạn như model_fn ở trên), không phải là một phiên bản đã được xây dựng, để việc xây dựng mô hình của bạn có thể xảy ra trong ngữ cảnh do TFF kiểm soát (nếu bạn tò mò về lý do điều này, chúng tôi khuyến khích bạn đọc hướng dẫn tiếp theo về các thuật toán tùy chỉnh ).

Một lưu ý quan trọng về thuật toán Trung bình liên kết bên dưới, có 2 trình tối ưu hóa: trình tối ưu hóa _client và trình tối ưu hóa _server. Trình tối ưu hóa _client chỉ được sử dụng để tính toán các bản cập nhật mô hình cục bộ trên mỗi máy khách. Trình tối ưu hóa _server áp dụng cập nhật trung bình cho mô hình chung tại máy chủ. Đặc biệt, điều này có nghĩa là lựa chọn trình tối ưu hóa và tốc độ học được sử dụng có thể cần phải khác với lựa chọn bạn đã sử dụng để đào tạo mô hình trên tập dữ liệu iid tiêu chuẩn. Chúng tôi khuyên bạn nên bắt đầu với SGD thông thường, có thể với tỷ lệ học tập nhỏ hơn bình thường. Tỷ lệ học tập mà chúng tôi sử dụng chưa được điều chỉnh cẩn thận, hãy thoải mái thử nghiệm.

iterative_process = tff.learning.build_federated_averaging_process(
    model_fn,
    client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=0.02),
    server_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=1.0))

Chuyện gì vừa xảy ra vậy? TFF đã xây dựng một cặp tính toán liên hợp và đóng gói chúng thành một tff.templates.IterativeProcess trong đó những tính toán này có sẵn dưới dạng một cặp thuộc tính initializenext .

Tóm lại, tính toán liên hợp là các chương trình bằng ngôn ngữ nội bộ của TFF có thể thể hiện các thuật toán liên hợp khác nhau (bạn có thể tìm thêm về điều này trong hướng dẫn thuật toán tùy chỉnh ). Trong trường hợp này, hai phép tính được tạo và đóng gói thành iterative_process thực hiện Trung bình Liên kết .

Mục tiêu của TFF là xác định các tính toán theo cách mà chúng có thể được thực thi trong cài đặt học tập liên kết thực, nhưng hiện tại chỉ thời gian chạy mô phỏng thực thi cục bộ mới được thực hiện. Để thực thi một phép tính trong trình mô phỏng, bạn chỉ cần gọi nó giống như một hàm Python. Môi trường thông dịch mặc định này không được thiết kế cho hiệu suất cao, nhưng nó sẽ đủ cho hướng dẫn này; chúng tôi hy vọng sẽ cung cấp thời gian chạy mô phỏng hiệu suất cao hơn để tạo điều kiện cho nghiên cứu quy mô lớn hơn trong các bản phát hành trong tương lai.

Hãy bắt đầu với tính toán initialize . Như trường hợp của tất cả các phép tính liên hợp, bạn có thể coi nó như một hàm. Tính toán không có đối số và trả về một kết quả - biểu diễn trạng thái của quá trình Tính trung bình liên kết trên máy chủ. Mặc dù chúng tôi không muốn đi sâu vào chi tiết của TFF, nhưng có thể là hướng dẫn để xem trạng thái này trông như thế nào. Bạn có thể hình dung nó như sau.

str(iterative_process.initialize.type_signature)
'( -> <model=<trainable=<float32[784,10],float32[10]>,non_trainable=<>>,optimizer_state=<int64>,delta_aggregate_state=<value_sum_process=<>,weight_sum_process=<>>,model_broadcast_state=<>>@SERVER)'

Mặc dù chữ ký kiểu trên thoạt đầu có vẻ hơi khó hiểu, nhưng bạn có thể nhận ra rằng trạng thái máy chủ bao gồm một model (các thông số mô hình ban đầu cho MNIST sẽ được phân phối cho tất cả các thiết bị) và optimizer_state (thông tin bổ sung được duy trì bởi máy chủ, chẳng hạn như số vòng để sử dụng cho lịch biểu siêu tham số, v.v.).

Hãy gọi tính toán initialize để xây dựng trạng thái máy chủ.

state = iterative_process.initialize()

Phần thứ hai của cặp phép tính liên kết, next , đại diện cho một vòng duy nhất của Tính trung bình liên kết, bao gồm việc đẩy trạng thái máy chủ (bao gồm các thông số mô hình) cho khách hàng, đào tạo trên thiết bị về dữ liệu cục bộ của họ, thu thập và tính trung bình các bản cập nhật mô hình và tạo ra một mô hình cập nhật mới tại máy chủ.

Về mặt khái niệm, bạn có thể nghĩ next là có một chữ ký kiểu chức năng trông như sau.

SERVER_STATE, FEDERATED_DATA -> SERVER_STATE, TRAINING_METRICS

Đặc biệt, người ta nên nghĩ về next() không phải là một hàm chạy trên máy chủ, mà là một biểu diễn chức năng khai báo của toàn bộ tính toán phi tập trung - một số đầu vào được cung cấp bởi máy chủ ( SERVER_STATE ), nhưng mỗi đầu vào đều tham gia thiết bị đóng góp tập dữ liệu cục bộ của riêng nó.

Hãy chạy một vòng đào tạo và hình dung kết quả. Chúng tôi có thể sử dụng dữ liệu được liên kết mà chúng tôi đã tạo ở trên cho một mẫu người dùng.

state, metrics = iterative_process.next(state, federated_train_data)
print('round  1, metrics={}'.format(metrics))
round  1, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('value_sum_process', ()), ('weight_sum_process', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.11502057), ('loss', 3.244929)]))])

Hãy chạy thêm vài vòng nữa. Như đã lưu ý trước đó, thông thường tại thời điểm này, bạn sẽ chọn một tập hợp con dữ liệu mô phỏng của mình từ một mẫu người dùng mới được chọn ngẫu nhiên cho mỗi vòng để mô phỏng một triển khai thực tế trong đó người dùng liên tục đến và đi, nhưng trong sổ ghi chép tương tác này, cho vì mục đích minh chứng là chúng tôi sẽ chỉ sử dụng lại những người dùng giống nhau, để hệ thống hội tụ nhanh chóng.

NUM_ROUNDS = 11
for round_num in range(2, NUM_ROUNDS):
  state, metrics = iterative_process.next(state, federated_train_data)
  print('round {:2d}, metrics={}'.format(round_num, metrics))
round  2, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('value_sum_process', ()), ('weight_sum_process', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.14609054), ('loss', 2.9141645)]))])
round  3, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('value_sum_process', ()), ('weight_sum_process', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.15205762), ('loss', 2.9237952)]))])
round  4, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('value_sum_process', ()), ('weight_sum_process', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.18600823), ('loss', 2.7629454)]))])
round  5, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('value_sum_process', ()), ('weight_sum_process', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.20884773), ('loss', 2.622908)]))])
round  6, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('value_sum_process', ()), ('weight_sum_process', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.21872428), ('loss', 2.543587)]))])
round  7, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('value_sum_process', ()), ('weight_sum_process', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.2372428), ('loss', 2.4210362)]))])
round  8, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('value_sum_process', ()), ('weight_sum_process', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.28209877), ('loss', 2.2297976)]))])
round  9, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('value_sum_process', ()), ('weight_sum_process', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.2685185), ('loss', 2.195803)]))])
round 10, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('value_sum_process', ()), ('weight_sum_process', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.33868313), ('loss', 2.0523348)]))])

Tổn thất trong huấn luyện đang giảm dần sau mỗi đợt huấn luyện liên đoàn, cho thấy mô hình đang hội tụ. Tuy nhiên, có một số lưu ý quan trọng với các chỉ số đào tạo này, hãy xem phần Đánh giá ở phần sau của hướng dẫn này.

Hiển thị số liệu mô hình trong TensorBoard

Tiếp theo, hãy hình dung các số liệu từ các phép tính liên hợp này bằng cách sử dụng Tensorboard.

Hãy bắt đầu bằng cách tạo thư mục và trình viết tóm tắt tương ứng để ghi các số liệu vào.

logdir = "/tmp/logs/scalars/training/"
summary_writer = tf.summary.create_file_writer(logdir)
state = iterative_process.initialize()

Vẽ biểu đồ các chỉ số vô hướng có liên quan với cùng một người viết tóm tắt.

with summary_writer.as_default():
  for round_num in range(1, NUM_ROUNDS):
    state, metrics = iterative_process.next(state, federated_train_data)
    for name, value in metrics['train'].items():
      tf.summary.scalar(name, value, step=round_num)

Khởi động TensorBoard với thư mục nhật ký gốc được chỉ định ở trên. Có thể mất vài giây để tải dữ liệu.

!ls {logdir}
%tensorboard --logdir {logdir} --port=0
events.out.tfevents.1604020204.isim77-20020ad609500000b02900f40f27a5f6.prod.google.com.686098.10633.v2
events.out.tfevents.1604020602.isim77-20020ad609500000b02900f40f27a5f6.prod.google.com.794554.10607.v2
Launching TensorBoard...
<IPython.core.display.Javascript at 0x7fc5e8d3c128>
# Uncomment and run this this cell to clean your directory of old output for
# future graphs from this directory. We don't run it by default so that if 
# you do a "Runtime > Run all" you don't lose your results.

# !rm -R /tmp/logs/scalars/*

Để xem các chỉ số đánh giá theo cách tương tự, bạn có thể tạo một thư mục eval riêng, như "nhật ký / vô hướng / eval", để ghi vào TensorBoard.

Tùy chỉnh việc triển khai mô hình

Keras là API mô hình cấp cao được đề xuất cho TensorFlow và chúng tôi khuyến khích sử dụng mô hình Keras (thông qua tff.learning.from_keras_model ) trong TFF bất cứ khi nào có thể.

Tuy nhiên, tff.learning cung cấp giao diện mô hình cấp thấp hơn, tff.learning.Model , hiển thị chức năng tối thiểu cần thiết để sử dụng mô hình cho việc học liên kết. Việc triển khai trực tiếp giao diện này (có thể vẫn sử dụng các khối xây dựng nhưtf.keras.layers ) cho phép tùy chỉnh tối đa mà không cần sửa đổi nội bộ của các thuật toán học liên hợp.

Vì vậy, chúng ta hãy làm lại từ đầu.

Xác định các biến mô hình, chuyển tiếp và số liệu

Bước đầu tiên là xác định các biến TensorFlow mà chúng ta sẽ làm việc với. Để làm cho đoạn mã sau dễ đọc hơn, hãy xác định cấu trúc dữ liệu để đại diện cho toàn bộ tập hợp. Điều này sẽ bao gồm các biến như weightsbias rằng chúng tôi sẽ đào tạo, cũng như các biến mà sẽ tổ chức thống kê khác nhau tích lũy và quầy chúng tôi sẽ cập nhật trong thời gian đào tạo, chẳng hạn như loss_sum , accuracy_sum , và num_examples .

MnistVariables = collections.namedtuple(
    'MnistVariables', 'weights bias num_examples loss_sum accuracy_sum')

Đây là một phương pháp tạo các biến. Để đơn giản, chúng tôi biểu thị tất cả các thống kê dưới dạng tf.float32 , vì điều đó sẽ loại bỏ nhu cầu chuyển đổi kiểu ở giai đoạn sau. Gói các bộ khởi tạo biến dưới dạng lambdas là một yêu cầu do các biến tài nguyên áp đặt.

def create_mnist_variables():
  return MnistVariables(
      weights=tf.Variable(
          lambda: tf.zeros(dtype=tf.float32, shape=(784, 10)),
          name='weights',
          trainable=True),
      bias=tf.Variable(
          lambda: tf.zeros(dtype=tf.float32, shape=(10)),
          name='bias',
          trainable=True),
      num_examples=tf.Variable(0.0, name='num_examples', trainable=False),
      loss_sum=tf.Variable(0.0, name='loss_sum', trainable=False),
      accuracy_sum=tf.Variable(0.0, name='accuracy_sum', trainable=False))

Với các biến cho tham số mô hình và thống kê tích lũy đã có sẵn, giờ đây chúng ta có thể xác định phương pháp chuyển tiếp tính toán tổn thất, đưa ra dự đoán và cập nhật thống kê tích lũy cho một lô dữ liệu đầu vào, như sau.

def mnist_forward_pass(variables, batch):
  y = tf.nn.softmax(tf.matmul(batch['x'], variables.weights) + variables.bias)
  predictions = tf.cast(tf.argmax(y, 1), tf.int32)

  flat_labels = tf.reshape(batch['y'], [-1])
  loss = -tf.reduce_mean(
      tf.reduce_sum(tf.one_hot(flat_labels, 10) * tf.math.log(y), axis=[1]))
  accuracy = tf.reduce_mean(
      tf.cast(tf.equal(predictions, flat_labels), tf.float32))

  num_examples = tf.cast(tf.size(batch['y']), tf.float32)

  variables.num_examples.assign_add(num_examples)
  variables.loss_sum.assign_add(loss * num_examples)
  variables.accuracy_sum.assign_add(accuracy * num_examples)

  return loss, predictions

Tiếp theo, chúng tôi xác định một hàm trả về một tập hợp các chỉ số cục bộ, một lần nữa bằng cách sử dụng TensorFlow. Đây là các giá trị (ngoài các bản cập nhật mô hình, được xử lý tự động) đủ điều kiện để được tổng hợp vào máy chủ trong quá trình học tập hoặc đánh giá được liên kết.

Ở đây, chúng tôi chỉ trả về lossaccuracy trung bình, cũng như num_examples , chúng tôi sẽ cần cân nhắc chính xác các đóng góp từ những người dùng khác nhau khi tính toán tổng hợp được liên kết.

def get_local_mnist_metrics(variables):
  return collections.OrderedDict(
      num_examples=variables.num_examples,
      loss=variables.loss_sum / variables.num_examples,
      accuracy=variables.accuracy_sum / variables.num_examples)

Cuối cùng, chúng ta cần xác định cách tổng hợp các chỉ số cục bộ do mỗi thiết bị phát ra thông qua get_local_mnist_metrics . Đây là phần duy nhất của mã không được viết trong TensorFlow - đó là một phép tính liên hợp được thể hiện trong TFF. Nếu bạn muốn tìm hiểu sâu hơn, hãy đọc lướt qua hướng dẫn thuật toán tùy chỉnh , nhưng trong hầu hết các ứng dụng, bạn sẽ không thực sự cần; các biến thể của mẫu hiển thị bên dưới là đủ. Đây là những gì nó trông giống như:

@tff.federated_computation
def aggregate_mnist_metrics_across_clients(metrics):
  return collections.OrderedDict(
      num_examples=tff.federated_sum(metrics.num_examples),
      loss=tff.federated_mean(metrics.loss, metrics.num_examples),
      accuracy=tff.federated_mean(metrics.accuracy, metrics.num_examples))

Đối metrics đầu vào tương ứng với OrderedDict trả về bởi get_local_mnist_metrics ở trên, nhưng quan trọng là các giá trị không còn là tf.Tensors - chúng được "đóng hộp" dưới dạng tff.Value s, để làm rõ rằng bạn không còn có thể thao tác chúng bằng TensorFlow nữa, mà chỉ sử dụng các toán tử liên kết của TFF như tff.federated_meantff.federated_sum . Từ điển tổng hợp toàn cầu được trả về xác định tập hợp số liệu sẽ có sẵn trên máy chủ.

Xây dựng một phiên bản của tff.learning.Model

Với tất cả những điều trên, chúng tôi đã sẵn sàng xây dựng một biểu diễn mô hình để sử dụng với TFF tương tự như một biểu diễn được tạo cho bạn khi bạn cho phép TFF nhập mô hình Keras.

class MnistModel(tff.learning.Model):

  def __init__(self):
    self._variables = create_mnist_variables()

  @property
  def trainable_variables(self):
    return [self._variables.weights, self._variables.bias]

  @property
  def non_trainable_variables(self):
    return []

  @property
  def local_variables(self):
    return [
        self._variables.num_examples, self._variables.loss_sum,
        self._variables.accuracy_sum
    ]

  @property
  def input_spec(self):
    return collections.OrderedDict(
        x=tf.TensorSpec([None, 784], tf.float32),
        y=tf.TensorSpec([None, 1], tf.int32))

  @tf.function
  def forward_pass(self, batch, training=True):
    del training
    loss, predictions = mnist_forward_pass(self._variables, batch)
    num_exmaples = tf.shape(batch['x'])[0]
    return tff.learning.BatchOutput(
        loss=loss, predictions=predictions, num_examples=num_exmaples)

  @tf.function
  def report_local_outputs(self):
    return get_local_mnist_metrics(self._variables)

  @property
  def federated_output_computation(self):
    return aggregate_mnist_metrics_across_clients

Như bạn có thể thấy, các phương thức và thuộc tính trừu tượng được xác định bởi tff.learning.Model tương ứng với các đoạn mã trong phần trước đã giới thiệu các biến và xác định tổn thất và thống kê.

Dưới đây là một số điểm đáng chú ý:

  • Tất cả trạng thái mà mô hình của bạn sẽ sử dụng phải được ghi lại dưới dạng biến TensorFlow, vì TFF không sử dụng Python trong thời gian chạy (hãy nhớ mã của bạn phải được viết sao cho nó có thể được triển khai cho các thiết bị di động; xem hướng dẫn thuật toán tùy chỉnh để biết thêm chi tiết bình luận về lý do).
  • Mô hình của bạn nên mô tả dạng dữ liệu mà nó chấp nhận ( input_spec ), vì nói chung, TFF là một môi trường được đánh máy mạnh và muốn xác định chữ ký kiểu cho tất cả các thành phần. Khai báo định dạng của đầu vào mô hình của bạn là một phần thiết yếu của nó.
  • Mặc dù về mặt kỹ thuật không bắt buộc, chúng tôi khuyên bạn nên gói tất cả logic TensorFlow (chuyển tiếp, tính toán số liệu, v.v.) dưới dạng tf.function . tf.function , vì điều này giúp đảm bảo TensorFlow có thể được tuần tự hóa và loại bỏ nhu cầu phụ thuộc điều khiển rõ ràng.

Trên đây là đủ để đánh giá và các thuật toán như Federated SGD. Tuy nhiên, đối với Tính trung bình liên kết, chúng ta cần chỉ định cách mô hình sẽ đào tạo cục bộ trên mỗi lô. Chúng tôi sẽ chỉ định một trình tối ưu hóa cục bộ khi xây dựng thuật toán Trung bình Liên kết.

Mô phỏng đào tạo liên đoàn với mô hình mới

Với tất cả những điều ở trên, phần còn lại của quá trình trông giống như những gì chúng ta đã thấy - chỉ cần thay thế hàm tạo mô hình bằng hàm tạo của lớp mô hình mới của chúng tôi và sử dụng hai phép tính được liên kết trong quy trình lặp lại mà bạn đã tạo để chuyển qua các vòng huấn luyện.

iterative_process = tff.learning.build_federated_averaging_process(
    MnistModel,
    client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=0.02))
state = iterative_process.initialize()
state, metrics = iterative_process.next(state, federated_train_data)
print('round  1, metrics={}'.format(metrics))
round  1, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('value_sum_process', ()), ('weight_sum_process', ())])), ('train', OrderedDict([('num_examples', 4860.0), ('loss', 3.1527398), ('accuracy', 0.12469136)]))])
for round_num in range(2, 11):
  state, metrics = iterative_process.next(state, federated_train_data)
  print('round {:2d}, metrics={}'.format(round_num, metrics))
round  2, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('value_sum_process', ()), ('weight_sum_process', ())])), ('train', OrderedDict([('num_examples', 4860.0), ('loss', 2.941014), ('accuracy', 0.14218107)]))])
round  3, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('value_sum_process', ()), ('weight_sum_process', ())])), ('train', OrderedDict([('num_examples', 4860.0), ('loss', 2.9052832), ('accuracy', 0.14444445)]))])
round  4, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('value_sum_process', ()), ('weight_sum_process', ())])), ('train', OrderedDict([('num_examples', 4860.0), ('loss', 2.7491086), ('accuracy', 0.17962962)]))])
round  5, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('value_sum_process', ()), ('weight_sum_process', ())])), ('train', OrderedDict([('num_examples', 4860.0), ('loss', 2.5129666), ('accuracy', 0.19526748)]))])
round  6, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('value_sum_process', ()), ('weight_sum_process', ())])), ('train', OrderedDict([('num_examples', 4860.0), ('loss', 2.4175923), ('accuracy', 0.23600823)]))])
round  7, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('value_sum_process', ()), ('weight_sum_process', ())])), ('train', OrderedDict([('num_examples', 4860.0), ('loss', 2.4273515), ('accuracy', 0.24176955)]))])
round  8, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('value_sum_process', ()), ('weight_sum_process', ())])), ('train', OrderedDict([('num_examples', 4860.0), ('loss', 2.2426176), ('accuracy', 0.2802469)]))])
round  9, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('value_sum_process', ()), ('weight_sum_process', ())])), ('train', OrderedDict([('num_examples', 4860.0), ('loss', 2.1567981), ('accuracy', 0.295679)]))])
round 10, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('value_sum_process', ()), ('weight_sum_process', ())])), ('train', OrderedDict([('num_examples', 4860.0), ('loss', 2.1092515), ('accuracy', 0.30843621)]))])

Để xem các chỉ số này trong TensorBoard, hãy tham khảo các bước được liệt kê ở trên trong "Hiển thị chỉ số mô hình trong TensorBoard".

Đánh giá

Tất cả các thử nghiệm của chúng tôi cho đến nay chỉ trình bày các chỉ số đào tạo được liên kết - các chỉ số trung bình trên tất cả các lô dữ liệu được đào tạo trên tất cả các khách hàng trong vòng. Điều này dẫn đến những lo ngại bình thường về việc trang bị quá nhiều, đặc biệt là vì chúng tôi đã sử dụng cùng một nhóm khách hàng trên mỗi vòng để đơn giản hóa, nhưng có thêm khái niệm về trang bị quá mức trong các chỉ số đào tạo cụ thể cho thuật toán Trung bình liên kết. Điều này dễ thấy nhất nếu chúng ta tưởng tượng mỗi khách hàng có một lô dữ liệu duy nhất và chúng tôi đào tạo trên lô đó cho nhiều lần lặp lại (kỷ nguyên). Trong trường hợp này, mô hình cục bộ sẽ nhanh chóng phù hợp chính xác với một lô đó và do đó, chỉ số độ chính xác cục bộ mà chúng tôi trung bình sẽ đạt tới 1,0. Do đó, các số liệu đào tạo này có thể được coi là một dấu hiệu cho thấy việc đào tạo đang tiến bộ, nhưng không nhiều hơn.

Để thực hiện đánh giá trên dữ liệu được liên kết, bạn có thể xây dựng một phép tính liên kết khác được thiết kế cho mục đích này, bằng cách sử dụng hàm tff.learning.build_federated_evaluation và chuyển vào hàm tạo mô hình của bạn dưới dạng đối số. Lưu ý rằng không giống như với Trung bình liên kết, nơi chúng tôi đã sử dụng MnistTrainableModel , nó đủ để vượt qua MnistModel . Đánh giá không thực hiện giảm độ dốc và không cần phải xây dựng các trình tối ưu hóa.

Đối với thử nghiệm và nghiên cứu, khi bộ dữ liệu kiểm tra tập trung có sẵn, Học liên kết cho Tạo văn bản trình bày một tùy chọn đánh giá khác: lấy các trọng số được đào tạo từ học liên kết, áp dụng chúng vào mô hình Keras chuẩn, sau đó chỉ cần gọi tf.keras.models.Model.evaluate() trên tập dữ liệu tập trung.

evaluation = tff.learning.build_federated_evaluation(MnistModel)

Bạn có thể kiểm tra chữ ký kiểu trừu tượng của hàm đánh giá như sau.

str(evaluation.type_signature)
'(<server_model_weights=<trainable=<float32[784,10],float32[10]>,non_trainable=<>>@SERVER,federated_dataset={<x=float32[?,784],y=int32[?,1]>*}@CLIENTS> -> <num_examples=float32@SERVER,loss=float32@SERVER,accuracy=float32@SERVER>)'

Không cần quan tâm đến các chi tiết tại thời điểm này, chỉ cần lưu ý rằng nó có dạng chung sau đây, tương tự như tff.templates.IterativeProcess.next nhưng có hai điểm khác biệt quan trọng. Đầu tiên, chúng tôi không trả lại trạng thái máy chủ, vì đánh giá không sửa đổi mô hình hoặc bất kỳ khía cạnh nào khác của trạng thái - bạn có thể coi nó là trạng thái không trạng thái. Thứ hai, đánh giá chỉ cần mô hình và không yêu cầu bất kỳ phần nào khác của trạng thái máy chủ có thể được liên kết với đào tạo, chẳng hạn như các biến trình tối ưu hóa.

SERVER_MODEL, FEDERATED_DATA -> TRAINING_METRICS

Hãy gọi đánh giá về trạng thái mới nhất mà chúng tôi đạt được trong quá trình đào tạo. Để trích xuất mô hình được đào tạo mới nhất từ ​​trạng thái máy chủ, bạn chỉ cần truy cập thành viên .model , như sau.

train_metrics = evaluation(state.model, federated_train_data)

Đây là những gì chúng tôi nhận được. Lưu ý rằng các con số trông tốt hơn một chút so với những gì được báo cáo bởi vòng đào tạo cuối cùng ở trên. Theo quy ước, các chỉ số đào tạo được báo cáo bởi quá trình đào tạo lặp đi lặp lại thường phản ánh hiệu suất của mô hình khi bắt đầu vòng đào tạo, do đó, các chỉ số đánh giá sẽ luôn đi trước một bước.

str(train_metrics)
'<num_examples=4860.0,loss=1.7142657041549683,accuracy=0.38683128356933594>'

Bây giờ, hãy biên dịch một mẫu thử nghiệm của dữ liệu liên kết và chạy lại đánh giá trên dữ liệu thử nghiệm. Dữ liệu sẽ đến từ cùng một mẫu người dùng thực, nhưng từ một tập dữ liệu riêng biệt.

federated_test_data = make_federated_data(emnist_test, sample_clients)

len(federated_test_data), federated_test_data[0]
(10,
 <DatasetV1Adapter shapes: OrderedDict([(x, (None, 784)), (y, (None, 1))]), types: OrderedDict([(x, tf.float32), (y, tf.int32)])>)
test_metrics = evaluation(state.model, federated_test_data)
str(test_metrics)
'<num_examples=580.0,loss=1.861915111541748,accuracy=0.3362068831920624>'

Điều này kết thúc hướng dẫn. Chúng tôi khuyến khích bạn chơi với các thông số (ví dụ: kích thước lô, số lượng người dùng, kỷ nguyên, tỷ lệ học tập, v.v.), để sửa đổi mã ở trên để mô phỏng đào tạo trên các mẫu ngẫu nhiên của người dùng trong mỗi vòng và để khám phá các hướng dẫn khác chúng tôi đã phát triển.