Algoritmos federados personalizados, parte 2: implementación del promedio federado

Ver en TensorFlow.org Ejecutar en Google Colab Ver fuente en GitHub Descargar cuaderno

Este tutorial es la segunda parte de una serie de dos partes que muestra cómo implementar tipos de algoritmos personalizados federados en TFF usando el Federados Core (FC) , que sirve de base para el Federados de Aprendizaje (FL) capa ( tff.learning ) .

Le recomendamos que lea primero la primera parte de esta serie , que introducen algunos de los conceptos clave y las abstracciones de programación utilizado aquí.

Esta segunda parte de la serie utiliza los mecanismos introducidos en la primera parte para implementar una versión simple de algoritmos de capacitación y evaluación federados.

Lo invitamos a revisar la clasificación de imágenes y la generación de texto tutoriales para un nivel más alto y una introducción más suave para federados API de aprendizaje de TFF, ya que le ayudará a poner los conceptos que describimos aquí en su contexto.

Antes que empecemos

Antes de comenzar, intente ejecutar el siguiente ejemplo de "Hola mundo" para asegurarse de que su entorno esté configurado correctamente. Si esto no funciona, por favor refiérase a la instalación de guía para obtener instrucciones.

!pip install --quiet --upgrade tensorflow-federated-nightly
!pip install --quiet --upgrade nest-asyncio

import nest_asyncio
nest_asyncio.apply()
import collections

import numpy as np
import tensorflow as tf
import tensorflow_federated as tff

# TODO(b/148678573,b/148685415): must use the reference context because it
# supports unbounded references and tff.sequence_* intrinsics.
tff.backends.reference.set_reference_context()
@tff.federated_computation
def hello_world():
  return 'Hello, World!'

hello_world()
'Hello, World!'

Implementación de promedios federados

Al igual que en Federados de Aprendizaje para la Clasificación de Imágenes , vamos a utilizar el ejemplo MNIST, pero ya que esta pretende ser un tutorial de bajo nivel, vamos a la API de derivación Keras y tff.simulation , escribimos código de modelo en bruto, y un constructo conjunto de datos federados desde cero.

Preparar conjuntos de datos federados

A modo de demostración, vamos a simular un escenario en el que tenemos datos de 10 usuarios y cada uno de los usuarios aporta conocimientos sobre cómo reconocer un dígito diferente. Esto es lo más no iid como se pone.

Primero, carguemos los datos MNIST estándar:

mnist_train, mnist_test = tf.keras.datasets.mnist.load_data()
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz
11493376/11490434 [==============================] - 0s 0us/step
11501568/11490434 [==============================] - 0s 0us/step
[(x.dtype, x.shape) for x in mnist_train]
[(dtype('uint8'), (60000, 28, 28)), (dtype('uint8'), (60000,))]

Los datos vienen como matrices Numpy, una con imágenes y otra con etiquetas de dígitos, ambas con la primera dimensión repasando los ejemplos individuales. Escribamos una función auxiliar que la formatee de una manera compatible con la forma en que alimentamos secuencias federadas en los cálculos de TFF, es decir, como una lista de listas: la lista externa se extiende sobre los usuarios (dígitos), las internas abarcan lotes de datos en la secuencia de cada cliente. Como es habitual, Estructuraremos cada lote como un par de tensores nombrados x y y , cada uno con la dimensión principal de lote. Mientras que en él, también vamos a aplanar cada imagen en un vector de 784 elementos y cambiar la escala de los píxeles en que en el 0..1 rango, por lo que no tenemos el desorden de la lógica del modelo con las conversiones de datos.

NUM_EXAMPLES_PER_USER = 1000
BATCH_SIZE = 100


