Làm việc với ClientData của tff.

Xem trên TensorFlow.org Chạy trong Google Colab Xem nguồn trên GitHub Tải xuống sổ ghi chép

Khái niệm về tập dữ liệu do khách hàng (ví dụ người dùng) khóa là cần thiết cho tính toán liên hợp như được mô hình hóa trong TFF. TFF cung cấp giao diện tff.simulation.datasets.ClientData để trừu tượng hơn khái niệm này, và tập hợp dữ liệu mà host TFF ( stackoverflow , shakespeare , emnist , cifar100 , và gldv2 ) tất cả các cài đặt giao diện này.

Nếu bạn đang làm việc trên học tập liên với bộ dữ liệu của riêng bạn, TFF mạnh mẽ khuyến khích bạn hoặc thực hiện các ClientData một giao diện hoặc sử dụng các hàm helper của TFF để tạo ra một ClientData đại diện cho dữ liệu của bạn trên đĩa, ví dụ như tff.simulation.datasets.ClientData.from_clients_and_fn .

Như hầu hết các ví dụ end-to-end TFF bắt đầu với ClientData đối tượng, thực hiện ClientData giao diện với dữ liệu tùy chỉnh của bạn sẽ làm cho nó dễ dàng hơn để spelunk qua mã hiện bằng văn bản với TFF. Hơn nữa, tf.data.DatasetsClientData cấu trúc có thể được lặp trên trực tiếp để mang lại cấu trúc của numpy mảng, vì vậy ClientData đối tượng có thể được sử dụng với bất kỳ khuôn khổ ML Python dựa trên trước khi chuyển đến TFF.

Có một số mẫu mà bạn có thể làm cho cuộc sống của mình dễ dàng hơn nếu bạn có ý định mở rộng mô phỏng của mình cho nhiều máy hoặc triển khai chúng. Dưới đây chúng tôi sẽ đi bộ qua một vài trong số những cách chúng tôi có thể sử dụng ClientData và TFF để làm cho quy mô nhỏ lặp-to quy mô lớn thử nghiệm để sản xuất kinh nghiệm triển khai của chúng tôi như mịn càng tốt.

Tôi nên sử dụng mẫu nào để chuyển ClientData vào TFF?

Chúng tôi sẽ thảo luận về hai tập quán của TFF của ClientData sâu; nếu bạn phù hợp với một trong hai loại dưới đây, rõ ràng bạn sẽ thích loại này hơn loại kia. Nếu không, bạn có thể cần hiểu chi tiết hơn về ưu và nhược điểm của từng loại để đưa ra lựa chọn phù hợp hơn.

  • Tôi muốn lặp lại càng nhanh càng tốt trên một máy cục bộ; Tôi không cần phải có thể dễ dàng tận dụng thời gian chạy phân tán của TFF.

    • Bạn muốn vượt qua tf.data.Datasets để TFF trực tiếp.
    • Điều này cho phép bạn để chương trình phải nhất thiết với tf.data.Dataset đối tượng, và xử lý chúng tùy tiện.
    • Nó cung cấp tính linh hoạt hơn tùy chọn bên dưới; đẩy logic đến các máy khách yêu cầu logic này phải được tuần tự hóa.
  • Tôi muốn chạy tính toán liên hợp của mình trong thời gian chạy từ xa của TFF hoặc tôi dự định làm như vậy sớm.

    • Trong trường hợp này, bạn muốn ánh xạ việc xây dựng và xử lý trước tập dữ liệu cho các máy khách.
    • Kết quả trong bạn này đi qua đơn giản là một danh sách các client_ids trực tiếp đến tính liên của bạn.
    • Việc đẩy quá trình xây dựng và xử lý trước tập dữ liệu tới các máy khách sẽ tránh tắc nghẽn trong tuần tự hóa và tăng đáng kể hiệu suất với hàng trăm đến hàng nghìn máy khách.

Thiết lập môi trường nguồn mở

Nhập gói

Thao tác một đối tượng ClientData

