Trabalhando com ClientData de tff.

Ver no TensorFlow.org Executar no Google Colab Ver fonte no GitHub Baixar caderno

A noção de um conjunto de dados codificado por clientes (por exemplo, usuários) é essencial para a computação federada, conforme modelado no TFF. TFF fornece a interface tff.simulation.datasets.ClientData para abstrato sobre este conceito, e os conjuntos de dados que abriga TFF ( stackoverflow , shakespeare , emnist , cifar100 e gldv2 ) tudo implementar essa interface.

Se você estiver trabalhando na aprendizagem federado com seu próprio conjunto de dados, TFF encoraja fortemente que você implementar tanto o ClientData de interface ou usar uma das funções auxiliares de TFF para gerar um ClientData que representa seus dados no disco, por exemplo, tff.simulation.datasets.ClientData.from_clients_and_fn .

Como a maioria de exemplos end-to-end da TFF começar com ClientData objetos, implementando o ClientData de interface com o seu conjunto de dados personalizado irá torná-lo mais fácil de spelunk através de código existente escrito com TFF. Além disso, os tf.data.Datasets que ClientData construções podem ser iterado para se obter directamente as estruturas de numpy matrizes, de modo que ClientData objectos podem ser utilizados com qualquer quadro ML baseado em Python antes de se mudar para TFF.

Existem vários padrões com os quais você pode tornar sua vida mais fácil se você pretende dimensionar suas simulações para muitas máquinas ou implantá-las. Abaixo vamos percorrer algumas das maneiras que podemos usar ClientData e TFF para tornar a nossa pequena escala iteração-to-larga escala experimentação-to produção experiência de implantação o mais suave possível.

Qual padrão devo usar para passar ClientData para TFF?

Vamos discutir dois usos da TFF ClientData em profundidade; se você se encaixa em uma das duas categorias abaixo, você claramente preferirá uma em vez da outra. Caso contrário, você pode precisar de uma compreensão mais detalhada dos prós e contras de cada um para fazer uma escolha com mais nuances.

  • Quero iterar o mais rápido possível em uma máquina local; Não preciso tirar proveito facilmente do tempo de execução distribuído da TFF.

    • Você quer passar tf.data.Datasets para TFF diretamente.
    • Isso permite que você programar imperativamente com tf.data.Dataset objetos, e processá-los arbitrariamente.
    • Ele oferece mais flexibilidade do que a opção abaixo; enviar a lógica aos clientes requer que essa lógica seja serializável.
  • Desejo executar minha computação federada no tempo de execução remoto da TFF ou pretendo fazê-lo em breve.

    • Nesse caso, você deseja mapear a construção e o pré-processamento do conjunto de dados para os clientes.
    • Isso resulta em você passando simplesmente uma lista de client_ids diretamente para o seu cálculo federado.
    • Empurrar a construção e o pré-processamento do conjunto de dados para os clientes evita gargalos na serialização e aumenta significativamente o desempenho com centenas de milhares de clientes.

Configurar ambiente de código aberto

Pacotes de importação

Manipulando um objeto ClientData

Vamos começar por carga e explorar da TFF EMNIST ClientData :

client_data, _ = tff.simulation.datasets.emnist.load_data()
Downloading emnist_all.sqlite.lzma: 100%|██████████| 170507172/170507172 [00:19<00:00, 8831921.67it/s]
2021-10-01 11:17:58.718735: E tensorflow/stream_executor/cuda/cuda_driver.cc:271] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected

Inspecionando o primeiro conjunto de dados pode nos dizer que tipo de exemplos estão no ClientData .

first_client_id = client_data.client_ids[0]
first_client_dataset = client_data.create_tf_dataset_for_client(
    first_client_id)
print(first_client_dataset.element_spec)
# This information is also available as a `ClientData` property:
assert client_data.element_type_structure == first_client_dataset.element_spec
OrderedDict([('label', TensorSpec(shape=(), dtype=tf.int32, name=None)), ('pixels', TensorSpec(shape=(28, 28), dtype=tf.float32, name=None))])

Note-se que os rendimentos do conjunto de dados collections.OrderedDict objetos que têm pixels e label chaves, onde pixels é um tensor com forma [28, 28] . Suponha que queremos achatar nossos entradas para forma [784] . Uma maneira possível nós podemos fazer isso seria aplicar uma função pré-processamento para o nosso ClientData objeto.

