¡Reserva! Google I / O regresa del 18 al 20 de mayo Regístrese ahora
Se usó la API de Cloud Translation para traducir esta página.
Switch to English

Creación de su propio algoritmo de aprendizaje federado

Ver en TensorFlow.org Ejecutar en Google Colab Ver fuente en GitHub Descargar cuaderno

Antes que empecemos

Antes de comenzar, ejecute lo siguiente para asegurarse de que su entorno esté configurado correctamente. Si no ve un saludo, consulte la Guía de instalación para obtener instrucciones.

!pip install --quiet --upgrade tensorflow-federated-nightly
!pip install --quiet --upgrade nest-asyncio

import nest_asyncio
nest_asyncio.apply()
import collections
import attr
import functools
import numpy as np
import tensorflow as tf
import tensorflow_federated as tff

np.random.seed(0)

En los tutoriales de clasificación de imágenes y generación de texto , aprendimos cómo configurar modelos y canalizaciones de datos para el aprendizaje federado (FL) y tff.learning capacitación federada a través de la capa de API tff.learning de TFF.

Esta es solo la punta del iceberg en lo que respecta a la investigación de FL. En este tutorial, discutimos cómo implementar algoritmos de aprendizaje federado sin tff.learning a la API tff.learning . Nuestro objetivo es lograr lo siguiente:

Metas:

  • Comprender la estructura general de los algoritmos de aprendizaje federado.
  • Explore el núcleo federado de TFF.
  • Utilice el núcleo federado para implementar el promedio federado directamente.

Si bien este tutorial es autónomo, recomendamos leer primero los tutoriales de clasificación de imágenes y generación de texto .

Preparando los datos de entrada

Primero cargamos y preprocesamos el conjunto de datos EMNIST incluido en TFF. Para obtener más detalles, consulte el tutorial de clasificación de imágenes .

emnist_train, emnist_test = tff.simulation.datasets.emnist.load_data()

Para alimentar el conjunto de datos en nuestro modelo, aplanamos los datos y convertimos cada ejemplo en una tupla del formulario (flattened_image_vector, label) .

NUM_CLIENTS = 10
BATCH_SIZE = 20

def preprocess(dataset):

  def batch_format_fn(element):
    """Flatten a batch of EMNIST data and return a (features, label) tuple."""
    return (tf.reshape(element['pixels'], [-1, 784]), 
            tf.reshape(element['label'], [-1, 1]))

  return dataset.batch(BATCH_SIZE).map(batch_format_fn)

Ahora tomamos muestras de una pequeña cantidad de clientes y aplicamos el procesamiento previo anterior a sus conjuntos de datos.

client_ids = np.random.choice(emnist_train.client_ids, size=NUM_CLIENTS, replace=False)

federated_train_data = [preprocess(emnist_train.create_tf_dataset_for_client(x))
  for x in client_ids
]

Preparando el modelo

Usamos el mismo modelo que en el tutorial de clasificación de imágenes . Este modelo (implementado a través de tf.keras ) tiene una sola capa oculta, seguida de una capa softmax.

def create_keras_model():
  return tf.keras.models.Sequential([
      tf.keras.layers.Input(shape=(784,)),
      tf.keras.layers.Dense(10, kernel_initializer='zeros'),
      tf.keras.layers.Softmax(),
  ])

Para utilizar este modelo en TFF, tff.learning.Model modelo de Keras como tff.learning.Model . Esto nos permite realizar el pase directo del modelo dentro de TFF y extraer los resultados del modelo . Para obtener más detalles, consulte también el tutorial de clasificación de imágenes .

def model_fn():
  keras_model = create_keras_model()
  return tff.learning.from_keras_model(
      keras_model,
      input_spec=federated_train_data[0].element_spec,
      loss=tf.keras.losses.SparseCategoricalCrossentropy(),
      metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])

Si bien usamos tf.keras para crear un tff.learning.Model , TFF admite modelos mucho más generales. Estos modelos tienen los siguientes atributos relevantes que capturan los pesos del modelo:

  • trainable_variables : un iterable de los tensores correspondientes a las capas entrenables.
  • non_trainable_variables : un iterable de los tensores correspondientes a capas no entrenables.

