¿Tengo una pregunta? Conéctese con la comunidad en el Foro de visita del foro de TensorFlow

Algoritmos federados personalizados, parte 2: implementación de promedios federados

Ver en TensorFlow.org Ejecutar en Google Colab Ver fuente en GitHub Descargar cuaderno

Este tutorial es la segunda parte de una serie de dos partes que demuestra cómo implementar tipos personalizados de algoritmos federados en TFF utilizando Federated Core (FC) , que sirve como base para la capa de aprendizaje federado (FL) ( tff.learning ) .

Le recomendamos que lea primero la primera parte de esta serie , que presenta algunos de los conceptos clave y abstracciones de programación que se utilizan aquí.

Esta segunda parte de la serie utiliza los mecanismos introducidos en la primera parte para implementar una versión simple de algoritmos de capacitación y evaluación federados.

Le recomendamos que revise los tutoriales de clasificación de imágenes y generación de texto para obtener una introducción de nivel superior y más suave a las API de aprendizaje federado de TFF, ya que lo ayudarán a poner los conceptos que describimos aquí en contexto.

Antes que empecemos

Antes de comenzar, intente ejecutar el siguiente ejemplo de "Hola mundo" para asegurarse de que su entorno esté configurado correctamente. Si no funciona, consulte la guía de instalación para obtener instrucciones.

!pip install --quiet --upgrade tensorflow-federated-nightly
!pip install --quiet --upgrade nest-asyncio

import nest_asyncio
nest_asyncio.apply()
import collections

import numpy as np
import tensorflow as tf
import tensorflow_federated as tff

# TODO(b/148678573,b/148685415): must use the reference context because it
# supports unbounded references and tff.sequence_* intrinsics.
tff.backends.reference.set_reference_context()
@tff.federated_computation
def hello_world():
  return 'Hello, World!'

hello_world()
'Hello, World!'

Implementación de promedios federados

Al igual que en el aprendizaje federado para la clasificación de imágenes , vamos a utilizar el ejemplo de MNIST, pero dado que se trata de un tutorial de bajo nivel, vamos a omitir la API de Keras y tff.simulation , escribir código de modelo sin procesar y construir un conjunto de datos federados desde cero.

Preparar conjuntos de datos federados

A modo de demostración, vamos a simular un escenario en el que tenemos datos de 10 usuarios y cada uno de los usuarios aporta conocimientos sobre cómo reconocer un dígito diferente. Esto es lo más no- iid posible .

Primero, carguemos los datos MNIST estándar:

mnist_train, mnist_test = tf.keras.datasets.mnist.load_data()
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz
11493376/11490434 [==============================] - 0s 0us/step
11501568/11490434 [==============================] - 0s 0us/step
[(x.dtype, x.shape) for x in mnist_train]
[(dtype('uint8'), (60000, 28, 28)), (dtype('uint8'), (60000,))]

Los datos vienen como matrices Numpy, una con imágenes y otra con etiquetas de dígitos, ambas con la primera dimensión repasando los ejemplos individuales. Escribamos una función auxiliar que la formatee de una manera compatible con la forma en que alimentamos secuencias federadas en los cálculos de TFF, es decir, como una lista de listas: la lista externa se extiende sobre los usuarios (dígitos), las internas abarcan lotes de datos en la secuencia de cada cliente. Como es habitual, Estructuraremos cada lote como un par de tensores nombrados x y y , cada uno con la dimensión principal de lote. Mientras lo hacemos, también aplanaremos cada imagen en un vector de 784 elementos y cambiaremos la escala de los píxeles en el rango 0..1 , para que no tengamos que saturar la lógica del modelo con conversiones de datos.

NUM_EXAMPLES_PER_USER = 1000
BATCH_SIZE = 100