def preprocess_dataset(dataset):
  """Create batches of 5 examples, and limit to 3 batches."""

  def map_fn(input):
    return collections.OrderedDict(
        x=tf.reshape(input['pixels'], shape=(-1, 784)),
        y=tf.cast(tf.reshape(input['label'], shape=(-1, 1)), tf.int64),
    )

  return dataset.batch(5).map(
      map_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE).take(5)


preprocessed_client_data = client_data.preprocess(preprocess_dataset)

# Notice that we have both reshaped and renamed the elements of the ordered dict.
first_client_dataset = preprocessed_client_data.create_tf_dataset_for_client(
    first_client_id)
print(first_client_dataset.element_spec)
OrderedDict([('x', TensorSpec(shape=(None, 784), dtype=tf.float32, name=None)), ('y', TensorSpec(shape=(None, 1), dtype=tf.int64, name=None))])

Além disso, podemos desejar realizar um pré-processamento mais complexo (e possivelmente com estado), por exemplo, embaralhamento.

def preprocess_and_shuffle(dataset):
  """Applies `preprocess_dataset` above and shuffles the result."""
  preprocessed = preprocess_dataset(dataset)
  return preprocessed.shuffle(buffer_size=5)

preprocessed_and_shuffled = client_data.preprocess(preprocess_and_shuffle)

# The type signature will remain the same, but the batches will be shuffled.
first_client_dataset = preprocessed_and_shuffled.create_tf_dataset_for_client(
    first_client_id)
print(first_client_dataset.element_spec)
OrderedDict([('x', TensorSpec(shape=(None, 784), dtype=tf.float32, name=None)), ('y', TensorSpec(shape=(None, 1), dtype=tf.int64, name=None))])

Interface com um tff.Computation

Agora que podemos executar algumas manipulações básicas com ClientData objetos, estamos prontos para dados de alimentação para um tff.Computation . Nós definimos um tff.templates.IterativeProcess que implementa Federated Média , e explorar diferentes métodos de passá-lo dados.

def model_fn():
  model = tf.keras.models.Sequential([
      tf.keras.layers.InputLayer(input_shape=(784,)),
      tf.keras.layers.Dense(10, kernel_initializer='zeros'),
  ])
  return tff.learning.from_keras_model(
      model,
      # Note: input spec is the _batched_ shape, and includes the 
      # label tensor which will be passed to the loss function. This model is
      # therefore configured to accept data _after_ it has been preprocessed.
      input_spec=collections.OrderedDict(
          x=tf.TensorSpec(shape=[None, 784], dtype=tf.float32),
          y=tf.TensorSpec(shape=[None, 1], dtype=tf.int64)),
      loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
      metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])

trainer = tff.learning.build_federated_averaging_process(
    model_fn,
    client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=0.01))

Antes de começar a trabalhar com este IterativeProcess , um comentário sobre a semântica de ClientData está em ordem. A ClientData objeto representa a totalidade da população disponível para o treinamento federado, que em geral é não disponível para o ambiente de execução de um sistema de FL produção e é específico para simulação. ClientData fato dá ao usuário a capacidade de desvio de computação federada totalmente e simplesmente treinar um modelo do lado do servidor, como de costume via ClientData.create_tf_dataset_from_all_clients .

O ambiente de simulação da TFF coloca o pesquisador em controle total do loop externo. Em particular, isso implica que as considerações sobre a disponibilidade do cliente, o abandono do cliente, etc., devem ser tratadas pelo usuário ou pelo script do driver Python. Poderíamos, por exemplo, modelo de desistência dos clientes, ajustando a distribuição de amostragem sobre seus ClientData's client_ids tais que os usuários com mais dados (e, correspondentemente, a longo executar cálculos locais) seriam selecionados com menor probabilidade.

Em um sistema federado real, entretanto, os clientes não podem ser selecionados explicitamente pelo treinador do modelo; a seleção de clientes é delegada ao sistema que está executando a computação federada.

Passando tf.data.Datasets diretamente para TFF

Uma opção que temos para fazer a interface entre um ClientData e um IterativeProcess é a de construir tf.data.Datasets em Python, e passando esses conjuntos de dados para TFF.

