Ajuda a proteger a Grande Barreira de Corais com TensorFlow em Kaggle Junte Desafio

Algoritmos federados personalizados, parte 2: implementação da média federada

Ver no TensorFlow.org Executar no Google Colab Ver fonte no GitHub Baixar caderno

Este tutorial é a segunda parte de uma série de duas partes que demonstra como implementar tipos de mercadorias de algoritmos federados em TFF utilizando o Federados núcleo (FC) , o qual serve como uma base para o Federados Aprendizagem (FL) camada ( tff.learning ) .

Nós encorajamos você a primeira ler a primeira parte desta série , que introduzem alguns dos conceitos-chave e abstrações de programação usada aqui.

Esta segunda parte da série usa os mecanismos introduzidos na primeira parte para implementar uma versão simples de treinamento federado e algoritmos de avaliação.

Nós encorajamos você a rever a classificação de imagens e geração de texto tutoriais para um nível mais alto e introdução mais suave para APIs de Aprendizagem Federados da TFF, como eles vão ajudar você a colocar os conceitos que descrevem aqui no contexto.

Antes de começarmos

Antes de começar, tente executar o seguinte exemplo "Hello World" para certificar-se de que seu ambiente está configurado corretamente. Se isso não funcionar, consulte a instalação guia para obter instruções.

!pip install --quiet --upgrade tensorflow-federated-nightly
!pip install --quiet --upgrade nest-asyncio

import nest_asyncio
nest_asyncio.apply()
import collections

import numpy as np
import tensorflow as tf
import tensorflow_federated as tff

# TODO(b/148678573,b/148685415): must use the reference context because it
# supports unbounded references and tff.sequence_* intrinsics.
tff.backends.reference.set_reference_context()
@tff.federated_computation
def hello_world():
  return 'Hello, World!'

hello_world()
'Hello, World!'

Implementando a média federada

Como em Federated Aprendizagem para Imagem classificação , vamos usar o exemplo MNIST, mas uma vez que este pretende ser um tutorial de baixo nível, nós estamos indo para ignorar a API Keras e tff.simulation , escrever código modelo de cru, e construir um conjunto de dados federados do zero.

Preparando conjuntos de dados federados

Para fins de demonstração, vamos simular um cenário em que temos dados de 10 usuários, e cada um dos usuários contribui com o conhecimento de como reconhecer um dígito diferente. Isto é sobre como não iid quanto ele ganha.

Primeiro, vamos carregar os dados MNIST padrão:

mnist_train, mnist_test = tf.keras.datasets.mnist.load_data()
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz
11493376/11490434 [==============================] - 0s 0us/step
11501568/11490434 [==============================] - 0s 0us/step
[(x.dtype, x.shape) for x in mnist_train]
[(dtype('uint8'), (60000, 28, 28)), (dtype('uint8'), (60000,))]

Os dados vêm como matrizes Numpy, uma com imagens e outra com rótulos de dígitos, ambas com a primeira dimensão passando pelos exemplos individuais. Vamos escrever uma função auxiliar que a formate de forma compatível com a forma como alimentamos sequências federadas em cálculos TFF, ou seja, como uma lista de listas - a lista externa variando sobre os usuários (dígitos), as internas variando sobre lotes de dados em a sequência de cada cliente. Como é habitual, vamos estruturar cada lote como um par de tensores nomeados x e y , cada uma com a dimensão levando lote. Enquanto a ele, nós também vamos achatar cada imagem em um vetor 784-elemento e redimensionar os pixels em-lo para o 0..1 gama, de modo que não temos para atravancar a lógica do modelo com as conversões de dados.

NUM_EXAMPLES_PER_USER = 1000
BATCH_SIZE = 100


def get_data_for_digit(source, digit):
  output_sequence = []
  all_samples = [i for i, d in enumerate(source[1]) if d == digit]
  for i in range(0, min(len(all_samples), NUM_EXAMPLES_PER_USER), BATCH_SIZE):
    batch_samples = all_samples[i:i + BATCH_SIZE]
    output_sequence.append({
        'x':
            np.array([source[0][i].flatten() / 255.0 for i in batch_samples],
                     dtype=np.float32),
        'y':
            np.array([source[1][i] for i in batch_samples], dtype=np.int32)
    })
  return output_sequence


federated_train_data = [get_data_for_digit(mnist_train, d) for d in range(10)]

federated_test_data = [get_data_for_digit(mnist_test, d) for d in range(10)]