def get_data_for_digit(source, digit):
  output_sequence = []
  all_samples = [i for i, d in enumerate(source[1]) if d == digit]
  for i in range(0, min(len(all_samples), NUM_EXAMPLES_PER_USER), BATCH_SIZE):
    batch_samples = all_samples[i:i + BATCH_SIZE]
    output_sequence.append({
        'x':
            np.array([source[0][i].flatten() / 255.0 for i in batch_samples],
                     dtype=np.float32),
        'y':
            np.array([source[1][i] for i in batch_samples], dtype=np.int32)
    })
  return output_sequence


federated_train_data = [get_data_for_digit(mnist_train, d) for d in range(10)]

federated_test_data = [get_data_for_digit(mnist_test, d) for d in range(10)]

Como prueba rápida de cordura, veamos el tensor Y en el último lote de datos aportados por el quinto cliente (el que corresponde al dígito 5 ).

federated_train_data[5][-1]['y']
array([5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5,
       5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5,
       5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5,
       5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5,
       5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5], dtype=int32)

Solo para estar seguros, veamos también la imagen correspondiente al último elemento de ese lote.

from matplotlib import pyplot as plt

plt.imshow(federated_train_data[5][-1]['x'][-1].reshape(28, 28), cmap='gray')
plt.grid(False)
plt.show()

png

Sobre la combinación de TensorFlow y TFF

En este tutorial, para simplificar, decoramos inmediatamente funciones que introducen la lógica de TensorFlow con tff.tf_computation . Sin embargo, para una lógica más compleja, este no es el patrón que recomendamos. La depuración de TensorFlow ya puede ser un desafío, y la depuración de TensorFlow después de que se haya serializado por completo y luego reimportado necesariamente pierde algunos metadatos y limita la interactividad, lo que hace que la depuración sea aún más desafiante.

Por lo tanto, recomendamos encarecidamente escribir lógica TF compleja como funciones de Python independientes (es decir, sin decoración tff.tf_computation ). De esta manera, la lógica de TensorFlow se puede desarrollar y probar usando las mejores prácticas y herramientas de TF (como el modo ansioso), antes de serializar el cálculo para TFF (por ejemplo, invocando tff.tf_computation con una función de Python como argumento).

Definición de una función de pérdida

Ahora que tenemos los datos, definamos una función de pérdida que podamos usar para el entrenamiento. Primero, definamos el tipo de entrada como una TFF llamada tupla. Dado que el tamaño de los lotes de datos puede variar, establecemos la dimensión del lote en None para indicar que se desconoce el tamaño de esta dimensión.

BATCH_SPEC = collections.OrderedDict(
    x=tf.TensorSpec(shape=[None, 784], dtype=tf.float32),
    y=tf.TensorSpec(shape=[None], dtype=tf.int32))
BATCH_TYPE = tff.to_type(BATCH_SPEC)

str(BATCH_TYPE)
'<x=float32[?,784],y=int32[?]>'

Quizás se pregunte por qué no podemos simplemente definir un tipo de Python ordinario. Recuerde la discusión en la parte 1 , donde explicamos que si bien podemos expresar la lógica de los cálculos de TFF usando Python, los cálculos de TFF bajo el capó no son Python. El símbolo BATCH_TYPE definido anteriormente representa una especificación de tipo TFF abstracta. Es importante distinguir este tipo de TFF abstracto de los tipos de representación de Python concretos, por ejemplo, contenedores como dict o collections.namedtuple que pueden usarse para representar el tipo de TFF en el cuerpo de una función de Python. A diferencia de Python, TFF tiene un solo constructor de tipo abstracto tff.StructType para contenedores tipo tupla, con elementos que se pueden nombrar individualmente o dejar sin nombre. Este tipo también se utiliza para modelar parámetros formales de cálculos, ya que los cálculos de TFF solo pueden declarar formalmente un parámetro y un resultado; verá ejemplos de esto en breve.

Definamos ahora el tipo de parámetros del modelo TFF, nuevamente como un TFF denominado tupla de pesos y sesgo .