Observe que, se usarmos os nossos pré-processadas ClientData os conjuntos de dados que produzem são do tipo apropriado esperado pelo nosso modelo definido acima.

selected_client_ids = preprocessed_and_shuffled.client_ids[:10]

preprocessed_data_for_clients = [
    preprocessed_and_shuffled.create_tf_dataset_for_client(
        selected_client_ids[i]) for i in range(10)
]

state = trainer.initialize()
for _ in range(5):
  t1 = time.time()
  state, metrics = trainer.next(state, preprocessed_data_for_clients)
  t2 = time.time()
  print('loss {}, round time {}'.format(metrics['train']['loss'], t2 - t1))
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow_federated/python/core/impl/compiler/tensorflow_computation_transformations.py:62: extract_sub_graph (from tensorflow.python.framework.graph_util_impl) is deprecated and will be removed in a future version.
Instructions for updating:
Use `tf.compat.v1.graph_util.extract_sub_graph`
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow_federated/python/core/impl/compiler/tensorflow_computation_transformations.py:62: extract_sub_graph (from tensorflow.python.framework.graph_util_impl) is deprecated and will be removed in a future version.
Instructions for updating:
Use `tf.compat.v1.graph_util.extract_sub_graph`
loss 2.9005744457244873, round time 4.576513767242432
loss 3.113278388977051, round time 0.49641919136047363
loss 2.7581865787506104, round time 0.4904160499572754
loss 2.87259578704834, round time 0.48976993560791016
loss 3.1202380657196045, round time 0.6724586486816406

Se tomarmos este caminho, no entanto, não será capaz de mover-se trivialmente a simulação multimáquina. Os conjuntos de dados que construímos no tempo de execução TensorFlow local pode capturar o estado do ambiente python circundante, e falhar em serialização ou desserialização quando tentam estado de referência que não está mais disponível para eles é. Isto pode se manifestar por exemplo, o erro inescrutáveis de TensorFlow tensor_util.cc :

Check failed: DT_VARIANT == input.dtype() (21 vs. 20)

Mapeamento de construção e pré-processamento sobre os clientes

Para evitar esse problema, TFF recomenda aos seus utilizadores a considerar dataset instanciação e pré-processamento como algo que acontece localmente em cada cliente, e usar ajudantes de TFF ou federated_map para executar explicitamente este código pré-processamento em cada cliente.

Conceitualmente, a razão para preferir isso é clara: no tempo de execução local da TFF, os clientes apenas "acidentalmente" têm acesso ao ambiente Python global devido ao fato de que toda a orquestração federada está acontecendo em uma única máquina. Vale a pena observar neste ponto que pensamento semelhante dá origem à filosofia funcional de plataforma cruzada, sempre serializável e da TFF.

TFF faz essa simples mudança através ClientData's atributo dataset_computation , um tff.Computation que leva um client_id e retorna o associado tf.data.Dataset .

Note-se que preprocess simplesmente trabalha com dataset_computation ; o dataset_computation atributo do pré-processados ClientData incorpora toda a pipeline de pré-processamento que acabamos de definir:

print('dataset computation without preprocessing:')
print(client_data.dataset_computation.type_signature)
print('\n')
print('dataset computation with preprocessing:')
print(preprocessed_and_shuffled.dataset_computation.type_signature)
dataset computation without preprocessing:
(string -> <label=int32,pixels=float32[28,28]>*)


dataset computation with preprocessing:
(string -> <x=float32[?,784],y=int64[?,1]>*)

Poderíamos invocar dataset_computation e receber um conjunto de dados ansioso no tempo de execução Python, mas o verdadeiro poder desta abordagem é exercido quando compor com um processo iterativo ou de outra computação para evitar a materialização destes conjuntos de dados no tempo de execução ansioso mundial em tudo. TFF fornece uma função auxiliar tff.simulation.compose_dataset_computation_with_iterative_process que pode ser usado para fazer exatamente isso.

trainer_accepting_ids = tff.simulation.compose_dataset_computation_with_iterative_process(
    preprocessed_and_shuffled.dataset_computation, trainer)

