tff의 ClientData 작업.

TensorFlow.org에서 보기 Google Colab에서 실행 GitHub에서 소스 보기 노트북 다운로드

클라이언트(예: 사용자)에 의해 키가 지정된 데이터 세트의 개념은 TFF에서 모델링된 연합 계산에 필수적입니다. TFF는 인터페이스 제공 tff.simulation.datasets.ClientData 이 개념을 통해 추상적으로, 어느 TFF 호스트 (데이터 집합 유래 , 셰익스피어 , emnist , cifar100gldv2는 ) 모든 인터페이스를 구현합니다.

당신이 당신의 자신의 세트와 함께 연합 학습에서 작업하는 경우, TFF는 강하게 당신이 중 하나를 구현하기 위해 권장 ClientData 생성 TFF의 도우미 기능의 인터페이스 또는 사용을 ClientData 예를 들어 디스크에 데이터를 나타냅니다, tff.simulation.datasets.ClientData.from_clients_and_fn .

TFF의 엔드 - 투 - 엔드 대부분의 예제로 시작으로 ClientData 이행, 객체 ClientData TFF로 작성된 기존 코드를 통해 spelunk을 쉽게 만들 것입니다 사용자 정의 데이터 세트와 인터페이스를. 또한, tf.data.Datasets ClientData 구조물의 구조체가 수득 바로 위에 반복 될 수 numpy 그렇게 배열을 ClientData 개체 이동 TFF 전에 파이썬 ML 기반 프레임 워크를 사용할 수있다.

시뮬레이션을 여러 시스템으로 확장하거나 배포하려는 경우 삶을 더 쉽게 만들 수 있는 몇 가지 패턴이 있습니다. 우리는 우리가 사용할 수있는 방법 중 몇 가지를 걸어 아래 ClientData 우리의 작은 규모의 대규모 반복-에 생산 실험-에 배포 경험을하고 TFF를 가능한 한 부드럽게.

ClientData를 TFF로 전달하려면 어떤 패턴을 사용해야 합니까?

우리는 TFF의 두 가지 용도에 대해 설명합니다 ClientData 깊이를; 아래 두 범주 중 하나에 해당한다면 분명히 다른 범주보다 선호할 것입니다. 그렇지 않은 경우 보다 미묘한 선택을 하기 위해 각각의 장단점을 보다 자세히 이해해야 할 수 있습니다.

  • 로컬 시스템에서 가능한 한 빨리 반복하고 싶습니다. TFF의 분산 런타임을 쉽게 활용할 필요가 없습니다.

    • 당신은 전달하려는 tf.data.Datasets 직접 TFF에에.
    • 이것은 당신이 긴급하게 프로그래밍 할 수 tf.data.Dataset 객체, 임의로이를 처리합니다.
    • 아래 옵션보다 더 많은 유연성을 제공합니다. 클라이언트에 논리를 푸시하려면 이 논리를 직렬화할 수 있어야 합니다.
  • TFF의 원격 런타임에서 연합 계산을 실행하고 싶거나 곧 그렇게 할 계획입니다.

    • 이 경우 데이터 세트 구성 및 사전 처리를 클라이언트에 매핑하려고 합니다.
    • 당신의이 결과는 단순히 목록 통과 client_ids 직접 페더 레이 티드 계산에.
    • 클라이언트에 데이터 세트 구성 및 전처리를 푸시하면 직렬화의 병목 현상이 방지되고 수백에서 수천 개의 클라이언트에서 성능이 크게 향상됩니다.

오픈 소스 환경 설정

패키지 가져오기

ClientData 객체 조작

의로드 및 TFF의 EMNIST 탐험에 의해 시작하자 ClientData :

client_data, _ = tff.simulation.datasets.emnist.load_data()
Downloading emnist_all.sqlite.lzma: 100%|██████████| 170507172/170507172 [00:19<00:00, 8831921.67it/s]
2021-10-01 11:17:58.718735: E tensorflow/stream_executor/cuda/cuda_driver.cc:271] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected

첫 번째 데이터 집합을 검사하면에있는 예제의 어떤 종류의 우리를 말할 수 ClientData .

first_client_id = client_data.client_ids[0]
first_client_dataset = client_data.create_tf_dataset_for_client(
    first_client_id)
print(first_client_dataset.element_spec)
# This information is also available as a `ClientData` property:
assert client_data.element_type_structure == first_client_dataset.element_spec
OrderedDict([('label', TensorSpec(shape=(), dtype=tf.int32, name=None)), ('pixels', TensorSpec(shape=(28, 28), dtype=tf.float32, name=None))])