def get_data_for_digit(source, digit):
  output_sequence = []
  all_samples = [i for i, d in enumerate(source[1]) if d == digit]
  for i in range(0, min(len(all_samples), NUM_EXAMPLES_PER_USER), BATCH_SIZE):
    batch_samples = all_samples[i:i + BATCH_SIZE]
    output_sequence.append({
        'x':
            np.array([source[0][i].flatten() / 255.0 for i in batch_samples],
                     dtype=np.float32),
        'y':
            np.array([source[1][i] for i in batch_samples], dtype=np.int32)
    })
  return output_sequence


federated_train_data = [get_data_for_digit(mnist_train, d) for d in range(10)]

federated_test_data = [get_data_for_digit(mnist_test, d) for d in range(10)]

Como una comprobación de validez rápida, Echemos un vistazo a la Y tensor en el último lote de datos aportados por el quinto cliente (el que corresponde al dígito 5 ).

federated_train_data[5][-1]['y']
array([5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5,
       5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5,
       5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5,
       5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5,
       5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5], dtype=int32)

Solo para estar seguros, veamos también la imagen correspondiente al último elemento de ese lote.

from matplotlib import pyplot as plt

plt.imshow(federated_train_data[5][-1]['x'][-1].reshape(28, 28), cmap='gray')
plt.grid(False)
plt.show()

png

Sobre la combinación de TensorFlow y TFF

En este tutorial, por compacidad decoramos inmediatamente funciones que introducen la lógica TensorFlow con tff.tf_computation . Sin embargo, para una lógica más compleja, este no es el patrón que recomendamos. La depuración de TensorFlow ya puede ser un desafío, y la depuración de TensorFlow después de que se haya serializado por completo y luego reimportado necesariamente pierde algunos metadatos y limita la interactividad, lo que hace que la depuración sea aún más desafiante.

Por lo tanto, es muy recomendable escribir la lógica de TF complejo como las funciones de Python independiente (es decir, sin tff.tf_computation decoración). De esta manera la lógica TensorFlow puede ser desarrollado y probado usando las mejores prácticas TF y herramientas (como el modo ansioso), antes de serializar el cálculo para TFF (por ejemplo, mediante la invocación de tff.tf_computation con una función de Python como el argumento).

Definición de una función de pérdida

Ahora que tenemos los datos, definamos una función de pérdida que podamos usar para el entrenamiento. Primero, definamos el tipo de entrada como una TFF llamada tupla. Dado que el tamaño de los lotes de datos puede variar, fijamos la dimensión lote a None para indicar que el tamaño de esta dimensión es desconocido.

BATCH_SPEC = collections.OrderedDict(
    x=tf.TensorSpec(shape=[None, 784], dtype=tf.float32),
    y=tf.TensorSpec(shape=[None], dtype=tf.int32))
BATCH_TYPE = tff.to_type(BATCH_SPEC)

str(BATCH_TYPE)
'<x=float32[?,784],y=int32[?]>'

Quizás se pregunte por qué no podemos simplemente definir un tipo de Python ordinario. Recordemos la discusión en la parte 1 , donde explicamos que mientras que podemos expresar la lógica de cálculos TFF usando Python, bajo los cálculos de FFT campana no son Python. El símbolo BATCH_TYPE definido anteriormente representa una especificación de tipo de TFF abstracto. Es importante distinguir este tipo TFF extracto de hormigón tipos de representación Python, por ejemplo, recipientes tales como dict o collections.namedtuple que pueden utilizarse para representar el tipo TFF en el cuerpo de una función Python. A diferencia de Python, TFF tiene un único constructor de tipo abstracto tff.StructType de tupla-como contenedores, con elementos que pueden ser nombradas individualmente o dejados sin nombre. Este tipo también se utiliza para modelar parámetros formales de cálculos, ya que los cálculos de TFF solo pueden declarar formalmente un parámetro y un resultado; verá ejemplos de esto en breve.

Ahora vamos a definir el tipo de TFF de los parámetros del modelo, de nuevo como TFF llamado tupla de pesos y sesgos.

MODEL_SPEC = collections.OrderedDict(
    weights=tf.TensorSpec(shape=[784, 10], dtype=tf.float32),
    bias=tf.TensorSpec(shape=[10], dtype=tf.float32))