Como uma verificação de sanidade rápida, Vamos olhar o Y tensor no último lote de dados contribuídos pelo quinto cliente (aquele correspondente ao dígito 5 ).

federated_train_data[5][-1]['y']
array([5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5,
       5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5,
       5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5,
       5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5,
       5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5], dtype=int32)

Só para ter certeza, vamos olhar também a imagem correspondente ao último elemento daquele lote.

from matplotlib import pyplot as plt

plt.imshow(federated_train_data[5][-1]['x'][-1].reshape(28, 28), cmap='gray')
plt.grid(False)
plt.show()

png

Sobre a combinação de TensorFlow e TFF

Neste tutorial, para compacidade nós imediatamente decorar funções que introduzem a lógica TensorFlow com tff.tf_computation . No entanto, para lógicas mais complexas, esse não é o padrão que recomendamos. Depurar o TensorFlow já pode ser um desafio, e depurar o TensorFlow depois de ser totalmente serializado e reimportado necessariamente perde alguns metadados e limita a interatividade, tornando a depuração ainda mais desafiadora.

Portanto, recomendamos escrever lógica TF complexo como funções Python stand-alone (isto é, sem tff.tf_computation decoração). Desta forma, a lógica TensorFlow podem ser desenvolvidas e testadas utilizando as melhores práticas e ferramentas TF (como modo ansioso), antes de serialização o cálculo para TFF (por exemplo, invocando tff.tf_computation com uma função de Python como o argumento).

Definindo uma função de perda

Agora que temos os dados, vamos definir uma função de perda que podemos usar para treinamento. Primeiro, vamos definir o tipo de entrada como uma TFF chamada tupla. Como o tamanho dos lotes de dados pode variar, vamos definir a dimensão de lote para None para indicar que o tamanho desta dimensão é desconhecida.

BATCH_SPEC = collections.OrderedDict(
    x=tf.TensorSpec(shape=[None, 784], dtype=tf.float32),
    y=tf.TensorSpec(shape=[None], dtype=tf.int32))
BATCH_TYPE = tff.to_type(BATCH_SPEC)

str(BATCH_TYPE)
'<x=float32[?,784],y=int32[?]>'

Você pode estar se perguntando por que não podemos simplesmente definir um tipo Python comum. Lembre-se da discussão na parte 1 , onde explicou que, enquanto podemos expressar a lógica de cálculos TFF utilizando Python, nos cálculos capô TFF não são Python. O símbolo BATCH_TYPE definido acima representa uma especificação resumo tipo TFF. É importante distinguir este tipo TFF resumo de betão tipos de representação pitão, por exemplo, recipientes, tais como dict ou collections.namedtuple que podem ser utilizados para representar o tipo TFF no corpo de uma função Python. Ao contrário do Python, TFF tem um único tipo construtor resumo tff.StructType para tupla-como recipientes, com elementos que podem ser individualmente nomeados ou deixadas sem nome. Este tipo também é usado para modelar parâmetros formais de cálculos, já que os cálculos TFF podem declarar formalmente apenas um parâmetro e um resultado - você verá exemplos disso em breve.

Vamos agora definir o tipo TFF dos parâmetros do modelo, novamente como um TFF chamado tupla de pesos e preconceito.

MODEL_SPEC = collections.OrderedDict(
    weights=tf.TensorSpec(shape=[784, 10], dtype=tf.float32),
    bias=tf.TensorSpec(shape=[10], dtype=tf.float32))
MODEL_TYPE = tff.to_type(MODEL_SPEC)

print(MODEL_TYPE)
<weights=float32[784,10],bias=float32[10]>

Com essas definições em vigor, agora podemos definir a perda para um determinado modelo, em um único lote. Observe o uso de @tf.function decorador interior do @tff.tf_computation decorador. Isso nos permite escrever TF usando Python como semântica, embora estivesse dentro de um tf.Graph contexto criado pela tff.tf_computation decorador.

# NOTE: `forward_pass` is defined separately from `batch_loss` so that it can 
# be later called from within another tf.function. Necessary because a
# @tf.function  decorated method cannot invoke a @tff.tf_computation.

@tf.function
def forward_pass(model, batch):
  predicted_y = tf.nn.softmax(
      tf.matmul(batch['x'], model['weights']) + model['bias'])
  return -tf.reduce_mean(
      tf.reduce_sum(
          tf.one_hot(batch['y'], 10) * tf.math.log(predicted_y), axis=[1]))

@tff.tf_computation(MODEL_TYPE, BATCH_TYPE)
def batch_loss(model, batch):
  return forward_pass(model, batch)