데이터 집합 수율은 참고 collections.OrderedDict 있는 개체 pixelslabel 화소 형상의 텐서 키, [28, 28] . 우리는 모양에서 우리의 입력을 평평하게한다고합니다 [784] . 우리는이 작업을 수행 할 수있는 한 가지 방법은 우리에 미리 처리 기능을 적용하는 것입니다 ClientData 객체입니다.

def preprocess_dataset(dataset):
  """Create batches of 5 examples, and limit to 3 batches."""

  def map_fn(input):
    return collections.OrderedDict(
        x=tf.reshape(input['pixels'], shape=(-1, 784)),
        y=tf.cast(tf.reshape(input['label'], shape=(-1, 1)), tf.int64),
    )

  return dataset.batch(5).map(
      map_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE).take(5)


preprocessed_client_data = client_data.preprocess(preprocess_dataset)

# Notice that we have both reshaped and renamed the elements of the ordered dict.
first_client_dataset = preprocessed_client_data.create_tf_dataset_for_client(
    first_client_id)
print(first_client_dataset.element_spec)
OrderedDict([('x', TensorSpec(shape=(None, 784), dtype=tf.float32, name=None)), ('y', TensorSpec(shape=(None, 1), dtype=tf.int64, name=None))])

예를 들어 셔플링과 같이 좀 더 복잡한(그리고 아마도 상태 저장) 전처리를 수행하는 것 외에 추가로 원할 수 있습니다.

def preprocess_and_shuffle(dataset):
  """Applies `preprocess_dataset` above and shuffles the result."""
  preprocessed = preprocess_dataset(dataset)
  return preprocessed.shuffle(buffer_size=5)

preprocessed_and_shuffled = client_data.preprocess(preprocess_and_shuffle)

# The type signature will remain the same, but the batches will be shuffled.
first_client_dataset = preprocessed_and_shuffled.create_tf_dataset_for_client(
    first_client_id)
print(first_client_dataset.element_spec)
OrderedDict([('x', TensorSpec(shape=(None, 784), dtype=tf.float32, name=None)), ('y', TensorSpec(shape=(None, 1), dtype=tf.int64, name=None))])

와 인터페이스 tff.Computation

이제 우리가 몇 가지 기본적인 조작을 수행 할 수있는 ClientData 객체, 우리는에 피드 데이터에 대한 준비가 tff.Computation . 우리는 정의 tff.templates.IterativeProcess 구현 연합 평균을 , 그것에게 데이터를 전달하는 다른 방법을 탐구한다.

def model_fn():
  model = tf.keras.models.Sequential([
      tf.keras.layers.InputLayer(input_shape=(784,)),
      tf.keras.layers.Dense(10, kernel_initializer='zeros'),
  ])
  return tff.learning.from_keras_model(
      model,
      # Note: input spec is the _batched_ shape, and includes the 
      # label tensor which will be passed to the loss function. This model is
      # therefore configured to accept data _after_ it has been preprocessed.
      input_spec=collections.OrderedDict(
          x=tf.TensorSpec(shape=[None, 784], dtype=tf.float32),
          y=tf.TensorSpec(shape=[None, 1], dtype=tf.int64)),
      loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
      metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])

trainer = tff.learning.build_federated_averaging_process(
    model_fn,
    client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=0.01))

우리는이 작업을 시작하기 전에 IterativeProcess 의 의미에 대한 한 의견 ClientData 순서입니다. ClientData 객체는 일반에 연합 훈련에 사용할 수있는 인구의 전체 나타내는 생산 FL 시스템의 실행 환경에 사용할 수없는 및 시뮬레이션에 고유합니다. ClientData 실제로 완전히 사용자에게 우회 연합 컴퓨팅의 용량을 제공하며, 단순히 통해 평소와 같이 서버 측 모델을 학습 ClientData.create_tf_dataset_from_all_clients .

TFF의 시뮬레이션 환경은 연구원이 외부 루프를 완전히 제어할 수 있도록 합니다. 특히 이것은 클라이언트 가용성, 클라이언트 드롭아웃 등의 고려 사항이 사용자 또는 Python 드라이버 스크립트에서 해결되어야 함을 의미합니다. 하나는 이상 샘플링 분포를 조정하여 예 모델 클라이언트 드롭 아웃을 위해 할 수 ClientData's client_ids 더 많은 데이터와 사용자 (대응 및 지역 계산을 더 이상은 실행)하도록 낮은 확률로 선택 될 것이다.

그러나 실제 연합 시스템에서는 모델 트레이너가 클라이언트를 명시적으로 선택할 수 없습니다. 클라이언트 선택은 연합 계산을 실행하는 시스템에 위임됩니다.

전달 tf.data.Datasets TFF에 직접