Chúng ta hãy bắt đầu bằng cách bốc thăm dò EMNIST TFF của ClientData :

client_data, _ = tff.simulation.datasets.emnist.load_data()
Downloading emnist_all.sqlite.lzma: 100%|██████████| 170507172/170507172 [00:19<00:00, 8831921.67it/s]
2021-10-01 11:17:58.718735: E tensorflow/stream_executor/cuda/cuda_driver.cc:271] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected

Kiểm tra việc tập dữ liệu đầu tiên có thể cho chúng tôi biết những gì loại ví dụ là trong ClientData .

first_client_id = client_data.client_ids[0]
first_client_dataset = client_data.create_tf_dataset_for_client(
    first_client_id)
print(first_client_dataset.element_spec)
# This information is also available as a `ClientData` property:
assert client_data.element_type_structure == first_client_dataset.element_spec
OrderedDict([('label', TensorSpec(shape=(), dtype=tf.int32, name=None)), ('pixels', TensorSpec(shape=(28, 28), dtype=tf.float32, name=None))])

Lưu ý rằng sản lượng dữ liệu collections.OrderedDict đối tượng mà có pixelslabel phím, nơi pixel là một tensor với hình dạng [28, 28] . Giả sử chúng ta muốn san bằng đầu vào của chúng tôi ra hình dạng [784] . Một cách có thể chúng ta có thể làm được điều này sẽ được áp dụng một chức năng tiền xử lý để chúng tôi ClientData đối tượng.

def preprocess_dataset(dataset):
  """Create batches of 5 examples, and limit to 3 batches."""

  def map_fn(input):
    return collections.OrderedDict(
        x=tf.reshape(input['pixels'], shape=(-1, 784)),
        y=tf.cast(tf.reshape(input['label'], shape=(-1, 1)), tf.int64),
    )

  return dataset.batch(5).map(
      map_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE).take(5)


preprocessed_client_data = client_data.preprocess(preprocess_dataset)

# Notice that we have both reshaped and renamed the elements of the ordered dict.
first_client_dataset = preprocessed_client_data.create_tf_dataset_for_client(
    first_client_id)
print(first_client_dataset.element_spec)
OrderedDict([('x', TensorSpec(shape=(None, 784), dtype=tf.float32, name=None)), ('y', TensorSpec(shape=(None, 1), dtype=tf.int64, name=None))])

Chúng tôi có thể muốn ngoài việc thực hiện một số tiền xử lý phức tạp hơn (và có thể là trạng thái), chẳng hạn như xáo trộn.

def preprocess_and_shuffle(dataset):
  """Applies `preprocess_dataset` above and shuffles the result."""
  preprocessed = preprocess_dataset(dataset)
  return preprocessed.shuffle(buffer_size=5)

preprocessed_and_shuffled = client_data.preprocess(preprocess_and_shuffle)

# The type signature will remain the same, but the batches will be shuffled.
first_client_dataset = preprocessed_and_shuffled.create_tf_dataset_for_client(
    first_client_id)
print(first_client_dataset.element_spec)
OrderedDict([('x', TensorSpec(shape=(None, 784), dtype=tf.float32, name=None)), ('y', TensorSpec(shape=(None, 1), dtype=tf.int64, name=None))])

Giao diện với một tff.Computation

Bây giờ chúng ta có thể thực hiện một số thao tác cơ bản với ClientData các đối tượng, chúng tôi đã sẵn sàng để dữ liệu thức ăn chăn nuôi đến một tff.Computation . Chúng ta định nghĩa một tff.templates.IterativeProcess mà thực hiện Federated trung bình , và khám phá các phương pháp khác nhau đi qua nó dữ liệu.

def model_fn():
  model = tf.keras.models.Sequential([
      tf.keras.layers.InputLayer(input_shape=(784,)),
      tf.keras.layers.Dense(10, kernel_initializer='zeros'),
  ])
  return tff.learning.from_keras_model(
      model,
      # Note: input spec is the _batched_ shape, and includes the 
      # label tensor which will be passed to the loss function. This model is
      # therefore configured to accept data _after_ it has been preprocessed.
      input_spec=collections.OrderedDict(
          x=tf.TensorSpec(shape=[None, 784], dtype=tf.float32),
          y=tf.TensorSpec(shape=[None, 1], dtype=tf.int64)),
      loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
      metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])