Como esperado, computação batch_loss retornos float32 perda dado o modelo e um único lote de dados. Note como o MODEL_TYPE e BATCH_TYPE foram agrupados em um 2-tupla de parâmetros formais; você pode reconhecer o tipo de batch_loss como (<MODEL_TYPE,BATCH_TYPE> -> float32) .

str(batch_loss.type_signature)
'(<<weights=float32[784,10],bias=float32[10]>,<x=float32[?,784],y=int32[?]>> -> float32)'

Para verificar a integridade, vamos construir um modelo inicial preenchido com zeros e calcular a perda sobre o lote de dados que visualizamos acima.

initial_model = collections.OrderedDict(
    weights=np.zeros([784, 10], dtype=np.float32),
    bias=np.zeros([10], dtype=np.float32))

sample_batch = federated_train_data[5][-1]

batch_loss(initial_model, sample_batch)
2.3025854

Note-se que nós alimentar o cálculo TFF com o modelo inicial definido como um dict , mesmo que o corpo da função Python que define que consome parâmetros do modelo como model['weight'] e model['bias'] . Os argumentos da chamada para batch_loss não são simplesmente passado para o corpo dessa função.

O que acontece quando invocamos batch_loss ? O corpo do Python batch_loss já foi rastreada e serializados na célula acima onde foi definido. TFF atua como o chamador para batch_loss no momento de definição computação, e como o destino de invocação no momento batch_loss é invocado. Em ambas as funções, a TFF serve como ponte entre o sistema de tipo abstrato da TFF e os tipos de representação Python. Na época invocação, TFF aceitará tipos de contêineres Python mais padrão ( dict , list , tuple , collections.namedtuple , etc.) como representações concretas de tuplas TFF abstratos. Além disso, embora conforme observado acima, os cálculos TFF formalmente aceitem apenas um único parâmetro, você pode usar a sintaxe de chamada Python familiar com argumentos posicionais e / ou de palavra-chave no caso em que o tipo do parâmetro é uma tupla - funciona conforme o esperado.

Descida gradiente em um único lote

Agora, vamos definir um cálculo que usa essa função de perda para realizar uma única etapa de descida do gradiente. Note como na definição desta função, usamos batch_loss como um subcomponente. Você pode chamar um cálculo construído com tff.tf_computation dentro do corpo de outra computação, embora normalmente isso não é necessário - como mencionado acima, porque serialização perde algumas informações de depuração, muitas vezes é preferível para cálculos mais complexos para escrever e testar toda a TensorFlow sem a tff.tf_computation decorador.

@tff.tf_computation(MODEL_TYPE, BATCH_TYPE, tf.float32)
def batch_train(initial_model, batch, learning_rate):
  # Define a group of model variables and set them to `initial_model`. Must
  # be defined outside the @tf.function.
  model_vars = collections.OrderedDict([
      (name, tf.Variable(name=name, initial_value=value))
      for name, value in initial_model.items()
  ])
  optimizer = tf.keras.optimizers.SGD(learning_rate)

  @tf.function
  def _train_on_batch(model_vars, batch):
    # Perform one step of gradient descent using loss from `batch_loss`.
    with tf.GradientTape() as tape:
      loss = forward_pass(model_vars, batch)
    grads = tape.gradient(loss, model_vars)
    optimizer.apply_gradients(
        zip(tf.nest.flatten(grads), tf.nest.flatten(model_vars)))
    return model_vars

  return _train_on_batch(model_vars, batch)
str(batch_train.type_signature)
'(<<weights=float32[784,10],bias=float32[10]>,<x=float32[?,784],y=int32[?]>,float32> -> <weights=float32[784,10],bias=float32[10]>)'

Ao chamar uma função Python decorado com tff.tf_computation dentro do corpo de um outro tal função, a lógica de cálculo TFF interior é incorporado (essencialmente, sequenciados) na lógica de um exterior. Como mencionado acima, se você estiver escrevendo ambos os cálculos, é provável preferível fazer a função interna ( batch_loss neste caso) uma Python regular ou tf.function em vez de um tff.tf_computation . No entanto, aqui ilustramos que chamar um tff.tf_computation dentro de outro, basicamente, funciona como esperado. Isso pode ser necessário se, por exemplo, você não tem o código Python que define batch_loss , mas apenas sua representação TFF serializado.

Agora, vamos aplicar essa função algumas vezes ao modelo inicial para ver se a perda diminui.