MODEL_SPEC = collections.OrderedDict(
    weights=tf.TensorSpec(shape=[784, 10], dtype=tf.float32),
    bias=tf.TensorSpec(shape=[10], dtype=tf.float32))
MODEL_TYPE = tff.to_type(MODEL_SPEC)

print(MODEL_TYPE)
<weights=float32[784,10],bias=float32[10]>

Con esas definiciones en su lugar, ahora podemos definir la pérdida para un modelo dado, en un solo lote. Tenga en cuenta el uso del decorador @tf.function dentro del decorador @tff.tf_computation . Esto nos permite escribir TF usando Python como semántica a pesar de estar dentro de un contexto tf.Graph creado por el decorador tff.tf_computation .

# NOTE: `forward_pass` is defined separately from `batch_loss` so that it can 
# be later called from within another tf.function. Necessary because a
# @tf.function  decorated method cannot invoke a @tff.tf_computation.

@tf.function
def forward_pass(model, batch):
  predicted_y = tf.nn.softmax(
      tf.matmul(batch['x'], model['weights']) + model['bias'])
  return -tf.reduce_mean(
      tf.reduce_sum(
          tf.one_hot(batch['y'], 10) * tf.math.log(predicted_y), axis=[1]))

@tff.tf_computation(MODEL_TYPE, BATCH_TYPE)
def batch_loss(model, batch):
  return forward_pass(model, batch)

Como se esperaba, el cálculo batch_loss devuelve la pérdida de float32 dado el modelo y un solo lote de datos. Observe cómo MODEL_TYPE y BATCH_TYPE se han agrupado en una tupla de 2 parámetros formales; puede reconocer el tipo de batch_loss como (<MODEL_TYPE,BATCH_TYPE> -> float32) .

str(batch_loss.type_signature)
'(<<weights=float32[784,10],bias=float32[10]>,<x=float32[?,784],y=int32[?]>> -> float32)'

Como prueba de cordura, construyamos un modelo inicial lleno de ceros y calculemos la pérdida sobre el lote de datos que visualizamos arriba.

initial_model = collections.OrderedDict(
    weights=np.zeros([784, 10], dtype=np.float32),
    bias=np.zeros([10], dtype=np.float32))

sample_batch = federated_train_data[5][-1]

batch_loss(initial_model, sample_batch)
2.3025854

Tenga en cuenta que alimentamos el cálculo TFF con el modelo inicial definido como un dict , aunque el cuerpo de la función de Python que lo define consume parámetros del model['weight'] como model['weight'] y model['bias'] . Los argumentos de la llamada a batch_loss no se pasan simplemente al cuerpo de esa función.

¿Qué sucede cuando invocamos batch_loss ? El cuerpo de Python de batch_loss ya se ha rastreado y serializado en la celda anterior donde se definió. TFF actúa como el llamador para batch_loss en el momento de la definición del cálculo, y como el objetivo de la invocación en el momento en que se invoca batch_loss . En ambos roles, TFF sirve como puente entre el sistema de tipos abstractos de TFF y los tipos de representación de Python. En el momento de la invocación, TFF aceptará la mayoría de los tipos de contenedores estándar de Python ( dict , list , tuple , collections.namedtuple , etc.) como representaciones concretas de tuplas TFF abstractas. Además, aunque como se señaló anteriormente, los cálculos de TFF formalmente solo aceptan un único parámetro, puede usar la sintaxis de llamada familiar de Python con argumentos posicionales y / o de palabras clave en caso de que el tipo de parámetro sea una tupla; funciona como se esperaba.

Descenso de gradiente en un solo lote

Ahora, definamos un cálculo que use esta función de pérdida para realizar un solo paso de descenso de gradiente. Observe cómo al definir esta función, usamos batch_loss como subcomponente. Puede invocar un cálculo construido con tff.tf_computation dentro del cuerpo de otro cálculo, aunque normalmente esto no es necesario; como se indicó anteriormente, debido a que la serialización pierde parte de la información de depuración, a menudo es preferible que los cálculos más complejos escriban y prueben todo el TensorFlow sin el decorador tff.tf_computation .