trainer = tff.learning.build_federated_averaging_process(
    model_fn,
    client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=0.01))

Trước khi chúng tôi bắt đầu làm việc với điều này IterativeProcess , một bình luận về ngữ nghĩa của ClientData là theo thứ tự. Một ClientData đối tượng đại diện cho toàn bộ dân số sẵn sàng cho đào tạo liên kết, mà nói chung là không có sẵn cho các môi trường thực thi của một sản xuất hệ thống FL và là đặc trưng cho mô phỏng. ClientData thực sự mang đến cho người dùng khả năng bỏ qua máy tính liên hoàn toàn và chỉ cần đào tạo một mô hình server-side như bình thường thông qua ClientData.create_tf_dataset_from_all_clients .

Môi trường mô phỏng của TFF đặt nhà nghiên cứu kiểm soát hoàn toàn vòng lặp bên ngoài. Đặc biệt, điều này ngụ ý các cân nhắc về tính khả dụng của ứng dụng khách, tình trạng khách hàng bỏ học, v.v., phải được giải quyết bởi người dùng hoặc tập lệnh trình điều khiển Python. Một ví dụ có thể mô hình client bỏ học bằng cách điều chỉnh sự phân bố lấy mẫu trên bạn ClientData's client_ids như vậy mà người dùng với nhiều dữ liệu (và tương ứng còn chạy tính toán địa phương) sẽ được chọn với xác suất thấp hơn.

Tuy nhiên, trong một hệ thống liên hợp thực sự, người huấn luyện mô hình không thể chọn khách hàng một cách rõ ràng; việc lựa chọn khách hàng được ủy quyền cho hệ thống đang thực hiện tính toán liên hợp.

Đi qua tf.data.Datasets trực tiếp đến TFF

Một lựa chọn chúng tôi đã cho interfacing giữa một ClientData và một IterativeProcess là xây dựng tf.data.Datasets bằng Python, và đi qua các tập hợp dữ liệu để TFF.

Chú ý rằng nếu chúng ta sử dụng tiền xử lý của chúng tôi ClientData các tập hợp dữ liệu, chúng tôi mang lại là các loại thích hợp dự kiến theo mô hình của chúng tôi được xác định ở trên.

selected_client_ids = preprocessed_and_shuffled.client_ids[:10]

preprocessed_data_for_clients = [
    preprocessed_and_shuffled.create_tf_dataset_for_client(
        selected_client_ids[i]) for i in range(10)
]

state = trainer.initialize()
for _ in range(5):
  t1 = time.time()
  state, metrics = trainer.next(state, preprocessed_data_for_clients)
  t2 = time.time()
  print('loss {}, round time {}'.format(metrics['train']['loss'], t2 - t1))
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow_federated/python/core/impl/compiler/tensorflow_computation_transformations.py:62: extract_sub_graph (from tensorflow.python.framework.graph_util_impl) is deprecated and will be removed in a future version.
Instructions for updating:
Use `tf.compat.v1.graph_util.extract_sub_graph`
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow_federated/python/core/impl/compiler/tensorflow_computation_transformations.py:62: extract_sub_graph (from tensorflow.python.framework.graph_util_impl) is deprecated and will be removed in a future version.
Instructions for updating:
Use `tf.compat.v1.graph_util.extract_sub_graph`
loss 2.9005744457244873, round time 4.576513767242432
loss 3.113278388977051, round time 0.49641919136047363
loss 2.7581865787506104, round time 0.4904160499572754
loss 2.87259578704834, round time 0.48976993560791016
loss 3.1202380657196045, round time 0.6724586486816406

