Esta página foi traduzida pela API Cloud Translation.
Switch to English

Algoritmos federados personalizados, parte 2: implementação da média federada

Ver no TensorFlow.org Executar no Google Colab Ver fonte no GitHub

Este tutorial é a segunda parte de uma série de duas partes que demonstra como implementar tipos personalizados de algoritmos federados em TFF usando Federated Core (FC) , que serve como base para a camada Federated Learning (FL) ( tff.learning ) .

Nós o encorajamos a ler primeiro a primeira parte desta série , que apresenta alguns dos principais conceitos e abstrações de programação usados ​​aqui.

Esta segunda parte da série usa os mecanismos introduzidos na primeira parte para implementar uma versão simples de treinamento federado e algoritmos de avaliação.

Incentivamos você a revisar os tutoriais de classificação de imagem e geração de texto para uma introdução de nível superior e mais suave às APIs Federated Learning da TFF, pois eles o ajudarão a contextualizar os conceitos que descrevemos aqui.

Antes de começarmos

Antes de começar, tente executar o seguinte exemplo "Hello World" para certificar-se de que seu ambiente está configurado corretamente. Se não funcionar, consulte o guia de instalação para obter instruções.


!pip install --quiet --upgrade tensorflow_federated_nightly
!pip install --quiet --upgrade nest_asyncio

import nest_asyncio
nest_asyncio.apply()
import collections

import numpy as np
import tensorflow as tf
import tensorflow_federated as tff

# TODO(b/148678573,b/148685415): must use the reference context because it
# supports unbounded references and tff.sequence_* intrinsics.
tff.backends.reference.set_reference_context()
@tff.federated_computation
def hello_world():
  return 'Hello, World!'

hello_world()
'Hello, World!'

Implementando a média federada

Como em Aprendizado federado para classificação de imagens , usaremos o exemplo MNIST, mas como este é um tutorial de baixo nível, vamos ignorar a API Keras e tff.simulation , escrever o código do modelo bruto e construir um conjunto de dados federados do zero.

Preparando conjuntos de dados federados

Para fins de demonstração, vamos simular um cenário em que temos dados de 10 usuários, e cada um dos usuários contribui com o conhecimento de como reconhecer um dígito diferente. Isso é o mais não- iid possível.

Primeiro, vamos carregar os dados MNIST padrão:

mnist_train, mnist_test = tf.keras.datasets.mnist.load_data()
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz
11493376/11490434 [==============================] - 0s 0us/step
11501568/11490434 [==============================] - 0s 0us/step

[(x.dtype, x.shape) for x in mnist_train]
[(dtype('uint8'), (60000, 28, 28)), (dtype('uint8'), (60000,))]

Os dados vêm como matrizes Numpy, uma com imagens e outra com rótulos de dígitos, ambas com a primeira dimensão passando pelos exemplos individuais. Vamos escrever uma função auxiliar que a formate de forma compatível com a forma como alimentamos sequências federadas em cálculos TFF, ou seja, como uma lista de listas - a lista externa variando sobre os usuários (dígitos), as internas variando sobre lotes de dados em a sequência de cada cliente. Como de costume, estruturaremos cada lote como um par de tensores denominados x e y , cada um com a dimensão de lote principal. Enquanto isso, também achataremos cada imagem em um vetor de 784 elementos e redimensionaremos os pixels dentro do intervalo 0..1 , de modo que não tenhamos que desorganizar a lógica do modelo com conversões de dados.

NUM_EXAMPLES_PER_USER = 1000
BATCH_SIZE = 100


def get_data_for_digit(source, digit):
  output_sequence = []
  all_samples = [i for i, d in enumerate(source[1]) if d == digit]
  for i in range(0, min(len(all_samples), NUM_EXAMPLES_PER_USER), BATCH_SIZE):
    batch_samples = all_samples[i:i + BATCH_SIZE]
    output_sequence.append({
        'x':
            np.array([source[0][i].flatten() / 255.0 for i in batch_samples],
                     dtype=np.float32),
        'y':
            np.array([source[1][i] for i in batch_samples], dtype=np.int32)
    })
  return output_sequence


federated_train_data = [get_data_for_digit(mnist_train, d) for d in range(10)]

federated_test_data = [get_data_for_digit(mnist_test, d) for d in range(10)]

Como uma verificação rápida de sanidade, vamos olhar para o tensor Y no último lote de dados fornecido pelo quinto cliente (aquele que corresponde ao dígito 5 ).