MODEL_TYPE = tff.to_type(MODEL_SPEC)

print(MODEL_TYPE)
<weights=float32[784,10],bias=float32[10]>

Con esas definiciones en su lugar, ahora podemos definir la pérdida para un modelo dado, en un solo lote. Tenga en cuenta el uso de @tf.function decorador interior de la @tff.tf_computation decorador. Esto nos permite escribir TF usando Python como la semántica a pesar de que estábamos dentro de un tf.Graph contexto creado por la tff.tf_computation decorador.

# NOTE: `forward_pass` is defined separately from `batch_loss` so that it can 
# be later called from within another tf.function. Necessary because a
# @tf.function  decorated method cannot invoke a @tff.tf_computation.

@tf.function
def forward_pass(model, batch):
  predicted_y = tf.nn.softmax(
      tf.matmul(batch['x'], model['weights']) + model['bias'])
  return -tf.reduce_mean(
      tf.reduce_sum(
          tf.one_hot(batch['y'], 10) * tf.math.log(predicted_y), axis=[1]))

@tff.tf_computation(MODEL_TYPE, BATCH_TYPE)
def batch_loss(model, batch):
  return forward_pass(model, batch)

Como era de esperar, el cómputo batch_loss rendimientos float32 pérdida dado el modelo y un único lote de datos. Nota cómo el MODEL_TYPE y BATCH_TYPE han sido agrupados en un 2-tupla de parámetros formales; se puede reconocer el tipo de batch_loss como (<MODEL_TYPE,BATCH_TYPE> -> float32) .

str(batch_loss.type_signature)
'(<<weights=float32[784,10],bias=float32[10]>,<x=float32[?,784],y=int32[?]>> -> float32)'

Como prueba de cordura, construyamos un modelo inicial lleno de ceros y calculemos la pérdida sobre el lote de datos que visualizamos arriba.

initial_model = collections.OrderedDict(
    weights=np.zeros([784, 10], dtype=np.float32),
    bias=np.zeros([10], dtype=np.float32))

sample_batch = federated_train_data[5][-1]

batch_loss(initial_model, sample_batch)
2.3025854

Nota que alimentamos el cálculo TFF con el modelo inicial definido como un dict , aunque el cuerpo de la función Python que define que consume los parámetros del modelo como model['weight'] y model['bias'] . Los argumentos de la llamada a batch_loss no simplemente pasan al cuerpo de esa función.

¿Qué pasa cuando invocamos batch_loss ? El cuerpo Python de batch_loss ya ha sido trazada y serializado en la célula por encima de donde se definió. TFF actúa como la persona que llama a batch_loss en el momento de cálculo definición, y como el destino de la invocación en el momento batch_loss se invoca. En ambos roles, TFF sirve como puente entre el sistema de tipos abstractos de TFF y los tipos de representación de Python. En el momento de la invocación, TFF aceptará tipos de contenedores más Python estándar ( dict , list , tuple , collections.namedtuple , etc.) como representaciones concretas de tuplas TFF abstractos. Además, aunque como se señaló anteriormente, los cálculos de TFF formalmente solo aceptan un único parámetro, puede usar la sintaxis de llamada familiar de Python con argumentos posicionales y / o de palabras clave en caso de que el tipo de parámetro sea una tupla; funciona como se esperaba.

Descenso de gradiente en un solo lote

Ahora, definamos un cálculo que use esta función de pérdida para realizar un solo paso de descenso de gradiente. Nota cómo en la definición de esta función, utilizamos batch_loss como un subcomponente. Puede invocar un cálculo construido con tff.tf_computation en el interior del cuerpo de otro cálculo, pero en general esto no es necesario - como se señaló anteriormente, debido a la serialización pierde alguna información de depuración, a menudo es preferible para los cálculos más complejos para escribir y probar todos los TensorFlow sin el tff.tf_computation decorador.