Para nuestros propósitos, solo usaremos las trainable_variables . (¡ya que nuestro modelo solo tiene esos!).

Construyendo su propio algoritmo de aprendizaje federado

Si bien la API tff.learning permite crear muchas variantes de Promedio federado, existen otros algoritmos federados que no encajan perfectamente en este marco. Por ejemplo, es posible que desee agregar regularización, recorte o algoritmos más complicados, como el entrenamiento de GAN federado . En su lugar, también puede estar interesado en la analítica federada .

Para estos algoritmos más avanzados, tendremos que escribir nuestro propio algoritmo personalizado usando TFF. En muchos casos, los algoritmos federados tienen 4 componentes principales:

  1. Un paso de transmisión de servidor a cliente.
  2. Un paso de actualización del cliente local.
  3. Un paso de carga de cliente a servidor.
  4. Un paso de actualización del servidor.

En TFF, generalmente representamos los algoritmos federados como un tff.templates.IterativeProcess (al que nos referimos como un IterativeProcess todo momento). Esta es una clase que contiene las funciones initialize y next . Aquí, initialize se usa para inicializar el servidor, y next realizará una ronda de comunicación del algoritmo federado. Escribamos un esqueleto de cómo debería ser nuestro proceso iterativo para FedAvg.

Primero, tenemos una función de inicialización que simplemente crea un tff.learning.Model y devuelve sus pesos entrenables.

def initialize_fn():
  model = model_fn()
  return model.trainable_variables

Esta función se ve bien, pero como veremos más adelante, necesitaremos hacer una pequeña modificación para convertirla en un "cálculo TFF".

También queremos esbozar el next_fn .

def next_fn(server_weights, federated_dataset):
  # Broadcast the server weights to the clients.
  server_weights_at_client = broadcast(server_weights)

  # Each client computes their updated weights.
  client_weights = client_update(federated_dataset, server_weights_at_client)

  # The server averages these updates.
  mean_client_weights = mean(client_weights)

  # The server updates its model.
  server_weights = server_update(mean_client_weights)

  return server_weights

Nos centraremos en implementar estos cuatro componentes por separado. Primero nos enfocamos en las partes que se pueden implementar en TensorFlow puro, es decir, los pasos de actualización del cliente y del servidor.

Bloques de TensorFlow

Actualización del cliente

Usaremos nuestro tff.learning.Model para realizar el entrenamiento del cliente de la misma manera que entrenarías un modelo de TensorFlow. En particular, usaremostf.GradientTape para calcular el gradiente en lotes de datos, luego aplicaremos estos gradientes usando un client_optimizer . Nos enfocamos solo en los pesos entrenables.

@tf.function
def client_update(model, dataset, server_weights, client_optimizer):
  """Performs training (using the server model weights) on the client's dataset."""
  # Initialize the client model with the current server weights.
  client_weights = model.trainable_variables
  # Assign the server weights to the client model.
  tf.nest.map_structure(lambda x, y: x.assign(y),
                        client_weights, server_weights)

  # Use the client_optimizer to update the local model.
  for batch in dataset:
    with tf.GradientTape() as tape:
      # Compute a forward pass on the batch of data
      outputs = model.forward_pass(batch)

    # Compute the corresponding gradient
    grads = tape.gradient(outputs.loss, client_weights)
    grads_and_vars = zip(grads, client_weights)

    # Apply the gradient using a client optimizer.
    client_optimizer.apply_gradients(grads_and_vars)

  return client_weights

Actualización del servidor

La actualización del servidor para FedAvg es más simple que la actualización del cliente. Implementaremos un promedio federado "vanilla", en el que simplemente reemplazamos los pesos del modelo de servidor por el promedio de los pesos del modelo de cliente. Nuevamente, solo nos enfocamos en los pesos entrenables.

@tf.function
def server_update(model, mean_client_weights):
  """Updates the server model weights as the average of the client model weights."""
  model_weights = model.trainable_variables
  # Assign the mean client weights to the server model.
  tf.nest.map_structure(lambda x, y: x.assign(y),
                        model_weights, mean_client_weights)
  return model_weights