federated_train_data[5][-1]['y']
array([5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5,
       5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5,
       5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5,
       5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5,
       5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5], dtype=int32)

Para ter certeza, vamos também olhar para a imagem correspondente ao último elemento daquele lote.

from matplotlib import pyplot as plt

plt.imshow(federated_train_data[5][-1]['x'][-1].reshape(28, 28), cmap='gray')
plt.grid(False)
plt.show()

png

Sobre a combinação de TensorFlow e TFF

Neste tutorial, para compactação, decoramos imediatamente as funções que apresentam a lógica do tff.tf_computation com tff.tf_computation . No entanto, para lógicas mais complexas, esse não é o padrão que recomendamos. Depurar o TensorFlow já pode ser um desafio, e depurar o TensorFlow depois de ser totalmente serializado e reimportado necessariamente perde alguns metadados e limita a interatividade, tornando a depuração ainda mais desafiadora.

Portanto, é altamente recomendável escrever lógica TF complexa como funções autônomas do Python (ou seja, sem decoração tff.tf_computation ). Dessa forma, a lógica do TensorFlow pode ser desenvolvida e testada usando as melhores práticas e ferramentas do TF (como o modo ansioso), antes de serializar a computação para TFF (por exemplo, invocando tff.tf_computation com uma função Python como argumento).

Definindo uma função de perda

Agora que temos os dados, vamos definir uma função de perda que podemos usar para treinamento. Primeiro, vamos definir o tipo de entrada como uma TFF chamada tupla. Como o tamanho dos lotes de dados pode variar, definimos a dimensão do lote como None para indicar que o tamanho desta dimensão é desconhecido.

BATCH_SPEC = collections.OrderedDict(
    x=tf.TensorSpec(shape=[None, 784], dtype=tf.float32),
    y=tf.TensorSpec(shape=[None], dtype=tf.int32))
BATCH_TYPE = tff.to_type(BATCH_SPEC)

str(BATCH_TYPE)
'<x=float32[?,784],y=int32[?]>'

Você pode estar se perguntando por que não podemos simplesmente definir um tipo Python comum. Lembre-se da discussão na parte 1 , onde explicamos que, embora possamos expressar a lógica dos cálculos TFF usando Python, nos bastidores os cálculos TFF não são Python. O símbolo BATCH_TYPE definido acima representa uma especificação abstrata de tipo TFF. É importante distinguir esse tipo abstrato de TFF dos tipos de representação concretos do Python, por exemplo, contêineres como dict ou collections.namedtuple que podem ser usados ​​para representar o tipo TFF no corpo de uma função Python. Ao contrário do Python, o TFF tem um único construtor de tipo abstrato tff.StructType para tff.StructType tipo tupla, com elementos que podem ser nomeados individualmente ou deixados sem nome. Esse tipo também é usado para modelar parâmetros formais de cálculos, já que os cálculos TFF podem declarar formalmente apenas um parâmetro e um resultado - você verá exemplos disso em breve.

Vamos agora definir o tipo de parâmetros do modelo TFF, novamente como um TFF denominado tupla de pesos e polarização .

MODEL_SPEC = collections.OrderedDict(
    weights=tf.TensorSpec(shape=[784, 10], dtype=tf.float32),
    bias=tf.TensorSpec(shape=[10], dtype=tf.float32))
MODEL_TYPE = tff.to_type(MODEL_SPEC)

print(MODEL_TYPE)
<weights=float32[784,10],bias=float32[10]>

Com essas definições em vigor, agora podemos definir a perda para um determinado modelo, em um único lote. Observe o uso do decorador @tf.function dentro do decorador @tff.tf_computation . Isso nos permite escrever TF usando Python como a semântica, mesmo dentro de um contexto tf.Graph criado pelo decorador tff.tf_computation .

# NOTE: `forward_pass` is defined separately from `batch_loss` so that it can 
# be later called from within another tf.function. Necessary because a
# @tf.function  decorated method cannot invoke a @tff.tf_computation.

@tf.function
def forward_pass(model, batch):
  predicted_y = tf.nn.softmax(
      tf.matmul(batch['x'], model['weights']) + model['bias'])
  return -tf.reduce_mean(
      tf.reduce_sum(
          tf.one_hot(batch['y'], 10) * tf.math.log(predicted_y), axis=[1]))

@tff.tf_computation(MODEL_TYPE, BATCH_TYPE)
def batch_loss(model, batch):
  return forward_pass(model, batch)