@tff.tf_computation(MODEL_TYPE, BATCH_TYPE, tf.float32)
def batch_train(initial_model, batch, learning_rate):
  # Define a group of model variables and set them to `initial_model`. Must
  # be defined outside the @tf.function.
  model_vars = collections.OrderedDict([
      (name, tf.Variable(name=name, initial_value=value))
      for name, value in initial_model.items()
  ])
  optimizer = tf.keras.optimizers.SGD(learning_rate)

  @tf.function
  def _train_on_batch(model_vars, batch):
    # Perform one step of gradient descent using loss from `batch_loss`.
    with tf.GradientTape() as tape:
      loss = forward_pass(model_vars, batch)
    grads = tape.gradient(loss, model_vars)
    optimizer.apply_gradients(
        zip(tf.nest.flatten(grads), tf.nest.flatten(model_vars)))
    return model_vars

  return _train_on_batch(model_vars, batch)
str(batch_train.type_signature)
'(<<weights=float32[784,10],bias=float32[10]>,<x=float32[?,784],y=int32[?]>,float32> -> <weights=float32[784,10],bias=float32[10]>)'

Cuando se invoca una función de Python decorado con tff.tf_computation dentro del cuerpo del otro tal función, la lógica de la computación TFF interior está incrustado (esencialmente, en línea) en la lógica de la exterior. Como se señaló anteriormente, si está escribiendo ambos cálculos, lo más probable es preferible hacer la función interna ( batch_loss en este caso) un pitón regular o tf.function en lugar de un tff.tf_computation . Sin embargo, aquí nos ilustran que llamar a uno tff.tf_computation dentro de otro funciona básicamente como se esperaba. Esto puede ser necesario si, por ejemplo, usted no tiene el código Python que define batch_loss , pero sólo su representación serializada TFF.

Ahora, apliquemos esta función varias veces al modelo inicial para ver si la pérdida disminuye.

model = initial_model
losses = []
for _ in range(5):
  model = batch_train(model, sample_batch, 0.1)
  losses.append(batch_loss(model, sample_batch))
losses
[0.19690022, 0.13176313, 0.10113226, 0.082738124, 0.0703014]

Descenso de gradiente en una secuencia de datos locales

Ahora, ya batch_train parece trabajo, vamos a escribir una función de formación similar local_train que consume toda la secuencia de todos los lotes de un usuario en lugar de sólo un único lote. El nuevo cálculo tendrá que consumen ahora tff.SequenceType(BATCH_TYPE) en lugar de BATCH_TYPE .

LOCAL_DATA_TYPE = tff.SequenceType(BATCH_TYPE)

@tff.federated_computation(MODEL_TYPE, tf.float32, LOCAL_DATA_TYPE)
def local_train(initial_model, learning_rate, all_batches):

  # Mapping function to apply to each batch.
  @tff.federated_computation(MODEL_TYPE, BATCH_TYPE)
  def batch_fn(model, batch):
    return batch_train(model, batch, learning_rate)

  return tff.sequence_reduce(all_batches, initial_model, batch_fn)
str(local_train.type_signature)
'(<<weights=float32[784,10],bias=float32[10]>,float32,<x=float32[?,784],y=int32[?]>*> -> <weights=float32[784,10],bias=float32[10]>)'

Hay bastantes detalles enterrados en esta pequeña sección de código, repasemos uno por uno.

En primer lugar, mientras que podríamos haber implementado esta lógica en su totalidad en TensorFlow, confiando en tf.data.Dataset.reduce para procesar la secuencia de manera similar a como lo hemos hecho anteriormente, hemos optado en esta ocasión para expresar la lógica en el lenguaje de pegamento , como un tff.federated_computation . Hemos utilizado el operador federados tff.sequence_reduce para llevar a cabo la reducción.