@tff.tf_computation(MODEL_TYPE, BATCH_TYPE, tf.float32)
def batch_train(initial_model, batch, learning_rate):
  # Define a group of model variables and set them to `initial_model`. Must
  # be defined outside the @tf.function.
  model_vars = collections.OrderedDict([
      (name, tf.Variable(name=name, initial_value=value))
      for name, value in initial_model.items()
  ])
  optimizer = tf.keras.optimizers.SGD(learning_rate)

  @tf.function
  def _train_on_batch(model_vars, batch):
    # Perform one step of gradient descent using loss from `batch_loss`.
    with tf.GradientTape() as tape:
      loss = forward_pass(model_vars, batch)
    grads = tape.gradient(loss, model_vars)
    optimizer.apply_gradients(
        zip(tf.nest.flatten(grads), tf.nest.flatten(model_vars)))
    return model_vars

  return _train_on_batch(model_vars, batch)
str(batch_train.type_signature)
'(<<weights=float32[784,10],bias=float32[10]>,<x=float32[?,784],y=int32[?]>,float32> -> <weights=float32[784,10],bias=float32[10]>)'

Cuando invoca una función de Python decorada con tff.tf_computation dentro del cuerpo de otra función de este tipo, la lógica del cálculo de TFF interno está incrustada (esencialmente, en línea) en la lógica de la externa. Como se señaló anteriormente, si está escribiendo ambos cálculos, es probable que sea preferible hacer que la función interna ( batch_loss en este caso) sea una Python normal o tf.function lugar de tff.tf_computation . Sin embargo, aquí ilustramos que llamar a un tff.tf_computation dentro de otro básicamente funciona como se esperaba. Esto puede ser necesario si, por ejemplo, no tiene el código Python que define batch_loss , sino solo su representación TFF serializada.

Ahora, apliquemos esta función varias veces al modelo inicial para ver si la pérdida disminuye.

model = initial_model
losses = []
for _ in range(5):
  model = batch_train(model, sample_batch, 0.1)
  losses.append(batch_loss(model, sample_batch))
losses
[0.19690022, 0.13176313, 0.10113226, 0.082738124, 0.0703014]

Descenso de gradiente en una secuencia de datos locales

Ahora, dado que batch_train parece funcionar, escribamos una función de entrenamiento similar local_train que consume la secuencia completa de todos los lotes de un usuario en lugar de un solo lote. El nuevo cálculo deberá consumir ahora tff.SequenceType(BATCH_TYPE) lugar de BATCH_TYPE .

LOCAL_DATA_TYPE = tff.SequenceType(BATCH_TYPE)

@tff.federated_computation(MODEL_TYPE, tf.float32, LOCAL_DATA_TYPE)
def local_train(initial_model, learning_rate, all_batches):

  # Mapping function to apply to each batch.
  @tff.federated_computation(MODEL_TYPE, BATCH_TYPE)
  def batch_fn(model, batch):
    return batch_train(model, batch, learning_rate)

  return tff.sequence_reduce(all_batches, initial_model, batch_fn)
str(local_train.type_signature)
'(<<weights=float32[784,10],bias=float32[10]>,float32,<x=float32[?,784],y=int32[?]>*> -> <weights=float32[784,10],bias=float32[10]>)'

Hay bastantes detalles enterrados en esta pequeña sección de código, repasemos uno por uno.

Primero, si bien podríamos haber implementado esta lógica por completo en TensorFlow, confiando en tf.data.Dataset.reduce para procesar la secuencia de manera similar a como lo hicimos anteriormente, hemos optado esta vez por expresar la lógica en el lenguaje adhesivo. , como tff.federated_computation . Hemos utilizado el operador federado tff.sequence_reduce para realizar la reducción.