Tanto este tff.templates.IterativeProcesses e aquele acima executado da mesma maneira; mas primeiro aceita conjuntos de dados de clientes pré-processadas, eo último aceita strings representando IDs de cliente, manipulação tanto a construção do conjunto de dados e pré-processamento em seu corpo - na verdade state podem ser passados entre os dois.

for _ in range(5):
  t1 = time.time()
  state, metrics = trainer_accepting_ids.next(state, selected_client_ids)
  t2 = time.time()
  print('loss {}, round time {}'.format(metrics['train']['loss'], t2 - t1))
loss 2.8417396545410156, round time 1.6707067489624023
loss 2.7670371532440186, round time 0.5207102298736572
loss 2.665048122406006, round time 0.5302855968475342
loss 2.7213189601898193, round time 0.5313887596130371
loss 2.580148935317993, round time 0.5283482074737549

Escalonando para um grande número de clientes

trainer_accepting_ids pode imediatamente ser usado em tempo de execução multimáquina da TFF e evita materializando tf.data.Datasets e o controlador (e, portanto, a serialização-los e enviá-los para os trabalhadores).

Isso acelera significativamente as simulações distribuídas, especialmente com um grande número de clientes, e permite a agregação intermediária para evitar sobrecarga de serialização / desserialização semelhante.

Deepdive opcional: compondo manualmente a lógica de pré-processamento no TFF

TFF é projetado para composicionalidade desde o início; o tipo de composição executado pelo ajudante da TFF está totalmente sob nosso controle como usuários. Poderíamos ter manualmente compor o cálculo de pré-processamento que acabamos de definir com o treinador próprio next simplesmente:

selected_clients_type = tff.FederatedType(preprocessed_and_shuffled.dataset_computation.type_signature.parameter, tff.CLIENTS)

@tff.federated_computation(trainer.next.type_signature.parameter[0], selected_clients_type)
def new_next(server_state, selected_clients):
  preprocessed_data = tff.federated_map(preprocessed_and_shuffled.dataset_computation, selected_clients)
  return trainer.next(server_state, preprocessed_data)

manual_trainer_with_preprocessing = tff.templates.IterativeProcess(initialize_fn=trainer.initialize, next_fn=new_next)

Na verdade, isso é efetivamente o que o auxiliar que usamos está fazendo nos bastidores (além de executar a verificação de tipo e manipulação apropriadas). Poderíamos até ter expressado a mesma lógica ligeiramente diferente, por serialização preprocess_and_shuffle em um tff.Computation , e decompondo o federated_map em um passo que constrói conjuntos de dados pré-processados-un e outro que corre preprocess_and_shuffle em cada cliente.

Podemos verificar que este caminho mais manual resulta em cálculos com a mesma assinatura de tipo que o auxiliar de TFF (nomes de parâmetros de módulo):

print(trainer_accepting_ids.next.type_signature)
print(manual_trainer_with_preprocessing.next.type_signature)
(<server_state=<model=<trainable=<float32[784,10],float32[10]>,non_trainable=<>>,optimizer_state=<int64>,delta_aggregate_state=<value_sum_process=<>,weight_sum_process=<>>,model_broadcast_state=<>>@SERVER,federated_dataset={string}@CLIENTS> -> <<model=<trainable=<float32[784,10],float32[10]>,non_trainable=<>>,optimizer_state=<int64>,delta_aggregate_state=<value_sum_process=<>,weight_sum_process=<>>,model_broadcast_state=<>>@SERVER,<broadcast=<>,aggregation=<mean_value=<>,mean_weight=<>>,train=<sparse_categorical_accuracy=float32,loss=float32>,stat=<num_examples=int64>>@SERVER>)
(<server_state=<model=<trainable=<float32[784,10],float32[10]>,non_trainable=<>>,optimizer_state=<int64>,delta_aggregate_state=<value_sum_process=<>,weight_sum_process=<>>,model_broadcast_state=<>>@SERVER,selected_clients={string}@CLIENTS> -> <<model=<trainable=<float32[784,10],float32[10]>,non_trainable=<>>,optimizer_state=<int64>,delta_aggregate_state=<value_sum_process=<>,weight_sum_process=<>>,model_broadcast_state=<>>@SERVER,<broadcast=<>,aggregation=<mean_value=<>,mean_weight=<>>,train=<sparse_categorical_accuracy=float32,loss=float32>,stat=<num_examples=int64>>@SERVER>)