Como esperado, computação batch_loss retornos float32 perda dado o modelo e um único lote de dados. Observe como o MODEL_TYPE e o BATCH_TYPE foram agrupados em duas tuplas de parâmetros formais; você pode reconhecer o tipo de batch_loss como (<MODEL_TYPE,BATCH_TYPE> -> float32) .

str(batch_loss.type_signature)
'(<<weights=float32[784,10],bias=float32[10]>,<x=float32[?,784],y=int32[?]>> -> float32)'

Para verificar a integridade, vamos construir um modelo inicial preenchido com zeros e calcular a perda sobre o lote de dados que visualizamos acima.

initial_model = collections.OrderedDict(
    weights=np.zeros([784, 10], dtype=np.float32),
    bias=np.zeros([10], dtype=np.float32))

sample_batch = federated_train_data[5][-1]

batch_loss(initial_model, sample_batch)
2.3025854

Observe que alimentamos o cálculo TFF com o modelo inicial definido como um dict , embora o corpo da função Python que o define consuma parâmetros do model['weight'] como model['weight'] e model['bias'] . Os argumentos da chamada para batch_loss não são simplesmente passados ​​para o corpo dessa função.

O que acontece quando invocamos batch_loss ? O corpo Python de batch_loss já foi rastreado e serializado na célula acima onde foi definido. TFF atua como o chamador para batch_loss no momento da definição de computação e como o destino da invocação no momento em que batch_loss é invocado. Em ambas as funções, a TFF serve como ponte entre o sistema de tipo abstrato da TFF e os tipos de representação Python. No momento da invocação, o TFF aceitará a maioria dos tipos de contêineres Python padrão ( dict , list , tuple , collections.namedtuple , etc.) como representações concretas de tuplas TFF abstratas. Além disso, embora conforme observado acima, os cálculos TFF formalmente aceitem apenas um único parâmetro, você pode usar a sintaxe de chamada Python familiar com argumentos posicionais e / ou de palavra-chave no caso em que o tipo do parâmetro é uma tupla - funciona conforme o esperado.

Descida gradiente em um único lote

Agora, vamos definir um cálculo que usa essa função de perda para realizar uma única etapa de descida do gradiente. Observe como, ao definir esta função, usamos batch_loss como um subcomponente. Você pode invocar um cálculo construído com tff.tf_computation dentro do corpo de outro cálculo, embora normalmente isso não seja necessário - conforme observado acima, porque a serialização perde algumas informações de depuração, muitas vezes é preferível para cálculos mais complexos escrever e testar todos os TensorFlow sem o decorador tff.tf_computation .

@tff.tf_computation(MODEL_TYPE, BATCH_TYPE, tf.float32)
def batch_train(initial_model, batch, learning_rate):
  # Define a group of model variables and set them to `initial_model`. Must
  # be defined outside the @tf.function.
  model_vars = collections.OrderedDict([
      (name, tf.Variable(name=name, initial_value=value))
      for name, value in initial_model.items()
  ])
  optimizer = tf.keras.optimizers.SGD(learning_rate)

  @tf.function
  def _train_on_batch(model_vars, batch):
    # Perform one step of gradient descent using loss from `batch_loss`.
    with tf.GradientTape() as tape:
      loss = forward_pass(model_vars, batch)
    grads = tape.gradient(loss, model_vars)
    optimizer.apply_gradients(
        zip(tf.nest.flatten(grads), tf.nest.flatten(model_vars)))
    return model_vars

  return _train_on_batch(model_vars, batch)
str(batch_train.type_signature)
'(<<weights=float32[784,10],bias=float32[10]>,<x=float32[?,784],y=int32[?]>,float32> -> <weights=float32[784,10],bias=float32[10]>)'

Quando você invoca uma função Python decorada com tff.tf_computation dentro do corpo de outra função, a lógica da computação TFF interna é embutida (essencialmente, embutida) na lógica da externa. Conforme observado acima, se você estiver escrevendo ambos os cálculos, provavelmente é preferível fazer a função interna ( batch_loss neste caso) um Python regular ou tf.function vez de um tff.tf_computation . No entanto, aqui ilustramos que chamar um tff.tf_computation dentro de outro funciona basicamente como esperado. Isso pode ser necessário se, por exemplo, você não tiver o código Python definindo batch_loss , mas apenas sua representação TFF serializada.

Agora, vamos aplicar esta função algumas vezes ao modelo inicial para ver se a perda diminui.