El operador tff.sequence_reduce se usa de manera similar a tf.data.Dataset.reduce . Puede pensar en él como esencialmente lo mismo que tf.data.Dataset.reduce , pero para usarlo dentro de cálculos federados, que, como recordará, no pueden contener código de TensorFlow. Es un operador de plantilla con un parámetro formal de 3 tuplas que consta de una secuencia de elementos de tipo T , el estado inicial de la reducción (nos referiremos a él de forma abstracta como cero ) de algún tipo U y el operador de reducción de tipo (<U,T> -> U) que altera el estado de la reducción procesando un solo elemento. El resultado es el estado final de la reducción, después de procesar todos los elementos en orden secuencial. En nuestro ejemplo, el estado de la reducción es el modelo entrenado en un prefijo de los datos y los elementos son lotes de datos.

En segundo lugar, tenga en cuenta que hemos vuelto a utilizar un cálculo ( batch_train ) como un componente dentro de otro ( local_train ), pero no directamente. No podemos usarlo como un operador de reducción porque requiere un parámetro adicional: la tasa de aprendizaje. Para resolver esto, definimos un cálculo federado integrado batch_fn que se une al parámetro learning_rate local_train en su cuerpo. Se permite que un cálculo hijo definido de esta manera capture un parámetro formal de su padre siempre que el cálculo hijo no se invoque fuera del cuerpo de su padre. Puede pensar en este patrón como un equivalente de functools.partial en Python.

La implicación práctica de capturar la tasa de learning_rate esta manera es, por supuesto, que se utiliza el mismo valor de tasa de aprendizaje en todos los lotes.

Ahora, probemos la función de entrenamiento local recién definida en toda la secuencia de datos del mismo usuario que contribuyó con el lote de muestra (dígito 5 ).

locally_trained_model = local_train(initial_model, 0.1, federated_train_data[5])

¿Funcionó? Para responder a esta pregunta, debemos implementar la evaluación.

Evaluación local

Aquí hay una forma de implementar la evaluación local sumando las pérdidas en todos los lotes de datos (podríamos haber calculado igualmente el promedio; lo dejaremos como un ejercicio para el lector).

@tff.federated_computation(MODEL_TYPE, LOCAL_DATA_TYPE)
def local_eval(model, all_batches):
  # TODO(b/120157713): Replace with `tff.sequence_average()` once implemented.
  return tff.sequence_sum(
      tff.sequence_map(
          tff.federated_computation(lambda b: batch_loss(model, b), BATCH_TYPE),
          all_batches))
str(local_eval.type_signature)
'(<<weights=float32[784,10],bias=float32[10]>,<x=float32[?,784],y=int32[?]>*> -> float32)'

Nuevamente, hay algunos elementos nuevos ilustrados por este código, repasemos uno por uno.

Primero, hemos utilizado dos nuevos operadores federados para procesar secuencias: tff.sequence_map que toma una función de mapeo T->U y una secuencia de T , y emite una secuencia de U obtenida aplicando la función de mapeo tff.sequence_sum , y tff.sequence_sum que simplemente agrega todos los elementos. Aquí, asignamos cada lote de datos a un valor de pérdida y luego sumamos los valores de pérdida resultantes para calcular la pérdida total.

Tenga en cuenta que podríamos haber usado tff.sequence_reduce nuevamente, pero esta no sería la mejor opción: el proceso de reducción es, por definición, secuencial, mientras que el mapeo y la suma se pueden calcular en paralelo. Cuando se le da una opción, es mejor quedarse con operadores que no restrinjan las opciones de implementación, de modo que cuando nuestro cálculo TFF se compile en el futuro para implementarlo en un entorno específico, uno pueda aprovechar al máximo todas las oportunidades potenciales para una , ejecución más escalable y más eficiente en el uso de recursos.

En segundo lugar, tenga en cuenta que al igual que en local_train , la función del componente que necesitamos ( batch_loss ) toma más parámetros de los que espera el operador federado ( tff.sequence_map ), por lo que nuevamente definimos un parcial, esta vez en línea envolviendo directamente una lambda como tff.federated_computation . El uso de contenedores en línea con una función como argumento es la forma recomendada de usar tff.tf_computation para incorporar la lógica de TensorFlow en TFF.