El fragmento podría simplificarse simplemente devolviendo mean_client_weights . Sin embargo, las implementaciones más avanzadas de Promedio federado utilizan mean_client_weights con técnicas más sofisticadas, como el impulso o la adaptabilidad.

Desafío : implementar una versión de server_update que actualice los pesos del servidor para que sea el punto medio de model_weights y mean_client_weights. (Nota: ¡Este tipo de enfoque de "punto medio" es análogo al trabajo reciente en el optimizador Lookahead !).

Hasta ahora, solo hemos escrito código puro de TensorFlow. Esto es por diseño, ya que TFF le permite usar gran parte del código de TensorFlow con el que ya está familiarizado. Sin embargo, ahora tenemos que especificar la lógica de orquestación , es decir, la lógica que dicta lo que el servidor transmite al cliente y lo que el cliente carga al servidor.

Esto requerirá el núcleo federado de TFF.

Introducción al núcleo federado

Federated Core (FC) es un conjunto de interfaces de nivel inferior que sirven como base para la API tff.learning . Sin embargo, estas interfaces no se limitan al aprendizaje. De hecho, se pueden utilizar para análisis y muchos otros cálculos sobre datos distribuidos.

En un nivel alto, el núcleo federado es un entorno de desarrollo que permite que la lógica del programa expresada de manera compacta combine el código de TensorFlow con operadores de comunicaciones distribuidas (como sumas distribuidas y transmisiones). El objetivo es dar a los investigadores y profesionales un control explícito sobre la comunicación distribuida en sus sistemas, sin requerir detalles de implementación del sistema (como especificar intercambios de mensajes de red punto a punto).

Un punto clave es que TFF está diseñado para preservar la privacidad. Por lo tanto, permite un control explícito sobre dónde residen los datos, para evitar la acumulación no deseada de datos en la ubicación del servidor centralizado.

Datos federados

Un concepto clave en TFF es "datos federados", que se refiere a una colección de elementos de datos alojados en un grupo de dispositivos en un sistema distribuido (por ejemplo, conjuntos de datos de clientes o pesos del modelo de servidor). Modelamos toda la colección de elementos de datos en todos los dispositivos como un único valor federado .

Por ejemplo, supongamos que tenemos dispositivos cliente que tienen cada uno un flotador que representa la temperatura de un sensor. Podríamos representarlo como un flotador federado por

federated_float_on_clients = tff.FederatedType(tf.float32, tff.CLIENTS)

Los tipos federados se especifican mediante un tipo T de sus componentes miembros (por ejemplo, tf.float32 ) y un grupo G de dispositivos. Nos centraremos en los casos en los que G es tff.CLIENTS o tff.SERVER . Dicho tipo federado se representa como {T}@G , como se muestra a continuación.

str(federated_float_on_clients)
'{float32}@CLIENTS'

¿Por qué nos preocupan tanto las ubicaciones? Un objetivo clave de TFF es permitir la escritura de código que podría implementarse en un sistema distribuido real. Esto significa que es vital razonar sobre qué subconjuntos de dispositivos ejecutan qué código y dónde residen los diferentes datos.

TFF se centra en tres cosas: datos , dónde se colocan los datos y cómo se transforman los datos. Los dos primeros están encapsulados en tipos federados, mientras que el último está encapsulado en cálculos federados .

Cálculos federados

TFF es un entorno de programación funcional fuertemente tipado cuyas unidades básicas son cálculos federados . Estas son piezas de lógica que aceptan valores federados como entrada y devuelven valores federados como salida.

Por ejemplo, supongamos que quisiéramos promediar las temperaturas en los sensores de nuestros clientes. Podríamos definir lo siguiente (usando nuestro flotador federado):

@tff.federated_computation(tff.FederatedType(tf.float32, tff.CLIENTS))
def get_average_temperature(client_temperatures):
  return tff.federated_mean(client_temperatures)

Podría preguntar, ¿en qué se diferencia del decorador tf.function en TensorFlow? La respuesta clave es que el código generado por tff.federated_computation no es código TensorFlow ni Python; Es una especificación de un sistema distribuido en un lenguaje adhesivo interno independiente de la plataforma.