El operador tff.sequence_reduce se utiliza de manera similar a tf.data.Dataset.reduce . Se puede pensar que es esencialmente el mismo que tf.data.Dataset.reduce , pero para el uso dentro de los cálculos federados, que como se recordará, no puede contener código TensorFlow. Es un operador de plantilla con un parámetro formal 3-tupla que consiste en una secuencia de T elementos -typed, el estado inicial de la reducción (nos referiremos a ella de forma abstracta como cero) de algún tipo U , y el operador reducción de escriba (<U,T> -> U) que altera el estado de la reducción de la transformación de un solo elemento. El resultado es el estado final de la reducción, después de procesar todos los elementos en orden secuencial. En nuestro ejemplo, el estado de la reducción es el modelo entrenado en un prefijo de los datos y los elementos son lotes de datos.

En segundo lugar, señalar que hemos utilizado otra vez un cálculo ( batch_train ) como un componente dentro de otro ( local_train ), pero no directamente. No podemos usarlo como un operador de reducción porque requiere un parámetro adicional: la tasa de aprendizaje. Para resolver esto, definimos un cómputo embebido federados batch_fn que se une al local_train 's parámetro learning_rate en su cuerpo. Se permite que un cálculo hijo definido de esta manera capture un parámetro formal de su padre siempre que el cálculo hijo no se invoque fuera del cuerpo de su padre. Se puede pensar en este patrón como un equivalente de functools.partial en Python.

La consecuencia práctica de capturar learning_rate de esta manera es, por supuesto, que el mismo valor de la tasa de aprendizaje se utiliza en todos los lotes.

Ahora, vamos a probar la función de formación local que acaba de definir en toda la secuencia de datos desde el mismo usuario que aportó el lote de muestra (dígitos 5 ).

locally_trained_model = local_train(initial_model, 0.1, federated_train_data[5])

¿Funcionó? Para responder a esta pregunta, debemos implementar la evaluación.

Evaluación local

Aquí hay una forma de implementar la evaluación local sumando las pérdidas en todos los lotes de datos (podríamos haber calculado igualmente el promedio; lo dejaremos como un ejercicio para el lector).

@tff.federated_computation(MODEL_TYPE, LOCAL_DATA_TYPE)
def local_eval(model, all_batches):
  # TODO(b/120157713): Replace with `tff.sequence_average()` once implemented.
  return tff.sequence_sum(
      tff.sequence_map(
          tff.federated_computation(lambda b: batch_loss(model, b), BATCH_TYPE),
          all_batches))
str(local_eval.type_signature)
'(<<weights=float32[784,10],bias=float32[10]>,<x=float32[?,784],y=int32[?]>*> -> float32)'

Nuevamente, hay algunos elementos nuevos ilustrados por este código, repasemos uno por uno.

En primer lugar, hemos utilizado dos nuevos operadores federados para el procesamiento de secuencias: tff.sequence_map que tiene una función de mapeo T->U y una secuencia de T , y emite una secuencia de U obtiene aplicando la función punto a punto de mapeo, y tff.sequence_sum que simplemente agrega todos los elementos. Aquí, asignamos cada lote de datos a un valor de pérdida y luego sumamos los valores de pérdida resultantes para calcular la pérdida total.

Tenga en cuenta que podríamos haber utilizado otra vez tff.sequence_reduce , pero esto no sería la mejor opción - el proceso de reducción es, por definición, secuencial, mientras que la cartografía y la suma se pueden calcular en paralelo. Cuando se le da una opción, es mejor quedarse con operadores que no restrinjan las opciones de implementación, de modo que cuando nuestro cálculo TFF se compile en el futuro para implementarse en un entorno específico, uno pueda aprovechar al máximo todas las oportunidades potenciales para una implementación más rápida. , ejecución más escalable y más eficiente en el uso de recursos.

En segundo lugar, cabe destacar que al igual que en local_train , la función de los componentes que necesitamos ( batch_loss ) realiza más parámetros que lo que el operador federados ( tff.sequence_map ) espera, por lo que de nuevo definimos un parcial, esta vez en línea envolviendo directamente una lambda como tff.federated_computation . El uso de envolturas en línea con una función como argumento es la forma recomendada para usar tff.tf_computation a incrustar TensorFlow lógica en la TFF.