우리가 사이의 인터페이스를위한이 하나 개의 옵션 ClientDataIterativeProcess 그 구성이다 tf.data.Datasets 파이썬에서, 그리고 TFF 이러한 데이터 세트를 전달합니다.

우리는 우리의 전처리 사용하는 경우주의하는 것이 ClientData 우리가 얻을 데이터 셋은 위에서 정의 된 우리의 모델에 의해 예상 적절한 유형입니다.

selected_client_ids = preprocessed_and_shuffled.client_ids[:10]

preprocessed_data_for_clients = [
    preprocessed_and_shuffled.create_tf_dataset_for_client(
        selected_client_ids[i]) for i in range(10)
]

state = trainer.initialize()
for _ in range(5):
  t1 = time.time()
  state, metrics = trainer.next(state, preprocessed_data_for_clients)
  t2 = time.time()
  print('loss {}, round time {}'.format(metrics['train']['loss'], t2 - t1))
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow_federated/python/core/impl/compiler/tensorflow_computation_transformations.py:62: extract_sub_graph (from tensorflow.python.framework.graph_util_impl) is deprecated and will be removed in a future version.
Instructions for updating:
Use `tf.compat.v1.graph_util.extract_sub_graph`
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow_federated/python/core/impl/compiler/tensorflow_computation_transformations.py:62: extract_sub_graph (from tensorflow.python.framework.graph_util_impl) is deprecated and will be removed in a future version.
Instructions for updating:
Use `tf.compat.v1.graph_util.extract_sub_graph`
loss 2.9005744457244873, round time 4.576513767242432
loss 3.113278388977051, round time 0.49641919136047363
loss 2.7581865787506104, round time 0.4904160499572754
loss 2.87259578704834, round time 0.48976993560791016
loss 3.1202380657196045, round time 0.6724586486816406

우리가이 길을 경우, 그러나, 우리는 사소 multimachine 시뮬레이션으로 이동 할없습니다. 우리가 지방 TensorFlow 런타임에 구성 데이터 세트는 주변 파이썬 환경에서 상태를 캡처하고 그들이 더 이상 사용할 수없는 그들이다 기준 상태로 시도 할 때 직렬화 또는 직렬화 복원에 실패 할 수 있습니다. 이 TensorFlow의에서 헤아릴 수없는 오류의 예를 들어 나타날 수 tensor_util.cc :

Check failed: DT_VARIANT == input.dtype() (21 vs. 20)

클라이언트에 대한 매핑 구성 및 사전 처리

이 문제를 방지하려면, TFF는 사용자가 데이터 세트 인스턴스 및 각 클라이언트에 로컬로 발생하는 무언가로 사전 처리를 고려하고, TFF의 헬퍼를 사용하거나 권장 federated_map 명시 적으로 각 클라이언트에서이 사전 처리 코드를 실행합니다.

개념적으로 이것을 선호하는 이유는 분명합니다. TFF의 로컬 런타임에서 전체 연합 오케스트레이션이 단일 시스템에서 발생하기 때문에 클라이언트는 "우연히" 글로벌 Python 환경에 액세스할 수 있습니다. 이 시점에서 유사한 사고가 TFF의 교차 플랫폼, 항상 직렬화 가능, 기능 철학을 발생시킨다는 점에 주목할 가치가 있습니다.

TFF는 통해 이러한 변화를 간단하게 ClientData's 속성 dataset_computation 하는 tff.Computation 소요 client_id 하고 관련 반환 tf.data.Dataset .

참고 preprocess 단순히이 작동 dataset_computation ; dataset_computation 전처리의 속성 ClientData 우리가 정의 된 전체 전처리 파이프 라인을 포함 :

print('dataset computation without preprocessing:')
print(client_data.dataset_computation.type_signature)
print('\n')
print('dataset computation with preprocessing:')
print(preprocessed_and_shuffled.dataset_computation.type_signature)
dataset computation without preprocessing:
(string -> <label=int32,pixels=float32[28,28]>*)


dataset computation with preprocessing:
(string -> <x=float32[?,784],y=int64[?,1]>*)

우리는 호출 할 수 dataset_computation 하고 파이썬 런타임에 열심 인 데이터 집합을받을 수 있지만 우리는 반복적 인 과정 또는 모두에서 글로벌 열망 런타임에서 이러한 데이터 세트를 구체화 피하기 위해 다른 계산으로 구성 할 때이 방법의 진정한 힘이 발휘된다. TFF는 도우미 기능 제공 tff.simulation.compose_dataset_computation_with_iterative_process 정확히이 작업을 수행하는 데 사용할 수 있습니다.

trainer_accepting_ids = tff.simulation.compose_dataset_computation_with_iterative_process(
    preprocessed_and_shuffled.dataset_computation, trainer)