Ahora, veamos si nuestro entrenamiento funcionó.

print('initial_model loss =', local_eval(initial_model,
                                         federated_train_data[5]))
print('locally_trained_model loss =',
      local_eval(locally_trained_model, federated_train_data[5]))
initial_model loss = 23.025854
locally_trained_model loss = 0.4348469

De hecho, la pérdida disminuyó. Pero, ¿qué pasa si lo evaluamos sobre los datos de otro usuario?

print('initial_model loss =', local_eval(initial_model,
                                         federated_train_data[0]))
print('locally_trained_model loss =',
      local_eval(locally_trained_model, federated_train_data[0]))
initial_model loss = 23.025854
locally_trained_model loss = 74.50075

Como era de esperar, las cosas empeoraron. El modelo fue entrenado para reconocer 5 y nunca vio un 0 . Esto trae la pregunta: ¿cómo impactó la capacitación local en la calidad del modelo desde la perspectiva global?

Evaluación federada

Este es el punto de nuestro viaje en el que finalmente volvemos a los tipos federados y los cálculos federados, el tema con el que comenzamos. Aquí hay un par de definiciones de tipos de TFF para el modelo que se origina en el servidor y los datos que permanecen en los clientes.

SERVER_MODEL_TYPE = tff.type_at_server(MODEL_TYPE)
CLIENT_DATA_TYPE = tff.type_at_clients(LOCAL_DATA_TYPE)

Con todas las definiciones introducidas hasta ahora, expresar la evaluación federada en TFF es una línea recta: distribuimos el modelo a los clientes, permitimos que cada cliente invoque la evaluación local en su porción local de datos y luego promediamos la pérdida. Aquí hay una forma de escribir esto.

@tff.federated_computation(SERVER_MODEL_TYPE, CLIENT_DATA_TYPE)
def federated_eval(model, data):
  return tff.federated_mean(
      tff.federated_map(local_eval, [tff.federated_broadcast(model), data]))

Ya hemos visto ejemplos de tff.federated_mean y tff.federated_map en escenarios más simples y, a nivel intuitivo, funcionan como se esperaba, pero hay más en esta sección de código de lo que parece, así que repasemos esto con cuidado.

Primero, analicemos la parte de dejar que cada cliente invoque la evaluación local en su parte local de datos . Como recordará de las secciones anteriores, local_eval tiene una firma de tipo del formulario (<MODEL_TYPE, LOCAL_DATA_TYPE> -> float32) .

El operador federado tff.federated_map es una plantilla que acepta como parámetro una 2-tupla que consiste en la función de mapeo de algún tipo T->U y un valor federado de tipo {T}@CLIENTS (es decir, con constituyentes miembros del mismo tipo que el parámetro de la función de mapeo) y devuelve un resultado de tipo {U}@CLIENTS .

Dado que estamos alimentando local_eval como una función de mapeo para aplicar por cliente, el segundo argumento debe ser de un tipo federado {<MODEL_TYPE, LOCAL_DATA_TYPE>}@CLIENTS , es decir, en la nomenclatura de las secciones anteriores, debe ser una tupla federada. Cada cliente debe tener un conjunto completo de argumentos para local_eval como miembro constituyente. En cambio, lo estamos alimentando con una list Python de 2 elementos. ¿Que esta pasando aqui?

De hecho, este es un ejemplo de una conversión de tipos implícitos en TFF, similar a las conversiones de tipos implícitos que puede haber encontrado en otros lugares, por ejemplo, cuando alimenta un int a una función que acepta un float . El casting implícito se usa escasamente en este punto, pero planeamos hacerlo más omnipresente en TFF como una forma de minimizar la repetición.