Nếu chúng ta lấy đường này, tuy nhiên, chúng tôi sẽ không thể trivially chuyển sang mô phỏng multimachine. Các bộ dữ liệu chúng ta xây dựng trong thời gian chạy TensorFlow địa phương có thể nắm bắt được trạng thái từ môi trường xung quanh trăn, và thất bại trong serialization hoặc deserialization khi họ cố gắng để nhà nước tham khảo mà không còn có sẵn cho họ. Điều này có thể biểu hiện ví dụ như trong các lỗi khó hiểu từ TensorFlow của tensor_util.cc :

Check failed: DT_VARIANT == input.dtype() (21 vs. 20)

Lập bản đồ xây dựng và xử lý trước qua các khách hàng

Để tránh vấn đề này, TFF khuyến cáo người sử dụng để xem xét dữ liệu instantiation và tiền xử lý như một cái gì đó xảy ra cục bộ trên mỗi khách hàng, và sử dụng người giúp đỡ TFF hay federated_map để chạy một cách rõ ràng mã này tiền xử lý tại mỗi khách hàng.

Về mặt khái niệm, lý do thích điều này rất rõ ràng: trong thời gian chạy cục bộ của TFF, các máy khách chỉ "vô tình" có quyền truy cập vào môi trường Python toàn cầu do thực tế là toàn bộ điều phối liên hợp đang diễn ra trên một máy duy nhất. Điều đáng chú ý là ở điểm này, tư duy tương tự làm nảy sinh triết lý chức năng đa nền tảng, luôn có thể tuần tự hóa, của TFF.

TFF làm như vậy một sự thay đổi đơn giản thông qua ClientData's thuộc tính dataset_computation , một tff.Computation mà phải mất một client_id và trả về liên tf.data.Dataset .

Lưu ý rằng preprocess chỉ đơn giản là làm việc với dataset_computation ; các dataset_computation thuộc tính của preprocessed ClientData kết hợp toàn bộ đường ống dẫn tiền xử lý, chúng tôi chỉ được xác định:

print('dataset computation without preprocessing:')
print(client_data.dataset_computation.type_signature)
print('\n')
print('dataset computation with preprocessing:')
print(preprocessed_and_shuffled.dataset_computation.type_signature)
dataset computation without preprocessing:
(string -> <label=int32,pixels=float32[28,28]>*)


dataset computation with preprocessing:
(string -> <x=float32[?,784],y=int64[?,1]>*)

Chúng ta có thể gọi dataset_computation và nhận một bộ dữ liệu háo hức trong thời gian chạy Python, nhưng sức mạnh thực sự của phương pháp này được thực hiện khi chúng ta soạn với một quá trình lặp này hay cách khác tính toán để tránh vật chất hóa những tập hợp dữ liệu trong thời gian chạy háo hức toàn cầu ở tất cả. TFF cung cấp một hàm helper tff.simulation.compose_dataset_computation_with_iterative_process mà có thể được sử dụng để thực hiện chính xác này.

trainer_accepting_ids = tff.simulation.compose_dataset_computation_with_iterative_process(
    preprocessed_and_shuffled.dataset_computation, trainer)

Cả hai này tff.templates.IterativeProcesses và một ở trên chạy theo cùng một cách; nhưng cựu chấp nhận bộ dữ liệu của khách hàng xử lý trước, và sau này chấp nhận chuỗi đại diện cho id khách hàng, xử lý cả hai xây dựng bộ dữ liệu và tiền xử lý trong cơ thể của nó - trong thực tế state có thể được thông qua giữa hai người.

for _ in range(5):
  t1 = time.time()
  state, metrics = trainer_accepting_ids.next(state, selected_client_ids)
  t2 = time.time()
  print('loss {}, round time {}'.format(metrics['train']['loss'], t2 - t1))
loss 2.8417396545410156, round time 1.6707067489624023
loss 2.7670371532440186, round time 0.5207102298736572
loss 2.665048122406006, round time 0.5302855968475342
loss 2.7213189601898193, round time 0.5313887596130371
loss 2.580148935317993, round time 0.5283482074737549