Si bien esto puede parecer complicado, puede pensar en los cálculos de TFF como funciones con firmas de tipo bien definidas. Estas firmas de tipo se pueden consultar directamente.

str(get_average_temperature.type_signature)
'({float32}@CLIENTS -> float32@SERVER)'

Este tff.federated_computation acepta argumentos de tipo federado {float32}@CLIENTS , y devuelve valores de tipo federado {float32}@SERVER . Los cálculos federados también pueden ir de servidor a cliente, de cliente a cliente o de servidor a servidor. Los cálculos federados también se pueden componer como funciones normales, siempre que sus firmas de tipo coincidan.

Para respaldar el desarrollo, TFF le permite invocar un tff.federated_computation como una función de Python. Por ejemplo, podemos llamar

get_average_temperature([68.5, 70.3, 69.8])
69.53334

Cálculos no ansiosos y TensorFlow

Hay dos restricciones clave a tener en cuenta. Primero, cuando el intérprete de Python encuentra un decorador tff.federated_computation , la función se rastrea una vez y se serializa para uso futuro. Debido a la naturaleza descentralizada del aprendizaje federado, este uso futuro puede ocurrir en otros lugares, como un entorno de ejecución remota. Por lo tanto, los cálculos de TFF son fundamentalmente no ansiosos . Este comportamiento es algo análogo al del decorador tf.function en TensorFlow.

En segundo lugar, un cálculo federado solo puede constar de operadores federados (como tff.federated_mean ), no pueden contener operaciones de TensorFlow. El código de TensorFlow debe limitarse a bloques decorados con tff.tf_computation . La mayoría del código de TensorFlow común se puede decorar directamente, como la siguiente función que toma un número y le agrega 0.5 .

@tff.tf_computation(tf.float32)
def add_half(x):
  return tf.add(x, 0.5)

Estos también tienen firmas tipográficas, pero sin ubicaciones . Por ejemplo, podemos llamar

str(add_half.type_signature)
'(float32 -> float32)'

Aquí vemos una diferencia importante entre tff.federated_computation y tff.tf_computation . El primero tiene colocaciones explícitas, mientras que el segundo no.

Podemos usar bloques tff.tf_computation en cálculos federados especificando ubicaciones. Creemos una función que agregue la mitad, pero solo a flotantes federados en los clientes. Podemos hacer esto usando tff.federated_map , que aplica una determinada tff.tf_computation , mientras conserva la ubicación.

@tff.federated_computation(tff.FederatedType(tf.float32, tff.CLIENTS))
def add_half_on_clients(x):
  return tff.federated_map(add_half, x)

Esta función es casi idéntica a add_half , excepto que solo acepta valores con ubicación en tff.CLIENTS y devuelve valores con la misma ubicación. Podemos ver esto en su firma de tipo:

str(add_half_on_clients.type_signature)
'({float32}@CLIENTS -> {float32}@CLIENTS)'

En resumen:

  • TFF opera con valores federados.
  • Cada valor federado tiene un tipo federado , con un tipo (por ejemplo, tf.float32 ) y una ubicación (por ejemplo, tff.CLIENTS ).
  • Los valores federados se pueden transformar mediante cálculos federados , que deben decorarse con tff.federated_computation y una firma de tipo federado.
  • El código de TensorFlow debe estar contenido en bloques con decoradores tff.tf_computation .
  • Luego, estos bloques se pueden incorporar en cálculos federados.

Construyendo su propio algoritmo de aprendizaje federado, revisado

Ahora que hemos echado un vistazo al núcleo federado, podemos crear nuestro propio algoritmo de aprendizaje federado. Recuerde que arriba, definimos un initialize_fn y next_fn para nuestro algoritmo. next_fn hará uso de client_update y server_update que definimos usando código puro de TensorFlow.

Sin embargo, para que nuestro algoritmo sea un cálculo federado, necesitaremos tanto next_fn como initialize_fn para que cada uno sea un tff.federated_computation .

Bloques federados de TensorFlow

Creando el cálculo de inicialización

La función de inicialización será bastante simple: crearemos un modelo usando model_fn . Sin embargo, recuerde que debemos separar nuestro código de TensorFlow usando tff.tf_computation .