Ahora, veamos si nuestro entrenamiento funcionó.

print('initial_model loss =', local_eval(initial_model,
                                         federated_train_data[5]))
print('locally_trained_model loss =',
      local_eval(locally_trained_model, federated_train_data[5]))
initial_model loss = 23.025854
locally_trained_model loss = 0.4348469

De hecho, la pérdida disminuyó. Pero, ¿qué pasa si lo evaluamos sobre los datos de otro usuario?

print('initial_model loss =', local_eval(initial_model,
                                         federated_train_data[0]))
print('locally_trained_model loss =',
      local_eval(locally_trained_model, federated_train_data[0]))
initial_model loss = 23.025854
locally_trained_model loss = 74.50075

Como era de esperar, las cosas empeoraron. El modelo fue entrenado para reconocer 5 , y nunca ha visto un 0 . Esto trae la pregunta: ¿cómo impactó la capacitación local en la calidad del modelo desde la perspectiva global?

Evaluación federada

Este es el punto de nuestro viaje en el que finalmente volvemos a los tipos federados y los cálculos federados, el tema con el que comenzamos. Aquí hay un par de definiciones de tipos de TFF para el modelo que se origina en el servidor y los datos que permanecen en los clientes.

SERVER_MODEL_TYPE = tff.type_at_server(MODEL_TYPE)
CLIENT_DATA_TYPE = tff.type_at_clients(LOCAL_DATA_TYPE)

Con todas las definiciones introducidas hasta ahora, expresar la evaluación federada en TFF es de una sola línea: distribuimos el modelo a los clientes, permitimos que cada cliente invoque la evaluación local en su porción local de datos y luego promediamos la pérdida. Aquí hay una forma de escribir esto.

@tff.federated_computation(SERVER_MODEL_TYPE, CLIENT_DATA_TYPE)
def federated_eval(model, data):
  return tff.federated_mean(
      tff.federated_map(local_eval, [tff.federated_broadcast(model), data]))

Ya hemos visto ejemplos de tff.federated_mean y tff.federated_map en escenarios más sencillos, y en el nivel intuitivo, que funciona como se espera, pero hay más en esta sección del código que ve el ojo, por lo que vamos a repasar con cuidado.

En primer lugar, la ruptura de la defraudado dejar que cada cliente invocación de evaluación local en su porción local de la parte de datos. Como se recordará de las secciones anteriores, local_eval tiene una firma tipo de forma (<MODEL_TYPE, LOCAL_DATA_TYPE> -> float32) .

El operador federados tff.federated_map es una plantilla que acepta como parámetro un 2-tupla que consiste en la función de mapeo de algún tipo T->U y un valor Federados de tipo {T}@CLIENTS (es decir, con los constituyentes miembros de la mismo tipo que el parámetro de la función de mapeo), y devuelve un resultado de tipo {U}@CLIENTS .

Como nos estamos alimentando local_eval como una función de mapeo para aplicar en función de cada cliente, el segundo argumento debe ser de un tipo federados {<MODEL_TYPE, LOCAL_DATA_TYPE>}@CLIENTS , es decir, en la nomenclatura de las secciones anteriores, lo que debería ser una tupla federada. Cada cliente debe mantener un conjunto completo de argumentos para local_eval como consituent miembro. En cambio, estamos alimentándolo a 2 elemento de Python list . ¿Que esta pasando aqui?

De hecho, este es un ejemplo de una conversión de tipo implícito en TFF, similar a los moldes de tipo implícitas que pueda haber encontrado en otro lugar, por ejemplo, cuando usted alimenta a un int a una función que acepta un float . El casting implícito se usa escasamente en este punto, pero planeamos hacerlo más omnipresente en TFF como una forma de minimizar la repetición.