Mở rộng quy mô tới số lượng lớn khách hàng

trainer_accepting_ids ngay lập tức có thể được sử dụng trong thời gian chạy multimachine TFF, và tránh được vật chất hóa tf.data.Datasets và bộ điều khiển (và do đó serializing họ và gửi chúng ra để người lao động).

Điều này làm tăng tốc đáng kể các mô phỏng phân tán, đặc biệt là với một số lượng lớn máy khách và cho phép tổng hợp trung gian để tránh chi phí tuần tự hóa / giải mã hóa tương tự.

Deepdive tùy chọn: soạn logic tiền xử lý theo cách thủ công trong TFF

TFF được thiết kế cho tính tổng hợp từ cơ bản; loại bố cục do người trợ giúp của TFF thực hiện hoàn toàn nằm trong tầm kiểm soát của chúng tôi với tư cách là người dùng. Chúng ta có thể có tay soạn việc tính toán tiền xử lý, chúng tôi chỉ được xác định với các huấn luyện viên của mình next khá đơn giản:

selected_clients_type = tff.FederatedType(preprocessed_and_shuffled.dataset_computation.type_signature.parameter, tff.CLIENTS)

@tff.federated_computation(trainer.next.type_signature.parameter[0], selected_clients_type)
def new_next(server_state, selected_clients):
  preprocessed_data = tff.federated_map(preprocessed_and_shuffled.dataset_computation, selected_clients)
  return trainer.next(server_state, preprocessed_data)

manual_trainer_with_preprocessing = tff.templates.IterativeProcess(initialize_fn=trainer.initialize, next_fn=new_next)

Trên thực tế, đây là cách mà người trợ giúp chúng tôi sử dụng đang thực hiện một cách hiệu quả (cộng với việc thực hiện thao tác và kiểm tra loại phù hợp). Chúng tôi thậm chí có thể đã bày tỏ cùng một logic hơi khác nhau, bởi serializing preprocess_and_shuffle thành một tff.Computation , và phân hủy các federated_map vào một bước mà xây dựng bộ dữ liệu un-preprocessed và khác mà chạy preprocess_and_shuffle tại mỗi khách hàng.

Chúng tôi có thể xác minh rằng đường dẫn thủ công hơn này dẫn đến các phép tính có cùng kiểu chữ ký với trình trợ giúp của TFF (tên tham số modulo):

print(trainer_accepting_ids.next.type_signature)
print(manual_trainer_with_preprocessing.next.type_signature)
(<server_state=<model=<trainable=<float32[784,10],float32[10]>,non_trainable=<>>,optimizer_state=<int64>,delta_aggregate_state=<value_sum_process=<>,weight_sum_process=<>>,model_broadcast_state=<>>@SERVER,federated_dataset={string}@CLIENTS> -> <<model=<trainable=<float32[784,10],float32[10]>,non_trainable=<>>,optimizer_state=<int64>,delta_aggregate_state=<value_sum_process=<>,weight_sum_process=<>>,model_broadcast_state=<>>@SERVER,<broadcast=<>,aggregation=<mean_value=<>,mean_weight=<>>,train=<sparse_categorical_accuracy=float32,loss=float32>,stat=<num_examples=int64>>@SERVER>)
(<server_state=<model=<trainable=<float32[784,10],float32[10]>,non_trainable=<>>,optimizer_state=<int64>,delta_aggregate_state=<value_sum_process=<>,weight_sum_process=<>>,model_broadcast_state=<>>@SERVER,selected_clients={string}@CLIENTS> -> <<model=<trainable=<float32[784,10],float32[10]>,non_trainable=<>>,optimizer_state=<int64>,delta_aggregate_state=<value_sum_process=<>,weight_sum_process=<>>,model_broadcast_state=<>>@SERVER,<broadcast=<>,aggregation=<mean_value=<>,mean_weight=<>>,train=<sparse_categorical_accuracy=float32,loss=float32>,stat=<num_examples=int64>>@SERVER>)