La conversión implícita que se aplica en este caso es la equivalencia entre tuplas federadas de la forma {<X,Y>}@Z y tuplas de valores federados <{X}@Z,{Y}@Z> . Si bien formalmente, estos dos son firmas de tipos diferentes, mirándolo desde la perspectiva de los programadores, cada dispositivo en Z contiene dos unidades de datos X e Y Lo que sucede aquí no es diferente de zip en Python y, de hecho, ofrecemos un operador tff.federated_zip que le permite realizar tales conversiones explícitamente. Cuando tff.federated_map encuentra una tupla como segundo argumento, simplemente invoca tff.federated_zip por usted.

Dado lo anterior, ahora debería poder reconocer que la expresión tff.federated_broadcast(model) representa un valor de tipo TFF {MODEL_TYPE}@CLIENTS , y los data como un valor de tipo TFF {LOCAL_DATA_TYPE}@CLIENTS (o simplemente CLIENT_DATA_TYPE ) , los dos se filtran juntos a través de un tff.federated_zip implícito para formar el segundo argumento de tff.federated_map .

El operador tff.federated_broadcast , como era de esperar, simplemente transfiere datos desde el servidor a los clientes.

Ahora, veamos cómo nuestro entrenamiento local afectó la pérdida promedio en el sistema.

print('initial_model loss =', federated_eval(initial_model,
                                             federated_train_data))
print('locally_trained_model loss =',
      federated_eval(locally_trained_model, federated_train_data))
initial_model loss = 23.025852
locally_trained_model loss = 54.432625

De hecho, como se esperaba, la pérdida ha aumentado. Para mejorar el modelo para todos los usuarios, necesitaremos entrenarnos con los datos de todos.

Entrenamiento federado

La forma más sencilla de implementar el entrenamiento federado es entrenar localmente y luego promediar los modelos. Esto usa los mismos bloques de construcción y patrones que ya hemos discutido, como puede ver a continuación.

SERVER_FLOAT_TYPE = tff.type_at_server(tf.float32)


@tff.federated_computation(SERVER_MODEL_TYPE, SERVER_FLOAT_TYPE,
                           CLIENT_DATA_TYPE)
def federated_train(model, learning_rate, data):
  return tff.federated_mean(
      tff.federated_map(local_train, [
          tff.federated_broadcast(model),
          tff.federated_broadcast(learning_rate), data
      ]))

Tenga en cuenta que en la implementación completa de Promedio federado proporcionada por tff.learning , en lugar de promediar los modelos, preferimos promediar los deltas del modelo, por varias razones, por ejemplo, la capacidad de recortar las normas de actualización, por compresión, etc. .

Veamos si el entrenamiento funciona ejecutando algunas rondas de entrenamiento y comparando la pérdida promedio antes y después.

model = initial_model
learning_rate = 0.1
for round_num in range(5):
  model = federated_train(model, learning_rate, federated_train_data)
  learning_rate = learning_rate * 0.9
  loss = federated_eval(model, federated_train_data)
  print('round {}, loss={}'.format(round_num, loss))
round 0, loss=21.60552406311035
round 1, loss=20.365678787231445
round 2, loss=19.27480125427246
round 3, loss=18.31110954284668
round 4, loss=17.45725440979004

Para completar, ahora también ejecutemos los datos de prueba para confirmar que nuestro modelo se generaliza bien.

print('initial_model test loss =',
      federated_eval(initial_model, federated_test_data))
print('trained_model test loss =', federated_eval(model, federated_test_data))
initial_model test loss = 22.795593
trained_model test loss = 17.278767

Con esto concluye nuestro tutorial.

Por supuesto, nuestro ejemplo simplificado no refleja una serie de cosas que tendría que hacer en un escenario más realista; por ejemplo, no hemos calculado métricas que no sean las pérdidas. Le animamos a estudiar la implementación del promedio federado en tff.learning como un ejemplo más completo y como una forma de demostrar algunas de las prácticas de codificación que nos gustaría fomentar.