모두이 tff.templates.IterativeProcesses 실행 같은 방법으로 위의 하나; 그러나 전 전처리 클라이언트 데이터 세트를 수용하고, 후자는 클라이언트 ID를 나타내는 문자열을 모두 집합 구축 처리와 체에 전처리 수용 - 사실 state 둘 사이를 통과 할 수있다.

for _ in range(5):
  t1 = time.time()
  state, metrics = trainer_accepting_ids.next(state, selected_client_ids)
  t2 = time.time()
  print('loss {}, round time {}'.format(metrics['train']['loss'], t2 - t1))
loss 2.8417396545410156, round time 1.6707067489624023
loss 2.7670371532440186, round time 0.5207102298736572
loss 2.665048122406006, round time 0.5302855968475342
loss 2.7213189601898193, round time 0.5313887596130371
loss 2.580148935317993, round time 0.5283482074737549

많은 수의 클라이언트로 확장

trainer_accepting_ids 즉시 TFF의 multimachine 런타임에서 사용할 수 있으며,을 피가 구체화 tf.data.Datasets 와 컨트롤러를 (따라서를 직렬화하고 노동자에게 발송).

이것은 특히 많은 수의 클라이언트에서 분산 시뮬레이션의 속도를 크게 높이고 중간 집계를 활성화하여 유사한 직렬화/역직렬화 오버헤드를 방지합니다.

선택적 심층 분석: TFF에서 수동으로 사전 처리 논리 구성

TFF는 처음부터 구성을 위해 설계되었습니다. TFF의 도우미가 방금 수행한 구성 유형은 사용자로서 완전히 제어할 수 있습니다. 우리는 우리가 정의 전처리 계산을 수동으로이 구성 할 수있는 트레이너 자신의 next 아주 간단하게 :

selected_clients_type = tff.FederatedType(preprocessed_and_shuffled.dataset_computation.type_signature.parameter, tff.CLIENTS)

@tff.federated_computation(trainer.next.type_signature.parameter[0], selected_clients_type)
def new_next(server_state, selected_clients):
  preprocessed_data = tff.federated_map(preprocessed_and_shuffled.dataset_computation, selected_clients)
  return trainer.next(server_state, preprocessed_data)

manual_trainer_with_preprocessing = tff.templates.IterativeProcess(initialize_fn=trainer.initialize, next_fn=new_next)

사실, 이것은 우리가 사용한 도우미가 내부에서 수행하는 작업입니다(또한 적절한 유형 검사 및 조작 수행). 우리는 심지어 직렬화에 의해, 약간 다르게 같은 논리를 표명 한 수 preprocess_and_shuffletff.Computation 및 분해 federated_map 않은 전처리 데이터 세트 및 실행이되는 또 다른 구축 한 단계에 preprocess_and_shuffle 각 클라이언트에 있습니다.

보다 수동적인 이 경로는 TFF의 도우미(모듈로 매개변수 이름)와 동일한 유형 서명을 사용하여 계산을 수행하는지 확인할 수 있습니다.

print(trainer_accepting_ids.next.type_signature)
print(manual_trainer_with_preprocessing.next.type_signature)
(<server_state=<model=<trainable=<float32[784,10],float32[10]>,non_trainable=<>>,optimizer_state=<int64>,delta_aggregate_state=<value_sum_process=<>,weight_sum_process=<>>,model_broadcast_state=<>>@SERVER,federated_dataset={string}@CLIENTS> -> <<model=<trainable=<float32[784,10],float32[10]>,non_trainable=<>>,optimizer_state=<int64>,delta_aggregate_state=<value_sum_process=<>,weight_sum_process=<>>,model_broadcast_state=<>>@SERVER,<broadcast=<>,aggregation=<mean_value=<>,mean_weight=<>>,train=<sparse_categorical_accuracy=float32,loss=float32>,stat=<num_examples=int64>>@SERVER>)
(<server_state=<model=<trainable=<float32[784,10],float32[10]>,non_trainable=<>>,optimizer_state=<int64>,delta_aggregate_state=<value_sum_process=<>,weight_sum_process=<>>,model_broadcast_state=<>>@SERVER,selected_clients={string}@CLIENTS> -> <<model=<trainable=<float32[784,10],float32[10]>,non_trainable=<>>,optimizer_state=<int64>,delta_aggregate_state=<value_sum_process=<>,weight_sum_process=<>>,model_broadcast_state=<>>@SERVER,<broadcast=<>,aggregation=<mean_value=<>,mean_weight=<>>,train=<sparse_categorical_accuracy=float32,loss=float32>,stat=<num_examples=int64>>@SERVER>)