model = initial_model
losses = []
for _ in range(5):
  model = batch_train(model, sample_batch, 0.1)
  losses.append(batch_loss(model, sample_batch))
losses
[0.19690022, 0.13176313, 0.10113226, 0.082738124, 0.0703014]

Descida gradiente em uma sequência de dados locais

Agora, uma vez batch_train aparece para trabalhar, vamos escrever uma função de formação semelhante local_train que consome toda a seqüência de todos os lotes de um usuário em vez de apenas um único lote. O novo cálculo será necessário agora consomem tff.SequenceType(BATCH_TYPE) em vez de BATCH_TYPE .

LOCAL_DATA_TYPE = tff.SequenceType(BATCH_TYPE)

@tff.federated_computation(MODEL_TYPE, tf.float32, LOCAL_DATA_TYPE)
def local_train(initial_model, learning_rate, all_batches):

  # Mapping function to apply to each batch.
  @tff.federated_computation(MODEL_TYPE, BATCH_TYPE)
  def batch_fn(model, batch):
    return batch_train(model, batch, learning_rate)

  return tff.sequence_reduce(all_batches, initial_model, batch_fn)
str(local_train.type_signature)
'(<<weights=float32[784,10],bias=float32[10]>,float32,<x=float32[?,784],y=int32[?]>*> -> <weights=float32[784,10],bias=float32[10]>)'

Existem alguns detalhes enterrados nesta pequena seção de código, vamos examiná-los um por um.

Em primeiro lugar, enquanto que poderíamos ter implementado essa lógica inteiramente em TensorFlow, contando com tf.data.Dataset.reduce para processar a seqüência da mesma forma como temos feito isso anteriormente, optou desta vez para expressar a lógica na língua cola , como um tff.federated_computation . Nós usamos o operador federado tff.sequence_reduce para realizar a redução.

O operador tff.sequence_reduce é utilizado de modo semelhante para tf.data.Dataset.reduce . Você pode pensar nisso como essencialmente o mesmo que tf.data.Dataset.reduce , mas para uso dentro cálculos federados, que, como você pode se lembrar, não pode conter código TensorFlow. É um operador de template com um parâmetro formal 3-tupla que consiste em uma sequência de T -typed elementos, o estado inicial da redução (vamos nos referir a ele abstratamente como zero) de algum tipo U , eo operador de redução da digite (<U,T> -> U) que altera o estado da redução de processamento de um único elemento. O resultado é o estado final da redução, após o processamento de todos os elementos em uma ordem sequencial. Em nosso exemplo, o estado da redução é o modelo treinado em um prefixo dos dados e os elementos são lotes de dados.

Em segundo lugar, note que temos novamente usou uma computação ( batch_train ) como um componente dentro de outro ( local_train ), mas não diretamente. Não podemos usá-lo como um operador de redução porque leva um parâmetro adicional - a taxa de aprendizagem. Para resolver isso, vamos definir um cálculo federado embutido batch_fn que se liga ao local_train 's parâmetro learning_rate em seu corpo. É permitido a um cálculo filho definido desta forma capturar um parâmetro formal de seu pai, desde que o cálculo filho não seja invocado fora do corpo de seu pai. Você pode pensar deste padrão, como um equivalente de functools.partial em Python.

A implicação prática de capturar learning_rate desta maneira é, naturalmente, que o mesmo valor da taxa de aprendizagem é utilizado em todos os lotes.

Agora, vamos tentar a função de formação local recém-definido em toda a seqüência de dados do mesmo usuário que contribuiu o lote de amostra (dígito 5 ).

locally_trained_model = local_train(initial_model, 0.1, federated_train_data[5])

Funcionou? Para responder a esta pergunta, precisamos implementar a avaliação.

Avaliação local

Esta é uma maneira de implementar a avaliação local somando as perdas em todos os lotes de dados (poderíamos ter calculado a média da mesma forma; deixaremos isso como um exercício para o leitor).

@tff.federated_computation(MODEL_TYPE, LOCAL_DATA_TYPE)
def local_eval(model, all_batches):
  # TODO(b/120157713): Replace with `tff.sequence_average()` once implemented.
  return tff.sequence_sum(
      tff.sequence_map(
          tff.federated_computation(lambda b: batch_loss(model, b), BATCH_TYPE),
          all_batches))
str(local_eval.type_signature)
'(<<weights=float32[784,10],bias=float32[10]>,<x=float32[?,784],y=int32[?]>*> -> float32)'

Novamente, existem alguns novos elementos ilustrados por este código, vamos examiná-los um por um.