@tff.tf_computation
def server_init():
  model = model_fn()
  return model.trainable_variables

Luego podemos pasar esto directamente a un cálculo federado usando tff.federated_value .

@tff.federated_computation
def initialize_fn():
  return tff.federated_value(server_init(), tff.SERVER)

Creando el next_fn

Ahora usamos nuestro código de actualización de cliente y servidor para escribir el algoritmo real. Primero, convertiremos nuestro client_update en un tff.tf_computation que acepta conjuntos de datos de cliente y pesos de servidor, y genera un tensor de pesos de cliente actualizado.

Necesitaremos los tipos correspondientes para decorar adecuadamente nuestra función. Afortunadamente, el tipo de ponderaciones del servidor se puede extraer directamente de nuestro modelo.

whimsy_model = model_fn()
tf_dataset_type = tff.SequenceType(whimsy_model.input_spec)

Veamos la firma del tipo del conjunto de datos. Recuerde que tomamos imágenes de 28 por 28 (con etiquetas enteras) y las aplanamos.

str(tf_dataset_type)
'<float32[?,784],int32[?,1]>*'

También podemos extraer el tipo de ponderaciones del modelo utilizando nuestra función server_init anterior.

model_weights_type = server_init.type_signature.result

¡Examinando la firma de tipo, podremos ver la arquitectura de nuestro modelo!

str(model_weights_type)
'<float32[784,10],float32[10]>'

Ahora podemos crear nuestro tff.tf_computation para la actualización del cliente.

@tff.tf_computation(tf_dataset_type, model_weights_type)
def client_update_fn(tf_dataset, server_weights):
  model = model_fn()
  client_optimizer = tf.keras.optimizers.SGD(learning_rate=0.01)
  return client_update(model, tf_dataset, server_weights, client_optimizer)

La versión tff.tf_computation de la actualización del servidor se puede definir de manera similar, utilizando los tipos que ya extrajimos.

@tff.tf_computation(model_weights_type)
def server_update_fn(mean_client_weights):
  model = model_fn()
  return server_update(model, mean_client_weights)

Por último, pero no menos importante, debemos crear el tff.federated_computation que reúne todo esto. Esta función aceptará dos valores federados , uno correspondiente a los pesos del servidor (con la ubicación tff.SERVER ) y el otro correspondiente a los conjuntos de datos del cliente (con la ubicación tff.CLIENTS ).

Tenga en cuenta que ambos tipos se definieron anteriormente. Simplemente necesitamos darles la ubicación adecuada usando tff.FederatedType .

federated_server_type = tff.FederatedType(model_weights_type, tff.SERVER)
federated_dataset_type = tff.FederatedType(tf_dataset_type, tff.CLIENTS)

¿Recuerda los 4 elementos de un algoritmo FL?

  1. Un paso de transmisión de servidor a cliente.
  2. Un paso de actualización del cliente local.
  3. Un paso de carga de cliente a servidor.
  4. Un paso de actualización del servidor.

Ahora que hemos creado lo anterior, cada parte se puede representar de forma compacta como una sola línea de código TFF. Esta simplicidad es la razón por la que tuvimos que tener un cuidado especial para especificar cosas como tipos federados.

@tff.federated_computation(federated_server_type, federated_dataset_type)
def next_fn(server_weights, federated_dataset):
  # Broadcast the server weights to the clients.
  server_weights_at_client = tff.federated_broadcast(server_weights)

  # Each client computes their updated weights.
  client_weights = tff.federated_map(
      client_update_fn, (federated_dataset, server_weights_at_client))

  # The server averages these updates.
  mean_client_weights = tff.federated_mean(client_weights)

  # The server updates its model.
  server_weights = tff.federated_map(server_update_fn, mean_client_weights)

  return server_weights

Ahora tenemos un tff.federated_computation tanto para la inicialización del algoritmo como para ejecutar un paso del algoritmo. Para finalizar nuestro algoritmo, los pasamos a tff.templates.IterativeProcess .

federated_algorithm = tff.templates.IterativeProcess(
    initialize_fn=initialize_fn,
    next_fn=next_fn
)

Veamos la firma de tipo de las funciones initialize y next de nuestro proceso iterativo.