model = initial_model
losses = []
for _ in range(5):
  model = batch_train(model, sample_batch, 0.1)
  losses.append(batch_loss(model, sample_batch))
losses
[0.19690022, 0.13176313, 0.10113226, 0.082738124, 0.0703014]

Gradiente descendente em uma sequência de dados locais

Agora, como batch_train parece funcionar, vamos escrever uma função de treinamento semelhante local_train que consome toda a sequência de todos os lotes de um usuário, em vez de apenas um único lote. O novo cálculo agora precisará consumir tff.SequenceType(BATCH_TYPE) vez de BATCH_TYPE .

LOCAL_DATA_TYPE = tff.SequenceType(BATCH_TYPE)

@tff.federated_computation(MODEL_TYPE, tf.float32, LOCAL_DATA_TYPE)
def local_train(initial_model, learning_rate, all_batches):

  # Mapping function to apply to each batch.
  @tff.federated_computation(MODEL_TYPE, BATCH_TYPE)
  def batch_fn(model, batch):
    return batch_train(model, batch, learning_rate)

  return tff.sequence_reduce(all_batches, initial_model, batch_fn)
str(local_train.type_signature)
'(<<weights=float32[784,10],bias=float32[10]>,float32,<x=float32[?,784],y=int32[?]>*> -> <weights=float32[784,10],bias=float32[10]>)'

Existem alguns detalhes enterrados nesta pequena seção de código, vamos examiná-los um por um.

Em primeiro lugar, embora pudéssemos ter implementado essa lógica inteiramente no TensorFlow, contando com tf.data.Dataset.reduce para processar a sequência de maneira semelhante a como fizemos anteriormente, optamos desta vez para expressar a lógica na linguagem de cola , como um tff.federated_computation . Usamos o operador federado tff.sequence_reduce para realizar a redução.

O operador tff.sequence_reduce é usado de maneira semelhante a tf.data.Dataset.reduce . Você pode pensar nele como essencialmente o mesmo que tf.data.Dataset.reduce , mas para uso em cálculos federados, que, como você deve se lembrar, não podem conter código do TensorFlow. É um operador de modelo com um parâmetro formal 3-tupla que consiste em uma sequência de elementos do tipo T , o estado inicial da redução (vamos nos referir a ele abstratamente como zero ) de algum tipo U e o operador de redução de tipo (<U,T> -> U) que altera o estado da redução processando um único elemento. O resultado é o estado final da redução, após processar todos os elementos em uma ordem sequencial. Em nosso exemplo, o estado da redução é o modelo treinado em um prefixo dos dados e os elementos são lotes de dados.

Em segundo lugar, observe que usamos novamente um cálculo ( batch_train ) como um componente dentro de outro ( local_train ), mas não diretamente. Não podemos usá-lo como um operador de redução porque leva um parâmetro adicional - a taxa de aprendizagem. Para resolver isso, definimos uma computação federada embutida batch_fn que se liga ao parâmetro learning_rate do local_train em seu corpo. É permitido a um cálculo filho definido desta forma capturar um parâmetro formal de seu pai, desde que o cálculo filho não seja invocado fora do corpo de seu pai. Você pode pensar neste padrão como um equivalente de functools.partial em Python.

A implicação prática de capturar learning_rate dessa forma é, obviamente, que o mesmo valor de taxa de aprendizagem é usado em todos os lotes.

Agora, vamos tentar a função de treinamento local recém-definida em toda a sequência de dados do mesmo usuário que contribuiu com o lote de amostra (dígito 5 ).

locally_trained_model = local_train(initial_model, 0.1, federated_train_data[5])

Funcionou? Para responder a esta pergunta, precisamos implementar avaliação.

Avaliação local

Esta é uma maneira de implementar a avaliação local somando as perdas em todos os lotes de dados (poderíamos ter calculado a média da mesma forma; deixaremos isso como um exercício para o leitor).

@tff.federated_computation(MODEL_TYPE, LOCAL_DATA_TYPE)
def local_eval(model, all_batches):
  # TODO(b/120157713): Replace with `tff.sequence_average()` once implemented.
  return tff.sequence_sum(
      tff.sequence_map(
          tff.federated_computation(lambda b: batch_loss(model, b), BATCH_TYPE),
          all_batches))
str(local_eval.type_signature)
'(<<weights=float32[784,10],bias=float32[10]>,<x=float32[?,784],y=int32[?]>*> -> float32)'

Novamente, existem alguns novos elementos ilustrados por este código, vamos examiná-los um por um.