Primeiro, usamos dois novos operadores federados para o processamento de sequências: tff.sequence_map que leva uma função de mapeamento T->U e uma sequência de T , e emite uma sequência de U obtido pela aplicação da função pointwise mapeamento e tff.sequence_sum que apenas adiciona todos os elementos. Aqui, mapeamos cada lote de dados para um valor de perda e, em seguida, adicionamos os valores de perda resultantes para calcular a perda total.

Note-se que poderíamos ter novamente utilizado tff.sequence_reduce , mas isso não seria a melhor escolha - o processo de redução é, por definição, seqüencial, enquanto o mapeamento e soma pode ser calculado em paralelo. Quando tiver uma escolha, é melhor ficar com os operadores que não restringem as opções de implementação, de modo que quando nossa computação TFF for compilada no futuro para ser implementada em um ambiente específico, seja possível aproveitar todas as oportunidades potenciais para um processo mais rápido execução mais escalonável e eficiente em termos de recursos.

Em segundo lugar, nota que, assim como em local_train , a função do componente que precisamos ( batch_loss ) leva mais parâmetros do que o que o operador federado ( tff.sequence_map ) espera, por isso, mais uma vez definir uma parcial, desta vez em linha envolvendo diretamente a lambda como um tff.federated_computation . Usando wrappers em linha com uma função como um argumento é a maneira recomendada para usar tff.tf_computation para incorporar TensorFlow lógica TFF.

Agora, vamos ver se nosso treinamento funcionou.

print('initial_model loss =', local_eval(initial_model,
                                         federated_train_data[5]))
print('locally_trained_model loss =',
      local_eval(locally_trained_model, federated_train_data[5]))
initial_model loss = 23.025854
locally_trained_model loss = 0.4348469

Na verdade, a perda diminuiu. Mas o que acontece se avaliarmos nos dados de outro usuário?

print('initial_model loss =', local_eval(initial_model,
                                         federated_train_data[0]))
print('locally_trained_model loss =',
      local_eval(locally_trained_model, federated_train_data[0]))
initial_model loss = 23.025854
locally_trained_model loss = 74.50075

Como esperado, as coisas pioraram. O modelo foi treinado para reconhecer 5 , e nunca viu um 0 . Isso traz a questão - como o treinamento local impactou a qualidade do modelo do ponto de vista global?

Avaliação federada

Este é o ponto em nossa jornada em que finalmente voltamos aos tipos federados e às computações federadas - o tópico com o qual começamos. Aqui está um par de definições de tipos de TFF para o modelo que se origina no servidor e os dados que permanecem nos clientes.

SERVER_MODEL_TYPE = tff.type_at_server(MODEL_TYPE)
CLIENT_DATA_TYPE = tff.type_at_clients(LOCAL_DATA_TYPE)

Com todas as definições apresentadas até agora, expressar a avaliação federada no TFF é uma linha única - distribuímos o modelo aos clientes, deixamos cada cliente invocar a avaliação local em sua porção local de dados e, em seguida, calculamos a perda. Aqui está uma maneira de escrever isso.

@tff.federated_computation(SERVER_MODEL_TYPE, CLIENT_DATA_TYPE)
def federated_eval(model, data):
  return tff.federated_mean(
      tff.federated_map(local_eval, [tff.federated_broadcast(model), data]))

Nós já vimos exemplos de tff.federated_mean e tff.federated_map em cenários mais simples, e ao nível intuitivo, eles funcionam como esperado, mas há muito mais nesta seção do código que satisfaça os olhos, então vamos passar por isso com cuidado.

Em primeiro lugar, a ruptura de deixar para baixo o que cada cliente invocar a avaliação local de sua parte local da parte de dados. Como você pode recordar das seções anteriores, local_eval tem uma assinatura de tipo da forma (<MODEL_TYPE, LOCAL_DATA_TYPE> -> float32) .

O operador federado tff.federated_map é um modelo que aceita como parâmetro a 2 tupla que consiste na função de mapeamento de algum tipo T->U e um valor federado de tipo {T}@CLIENTS (ou seja, com constituintes membros da mesmo tipo que o parâmetro da função de mapeamento), e retorna um resultado de tipo {U}@CLIENTS .

Desde estamos alimentando local_eval como uma função de mapeamento para aplicar em uma base por cliente, o segundo argumento deve ser de um tipo federado {<MODEL_TYPE, LOCAL_DATA_TYPE>}@CLIENTS , ou seja, na nomenclatura das seções anteriores, que deveria ser uma tupla federada. Cada cliente deve realizar um conjunto completo de argumentos para local_eval como um constituinte membro. Em vez disso, estamos alimentando-o um Python 2-elemento de list . O que está acontecendo aqui?