La conversión implícita que se aplica en este caso es la equivalencia entre tuplas federados de la forma {<X,Y>}@Z , y tuplas de federados valores <{X}@Z,{Y}@Z> . Si bien formalmente, estos dos son diferentes firmas de tipos, mirándolo desde la perspectiva de los programadores, cada dispositivo de Z tiene dos unidades de datos de X e Y . Lo que sucede aquí no es diferente zip en Python, y de hecho, ofrecemos un operador tff.federated_zip que permite llevar a cabo tales conversiones explícitamente. Cuando el tff.federated_map se encuentra con una tupla como segundo argumento, simplemente invoca tff.federated_zip para usted.

Teniendo en cuenta lo anterior, ahora debería ser capaz de reconocer la expresión tff.federated_broadcast(model) como la representación de un valor de TFF Tipo {MODEL_TYPE}@CLIENTS , y data como un valor de tipo TFF {LOCAL_DATA_TYPE}@CLIENTS (o simplemente CLIENT_DATA_TYPE ) , los dos conseguir filtró juntos a través de un implícito tff.federated_zip para formar el segundo argumento para tff.federated_map .

El operador tff.federated_broadcast , como era de esperar, simplemente transfiere datos desde el servidor a los clientes.

Ahora, veamos cómo nuestro entrenamiento local afectó la pérdida promedio en el sistema.

print('initial_model loss =', federated_eval(initial_model,
                                             federated_train_data))
print('locally_trained_model loss =',
      federated_eval(locally_trained_model, federated_train_data))
initial_model loss = 23.025852
locally_trained_model loss = 54.432625

De hecho, como se esperaba, la pérdida ha aumentado. Para mejorar el modelo para todos los usuarios, necesitaremos entrenarnos con los datos de todos.

Entrenamiento federado

La forma más sencilla de implementar el entrenamiento federado es entrenar localmente y luego promediar los modelos. Esto usa los mismos bloques de construcción y patrones que ya hemos discutido, como puede ver a continuación.

SERVER_FLOAT_TYPE = tff.type_at_server(tf.float32)


@tff.federated_computation(SERVER_MODEL_TYPE, SERVER_FLOAT_TYPE,
                           CLIENT_DATA_TYPE)
def federated_train(model, learning_rate, data):
  return tff.federated_mean(
      tff.federated_map(local_train, [
          tff.federated_broadcast(model),
          tff.federated_broadcast(learning_rate), data
      ]))

Tenga en cuenta que en la aplicación con todas las funciones de Federados de promedio proporcionada por tff.learning , en lugar de la media de los modelos, preferimos deltas promedio de modelo, por varias razones, por ejemplo, la capacidad de cortar las normas de actualización, para la compresión, etc. .

Veamos si el entrenamiento funciona ejecutando algunas rondas de entrenamiento y comparando la pérdida promedio antes y después.

model = initial_model
learning_rate = 0.1
for round_num in range(5):
  model = federated_train(model, learning_rate, federated_train_data)
  learning_rate = learning_rate * 0.9
  loss = federated_eval(model, federated_train_data)
  print('round {}, loss={}'.format(round_num, loss))
round 0, loss=21.60552406311035
round 1, loss=20.365678787231445
round 2, loss=19.27480125427246
round 3, loss=18.31110954284668
round 4, loss=17.45725440979004

Para completar, ahora también ejecutemos los datos de prueba para confirmar que nuestro modelo se generaliza bien.

print('initial_model test loss =',
      federated_eval(initial_model, federated_test_data))
print('trained_model test loss =', federated_eval(model, federated_test_data))
initial_model test loss = 22.795593
trained_model test loss = 17.278767

Con esto concluye nuestro tutorial.

Por supuesto, nuestro ejemplo simplificado no refleja una serie de cosas que tendría que hacer en un escenario más realista; por ejemplo, no hemos calculado métricas que no sean las pérdidas. Le animamos a estudiar la aplicación de un promedio de federados en tff.learning como un ejemplo más completo, y como una manera de demostrar algunas de las prácticas de codificación que nos gustaría animar.