Primeiro, usamos dois novos operadores federados para processar sequências: tff.sequence_map que recebe uma função de mapeamento T->U e uma sequência de T , e emite uma sequência de U obtida aplicando a função de mapeamento tff.sequence_sum , e tff.sequence_sum que apenas adiciona todos os elementos. Aqui, mapeamos cada lote de dados para um valor de perda e, em seguida, adicionamos os valores de perda resultantes para calcular a perda total.

Observe que poderíamos ter usado tff.sequence_reduce novamente, mas essa não seria a melhor escolha - o processo de redução é, por definição, sequencial, enquanto o mapeamento e a soma podem ser calculados em paralelo. Quando tiver uma escolha, é melhor ficar com os operadores que não restringem as opções de implementação, de modo que quando nosso cálculo TFF for compilado no futuro para ser implementado em um ambiente específico, seja possível aproveitar todas as oportunidades potenciais para um processo mais rápido execução mais escalonável e eficiente em termos de recursos.

Em segundo lugar, observe que, assim como em local_train , a função de componente de que precisamos ( batch_loss ) leva mais parâmetros do que o que o operador federado ( tff.sequence_map ) espera, portanto, definimos novamente um parcial, desta vez em linha envolvendo diretamente um lambda como um tff.federated_computation . Usar wrappers inline com uma função como argumento é a maneira recomendada de usar tff.tf_computation para incorporar a lógica do TensorFlow no TFF.

Agora, vamos ver se nosso treinamento funcionou.

print('initial_model loss =', local_eval(initial_model,
                                         federated_train_data[5]))
print('locally_trained_model loss =',
      local_eval(locally_trained_model, federated_train_data[5]))
initial_model loss = 23.025854
locally_trained_model loss = 0.4348469

Na verdade, a perda diminuiu. Mas o que acontece se avaliarmos nos dados de outro usuário?

print('initial_model loss =', local_eval(initial_model,
                                         federated_train_data[0]))
print('locally_trained_model loss =',
      local_eval(locally_trained_model, federated_train_data[0]))
initial_model loss = 23.025854
locally_trained_model loss = 74.50075

Como esperado, as coisas pioraram. O modelo foi treinado para reconhecer 5 e nunca viu um 0 . Isso traz a questão - como o treinamento local impactou a qualidade do modelo do ponto de vista global?

Avaliação federada

Este é o ponto em nossa jornada em que finalmente voltamos aos tipos federados e às computações federadas - o tópico com o qual começamos. Aqui está um par de definições de tipos de TFF para o modelo que se origina no servidor e os dados que permanecem nos clientes.

SERVER_MODEL_TYPE = tff.FederatedType(MODEL_TYPE, tff.SERVER)
CLIENT_DATA_TYPE = tff.FederatedType(LOCAL_DATA_TYPE, tff.CLIENTS)

Com todas as definições introduzidas até agora, expressar a avaliação federada no TFF é uma linha única - distribuímos o modelo aos clientes, deixamos que cada cliente invoque a avaliação local em sua porção local de dados e, em seguida, calculamos a perda. Aqui está uma maneira de escrever isso.

@tff.federated_computation(SERVER_MODEL_TYPE, CLIENT_DATA_TYPE)
def federated_eval(model, data):
  return tff.federated_mean(
      tff.federated_map(local_eval, [tff.federated_broadcast(model), data]))

Já vimos exemplos de tff.federated_mean e tff.federated_map em cenários mais simples e, no nível intuitivo, eles funcionam como esperado, mas há mais nesta seção de código do que aparenta, então vamos examiná-lo com cuidado.

Primeiro, vamos dividir o modo de permitir que cada cliente invoque a avaliação local em sua parte local da parte de dados . Como você deve se lembrar das seções anteriores, local_eval tem uma assinatura de tipo do formulário (<MODEL_TYPE, LOCAL_DATA_TYPE> -> float32) .

O operador federado tff.federated_map é um modelo que aceita como parâmetro uma 2 tupla que consiste na função de mapeamento de algum tipo T->U e um valor federado do tipo {T}@CLIENTS (ou seja, com constituintes membros do mesmo tipo que o parâmetro da função de mapeamento) e retorna um resultado do tipo {U}@CLIENTS .

Como estamos alimentando local_eval como uma função de mapeamento para aplicar por cliente, o segundo argumento deve ser de um tipo federado {<MODEL_TYPE, LOCAL_DATA_TYPE>}@CLIENTS , ou seja, na nomenclatura das seções anteriores, ele deve ser uma tupla federada. Cada cliente deve conter um conjunto completo de argumentos para local_eval como membro constituinte. Em vez disso, estamos alimentando-o com uma list Python de 2 elementos. O que está acontecendo aqui?