Na verdade, este é um exemplo de um tipo de conversão implícita em TFF, semelhante ao gesso tipo implícito que você pode ter encontrado em outros lugares, por exemplo, quando você alimentar um int para uma função que aceita um float . A conversão implícita raramente é usada neste ponto, mas planejamos torná-la mais difundida na TFF como uma forma de minimizar o clichê.

O fundido implícito que é aplicada, neste caso, é a equivalência entre tuplos federados da forma {<X,Y>}@Z , e tuplos de bibliotecas valores <{X}@Z,{Y}@Z> . Enquanto formalmente, estes dois são diferentes assinaturas de tipo, olhando para ela a partir da perspectiva das programadores, cada dispositivo em Z ocupa duas unidades de dados X e Y . O que acontece aqui não é diferente zip em Python, e de fato, oferecemos um operador tff.federated_zip que permite executar tais conversões explicitamente. Quando o tff.federated_map encontra uma tupla como um segundo argumento, ele simplesmente invoca tff.federated_zip para você.

Face ao exposto, você deve agora ser capaz de reconhecer a expressão tff.federated_broadcast(model) como representando um valor de TFF tipo {MODEL_TYPE}@CLIENTS , e data como um valor do tipo TFF {LOCAL_DATA_TYPE}@CLIENTS (ou simplesmente CLIENT_DATA_TYPE ) , os dois em conjunto ficando filtrada através de um implícito tff.federated_zip para formar o segundo argumento para tff.federated_map .

O operador tff.federated_broadcast , como seria de esperar, simplesmente transfere dados a partir do servidor para os clientes.

Agora, vamos ver como nosso treinamento local afetou a perda média no sistema.

print('initial_model loss =', federated_eval(initial_model,
                                             federated_train_data))
print('locally_trained_model loss =',
      federated_eval(locally_trained_model, federated_train_data))
initial_model loss = 23.025852
locally_trained_model loss = 54.432625

De fato, como esperado, a perda aumentou. Para melhorar o modelo para todos os usuários, precisaremos treinar com os dados de todos.

Treino federado

A maneira mais simples de implementar o treinamento federado é treinar localmente e, em seguida, calcular a média dos modelos. Ele usa os mesmos blocos de construção e padrões que já discutimos, como você pode ver abaixo.

SERVER_FLOAT_TYPE = tff.type_at_server(tf.float32)


@tff.federated_computation(SERVER_MODEL_TYPE, SERVER_FLOAT_TYPE,
                           CLIENT_DATA_TYPE)
def federated_train(model, learning_rate, data):
  return tff.federated_mean(
      tff.federated_map(local_train, [
          tff.federated_broadcast(model),
          tff.federated_broadcast(learning_rate), data
      ]))

Note-se que na aplicação full-featured de Média Federated fornecido pelo tff.learning , ao invés da média dos modelos, preferimos deltas médias modelo, por uma série de razões, por exemplo, a capacidade de cortar as normas de atualização, para a compressão, etc .

Vamos ver se o treinamento funciona executando algumas rodadas de treinamento e comparando a perda média antes e depois.

model = initial_model
learning_rate = 0.1
for round_num in range(5):
  model = federated_train(model, learning_rate, federated_train_data)
  learning_rate = learning_rate * 0.9
  loss = federated_eval(model, federated_train_data)
  print('round {}, loss={}'.format(round_num, loss))
round 0, loss=21.60552406311035
round 1, loss=20.365678787231445
round 2, loss=19.27480125427246
round 3, loss=18.31110954284668
round 4, loss=17.45725440979004

Para completar, vamos agora também executar os dados de teste para confirmar se nosso modelo generaliza bem.

print('initial_model test loss =',
      federated_eval(initial_model, federated_test_data))
print('trained_model test loss =', federated_eval(model, federated_test_data))
initial_model test loss = 22.795593
trained_model test loss = 17.278767

Isso conclui nosso tutorial.

Claro, nosso exemplo simplificado não reflete uma série de coisas que você precisa fazer em um cenário mais realista - por exemplo, não calculamos métricas além da perda. Nós encorajamos você a estudar a implementação de média federado em tff.learning como um exemplo mais completo, e como uma forma de demonstrar algumas das práticas de codificação que gostaríamos de encorajar.