str(federated_algorithm.initialize.type_signature)
'( -> <float32[784,10],float32[10]>@SERVER)'

Esto refleja el hecho de que federated_algorithm.initialize es una función sin argumentos que devuelve un modelo de una sola capa (con una matriz de ponderación de 784 por 10 y 10 unidades de sesgo).

str(federated_algorithm.next.type_signature)
'(<<float32[784,10],float32[10]>@SERVER,{<float32[?,784],int32[?,1]>*}@CLIENTS> -> <float32[784,10],float32[10]>@SERVER)'

Aquí, vemos que federated_algorithm.next acepta un modelo de servidor y datos de cliente, y devuelve un modelo de servidor actualizado.

Evaluando el algoritmo

Hagamos algunas rondas y veamos cómo cambia la pérdida. Primero, definiremos una función de evaluación usando el enfoque centralizado discutido en el segundo tutorial.

Primero creamos un conjunto de datos de evaluación centralizado y luego aplicamos el mismo procesamiento previo que usamos para los datos de entrenamiento.

Tenga en cuenta que solo take los primeros 1000 elementos por razones de eficiencia computacional, pero normalmente usamos el conjunto de datos de prueba completo.

central_emnist_test = emnist_test.create_tf_dataset_from_all_clients().take(1000)
central_emnist_test = preprocess(central_emnist_test)

A continuación, escribimos una función que acepta un estado de servidor y usa Keras para evaluar el conjunto de datos de prueba. Si está familiarizado con tf.Keras , todo esto le resultará familiar, ¡aunque tenga en cuenta el uso de set_weights !

def evaluate(server_state):
  keras_model = create_keras_model()
  keras_model.compile(
      loss=tf.keras.losses.SparseCategoricalCrossentropy(),
      metrics=[tf.keras.metrics.SparseCategoricalAccuracy()]  
  )
  keras_model.set_weights(server_state)
  keras_model.evaluate(central_emnist_test)

Ahora, inicialicemos nuestro algoritmo y evaluemos en el conjunto de prueba.

server_state = federated_algorithm.initialize()
evaluate(server_state)
50/50 [==============================] - 0s 2ms/step - loss: 2.3026 - sparse_categorical_accuracy: 0.0910

Entrenemos durante algunas rondas y veamos si algo cambia.

for round in range(15):
  server_state = federated_algorithm.next(server_state, federated_train_data)
evaluate(server_state)
50/50 [==============================] - 0s 1ms/step - loss: 2.1706 - sparse_categorical_accuracy: 0.2440

Vemos una ligera disminución en la función de pérdida. Si bien el salto es pequeño, solo hemos realizado 10 rondas de entrenamiento y en un pequeño subconjunto de clientes. Para ver mejores resultados, es posible que tengamos que hacer cientos, si no miles, de rondas.

Modificando nuestro algoritmo

En este punto, detengámonos y pensemos en lo que hemos logrado. Implementamos el promedio federado directamente mediante la combinación de código puro de TensorFlow (para las actualizaciones del cliente y del servidor) con cálculos federados del núcleo federado de TFF.

Para realizar un aprendizaje más sofisticado, simplemente podemos alterar lo que tenemos arriba. En particular, al editar el código TF puro anterior, podemos cambiar la forma en que el cliente realiza el entrenamiento o cómo el servidor actualiza su modelo.

Desafío: agregue recorte de degradado a la función client_update .

Si quisiéramos hacer cambios más importantes, también podríamos hacer que el servidor almacene y transmita más datos. Por ejemplo, el servidor también podría almacenar la tasa de aprendizaje del cliente y hacer que decaiga con el tiempo. Tenga en cuenta que esto requerirá cambios en las firmas de tipo utilizadas en las llamadas tff.tf_computation anteriores.

Desafío más difícil: implementar promedios federados con disminución de la tasa de aprendizaje en los clientes.

En este punto, puede comenzar a darse cuenta de cuánta flexibilidad hay en lo que puede implementar en este marco. Para obtener ideas (incluida la respuesta al desafío más difícil anterior), puede ver el código fuente de tff.learning.build_federated_averaging_process , o consultar varios proyectos de investigación utilizando TFF.