Na verdade, este é um exemplo de conversão de tipo implícito em TFF, semelhante a conversão de tipo implícito que você pode ter encontrado em outro lugar, por exemplo, quando você alimenta um int para uma função que aceita um float . A conversão implícita raramente é usada neste ponto, mas planejamos torná-la mais difundida na TFF como uma forma de minimizar o clichê.

A conversão implícita aplicada neste caso é a equivalência entre tuplas federadas da forma {<X,Y>}@Z e tuplas de valores federados <{X}@Z,{Y}@Z> . Embora formalmente esses dois sejam assinaturas de tipo diferentes, olhando para isso da perspectiva dos programadores, cada dispositivo em Z contém duas unidades de dados X e Y O que acontece aqui não é diferente do zip em Python e, de fato, oferecemos um operador tff.federated_zip que permite realizar tais conversões explicitamente. Quando o tff.federated_map encontra uma tupla como um segundo argumento, ele simplesmente invoca tff.federated_zip para você.

Diante do exposto, agora você deve ser capaz de reconhecer a expressão tff.federated_broadcast(model) como representando um valor do tipo TFF {MODEL_TYPE}@CLIENTS , e data como um valor do tipo TFF {LOCAL_DATA_TYPE}@CLIENTS (ou simplesmente CLIENT_DATA_TYPE ) , os dois sendo filtrados juntos por meio de um tff.federated_zip implícito para formar o segundo argumento para tff.federated_map .

O operador tff.federated_broadcast , como você esperaria, simplesmente transfere os dados do servidor para os clientes.

Agora, vamos ver como nosso treinamento local afetou a perda média no sistema.

print('initial_model loss =', federated_eval(initial_model,
                                             federated_train_data))
print('locally_trained_model loss =',
      federated_eval(locally_trained_model, federated_train_data))
initial_model loss = 23.025852
locally_trained_model loss = 54.432625

Na verdade, como esperado, a perda aumentou. Para melhorar o modelo para todos os usuários, precisaremos treinar nos dados de todos.

Treino federado

A maneira mais simples de implementar o treinamento federado é treinar localmente e, em seguida, calcular a média dos modelos. Ele usa os mesmos blocos de construção e padrões que já discutimos, como você pode ver abaixo.

SERVER_FLOAT_TYPE = tff.FederatedType(tf.float32, tff.SERVER)


@tff.federated_computation(SERVER_MODEL_TYPE, SERVER_FLOAT_TYPE,
                           CLIENT_DATA_TYPE)
def federated_train(model, learning_rate, data):
  return tff.federated_mean(
      tff.federated_map(local_train, [
          tff.federated_broadcast(model),
          tff.federated_broadcast(learning_rate), data
      ]))

Observe que na implementação completa de Federated Averaging fornecida por tff.learning , em vez de fazer a média dos modelos, preferimos fazer a média dos deltas do modelo, por uma série de razões, por exemplo, a capacidade de cortar as normas de atualização, para compressão, etc. .

Vamos ver se o treinamento funciona executando algumas rodadas de treinamento e comparando a perda média antes e depois.

model = initial_model
learning_rate = 0.1
for round_num in range(5):
  model = federated_train(model, learning_rate, federated_train_data)
  learning_rate = learning_rate * 0.9
  loss = federated_eval(model, federated_train_data)
  print('round {}, loss={}'.format(round_num, loss))
round 0, loss=21.60552406311035
round 1, loss=20.365678787231445
round 2, loss=19.27480125427246
round 3, loss=18.31110954284668
round 4, loss=17.45725440979004

Para completar, vamos agora também executar os dados de teste para confirmar se nosso modelo generaliza bem.

print('initial_model test loss =',
      federated_eval(initial_model, federated_test_data))
print('trained_model test loss =', federated_eval(model, federated_test_data))
initial_model test loss = 22.795593
trained_model test loss = 17.278767

Isso conclui nosso tutorial.

É claro que nosso exemplo simplificado não reflete uma série de coisas que você precisa fazer em um cenário mais realista - por exemplo, não calculamos métricas além da perda. Incentivamos você a estudar a implementação da média federada em tff.learning como um exemplo mais completo e como uma forma de demonstrar algumas das práticas de codificação que gostaríamos de encorajar.