Se usó la API de Cloud Translation para traducir esta página.
Switch to English

Aprendizaje federado para clasificación de imágenes

Ver en TensorFlow.org Ejecutar en Google Colab Ver fuente en GitHub

En este tutorial, usamos el ejemplo de entrenamiento clásico de MNIST para presentar la capa de API de aprendizaje federado (FL) de TFF, tff.learning , un conjunto de interfaces de nivel superior que se pueden usar para realizar tipos comunes de tareas de aprendizaje federado, como entrenamiento federado, contra modelos proporcionados por el usuario implementados en TensorFlow.

Este tutorial, y la API de aprendizaje federado, están destinados principalmente a usuarios que desean conectar sus propios modelos de TensorFlow a TFF, tratando este último principalmente como una caja negra. Para una comprensión más profunda de TFF y cómo implementar sus propios algoritmos de aprendizaje federado, consulte los tutoriales sobre FC Core API - Algoritmos federados personalizados Parte 1 y Parte 2 .

Para obtener más información sobre tff.learning , continúe con el aprendizaje federado para la generación de texto , tutorial que además de cubrir los modelos recurrentes, también demuestra la carga de un modelo de Keras serializado previamente entrenado para refinamiento con aprendizaje federado combinado con evaluación usando Keras.

Antes que empecemos

Antes de comenzar, ejecute lo siguiente para asegurarse de que su entorno esté configurado correctamente. Si no ve un saludo, consulte la Guía de instalación para obtener instrucciones.


!pip install --quiet --upgrade tensorflow_federated_nightly
!pip install --quiet --upgrade nest_asyncio

import nest_asyncio
nest_asyncio.apply()

%load_ext tensorboard
import collections

import numpy as np
import tensorflow as tf
import tensorflow_federated as tff

np.random.seed(0)

tff.federated_computation(lambda: 'Hello, World!')()
b'Hello, World!'

Preparando los datos de entrada

Empecemos por los datos. El aprendizaje federado requiere un conjunto de datos federados, es decir, una colección de datos de múltiples usuarios. Los datos federados generalmente no son iid , lo que plantea un conjunto único de desafíos.

Para facilitar la experimentación, sembramos el repositorio de TFF con algunos conjuntos de datos, incluida una versión federada de MNIST que contiene una versión del conjunto de datos NIST original que ha sido reprocesado con Leaf para que los datos estén codificados por el escritor original de los dígitos. Dado que cada escritor tiene un estilo único, este conjunto de datos exhibe el tipo de comportamiento no iid que se espera de los conjuntos de datos federados.

Así es como podemos cargarlo.

emnist_train, emnist_test = tff.simulation.datasets.emnist.load_data()

Los conjuntos de datos devueltos por load_data() son instancias de tff.simulation.ClientData , una interfaz que le permite enumerar el conjunto de usuarios, construir un tf.data.Dataset que representa los datos de un usuario en particular y consultar el estructura de elementos individuales. A continuación, le mostramos cómo puede utilizar esta interfaz para explorar el contenido del conjunto de datos. Tenga en cuenta que, si bien esta interfaz le permite iterar sobre los identificadores de clientes, esta es solo una característica de los datos de simulación. Como verá en breve, el marco de aprendizaje federado no utiliza las identidades de los clientes; su único propósito es permitirle seleccionar subconjuntos de datos para simulaciones.

len(emnist_train.client_ids)
3383
emnist_train.element_type_structure
OrderedDict([('pixels', TensorSpec(shape=(28, 28), dtype=tf.float32, name=None)), ('label', TensorSpec(shape=(), dtype=tf.int32, name=None))])
example_dataset = emnist_train.create_tf_dataset_for_client(
    emnist_train.client_ids[0])

example_element = next(iter(example_dataset))

example_element['label'].numpy()
1
from matplotlib import pyplot as plt
plt.imshow(example_element['pixels'].numpy(), cmap='gray', aspect='equal')
plt.grid(False)
_ = plt.show()

png

Explorando la heterogeneidad en datos federados

Los datos federados generalmente no son iid , los usuarios generalmente tienen diferentes distribuciones de datos según los patrones de uso. Algunos clientes pueden tener menos ejemplos de entrenamiento en el dispositivo, debido a la escasez de datos a nivel local, mientras que algunos clientes tendrán ejemplos de entrenamiento más que suficientes. Exploremos este concepto de heterogeneidad de datos típico de un sistema federado con los datos EMNIST que tenemos disponibles. Es importante tener en cuenta que este análisis profundo de los datos de un cliente solo está disponible para nosotros porque este es un entorno de simulación donde todos los datos están disponibles localmente. En un entorno federado de producción real, no podría inspeccionar los datos de un solo cliente.

Primero, tomemos una muestra de los datos de un cliente para tener una idea de los ejemplos en un dispositivo simulado. Debido a que el conjunto de datos que estamos usando ha sido codificado por un escritor único, los datos de un cliente representan la escritura a mano de una persona para una muestra de los dígitos del 0 al 9, simulando el "patrón de uso" único de un usuario.

## Example MNIST digits for one client
figure = plt.figure(figsize=(20, 4))
j = 0

for example in example_dataset.take(40):
  plt.subplot(4, 10, j+1)
  plt.imshow(example['pixels'].numpy(), cmap='gray', aspect='equal')
  plt.axis('off')
  j += 1

png

Ahora visualicemos la cantidad de ejemplos en cada cliente para cada etiqueta de dígito MNIST. En el entorno federado, la cantidad de ejemplos en cada cliente puede variar bastante, dependiendo del comportamiento del usuario.

# Number of examples per layer for a sample of clients
f = plt.figure(figsize=(12, 7))
f.suptitle('Label Counts for a Sample of Clients')
for i in range(6):
  client_dataset = emnist_train.create_tf_dataset_for_client(
      emnist_train.client_ids[i])
  plot_data = collections.defaultdict(list)
  for example in client_dataset:
    # Append counts individually per label to make plots
    # more colorful instead of one color per plot.
    label = example['label'].numpy()
    plot_data[label].append(label)
  plt.subplot(2, 3, i+1)
  plt.title('Client {}'.format(i))
  for j in range(10):
    plt.hist(
        plot_data[j],
        density=False,
        bins=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10])

png

Ahora visualicemos la imagen media por cliente para cada etiqueta MNIST. Este código producirá la media de cada valor de píxel para todos los ejemplos del usuario para una etiqueta. Veremos que la imagen media de un cliente para un dígito se verá diferente a la imagen media de otro cliente para el mismo dígito, debido al estilo de escritura único de cada persona. Podemos reflexionar sobre cómo cada ronda de capacitación local empujará el modelo en una dirección diferente en cada cliente, ya que estamos aprendiendo de los datos únicos de ese usuario en esa ronda local. Más adelante en el tutorial veremos cómo podemos tomar cada actualización del modelo de todos los clientes y agregarlas en nuestro nuevo modelo global, que ha aprendido de los datos únicos de cada uno de nuestros clientes.

# Each client has different mean images, meaning each client will be nudging
# the model in their own directions locally.

for i in range(5):
  client_dataset = emnist_train.create_tf_dataset_for_client(
      emnist_train.client_ids[i])
  plot_data = collections.defaultdict(list)
  for example in client_dataset:
    plot_data[example['label'].numpy()].append(example['pixels'].numpy())
  f = plt.figure(i, figsize=(12, 5))
  f.suptitle("Client #{}'s Mean Image Per Label".format(i))
  for j in range(10):
    mean_img = np.mean(plot_data[j], 0)
    plt.subplot(2, 5, j+1)
    plt.imshow(mean_img.reshape((28, 28)))
    plt.axis('off')

png

png

png

png

png

Los datos del usuario pueden ser ruidosos y estar etiquetados de forma poco fiable. Por ejemplo, mirando los datos del Cliente # 2 arriba, podemos ver que para la etiqueta 2, es posible que haya algunos ejemplos mal etiquetados creando una imagen más ruidosa.

Preprocesar los datos de entrada

Dado que los datos ya son un tf.data.Dataset , el preprocesamiento se puede realizar mediante transformaciones de Dataset. Aquí, 28x28 imágenes de 28x28 en arreglos de 784 elementos, mezclamos los ejemplos individuales, los organizamos en lotes y cambiamos el nombre de las características de pixels y label x e y para usar con Keras. También lanzamos una repeat sobre el conjunto de datos para ejecutar varias épocas.

NUM_CLIENTS = 10
NUM_EPOCHS = 5
BATCH_SIZE = 20
SHUFFLE_BUFFER = 100
PREFETCH_BUFFER= 10

def preprocess(dataset):

  def batch_format_fn(element):
    """Flatten a batch `pixels` and return the features as an `OrderedDict`."""
    return collections.OrderedDict(
        x=tf.reshape(element['pixels'], [-1, 784]),
        y=tf.reshape(element['label'], [-1, 1]))

  return dataset.repeat(NUM_EPOCHS).shuffle(SHUFFLE_BUFFER).batch(
      BATCH_SIZE).map(batch_format_fn).prefetch(PREFETCH_BUFFER)

Verifiquemos que esto funcionó.

preprocessed_example_dataset = preprocess(example_dataset)

sample_batch = tf.nest.map_structure(lambda x: x.numpy(),
                                     next(iter(preprocessed_example_dataset)))

sample_batch
OrderedDict([('x', array([[1., 1., 1., ..., 1., 1., 1.],
       [1., 1., 1., ..., 1., 1., 1.],
       [1., 1., 1., ..., 1., 1., 1.],
       ...,
       [1., 1., 1., ..., 1., 1., 1.],
       [1., 1., 1., ..., 1., 1., 1.],
       [1., 1., 1., ..., 1., 1., 1.]], dtype=float32)), ('y', array([[2],
       [1],
       [2],
       [3],
       [6],
       [0],
       [1],
       [4],
       [1],
       [0],
       [6],
       [9],
       [9],
       [3],
       [6],
       [1],
       [4],
       [8],
       [0],
       [2]], dtype=int32))])

Tenemos casi todos los componentes básicos para construir conjuntos de datos federados.

Una de las formas de alimentar datos federados a TFF en una simulación es simplemente como una lista de Python, con cada elemento de la lista que contiene los datos de un usuario individual, ya sea como una lista o como un tf.data.Dataset . Como ya tenemos una interfaz que proporciona lo último, usémosla.

Aquí hay una función auxiliar simple que construirá una lista de conjuntos de datos a partir del conjunto dado de usuarios como entrada para una ronda de capacitación o evaluación.

def make_federated_data(client_data, client_ids):
  return [
      preprocess(client_data.create_tf_dataset_for_client(x))
      for x in client_ids
  ]

Ahora bien, ¿cómo elegimos a los clientes?

En un escenario típico de capacitación federado, estamos tratando con una población potencialmente muy grande de dispositivos de usuario, de los cuales solo una fracción puede estar disponible para capacitación en un momento dado. Este es el caso, por ejemplo, cuando los dispositivos del cliente son teléfonos móviles que participan en la capacitación solo cuando están conectados a una fuente de alimentación, fuera de una red con medidor y de otra manera inactivos.

Por supuesto, estamos en un entorno de simulación y todos los datos están disponibles localmente. Por lo general, entonces, al ejecutar simulaciones, simplemente tomaríamos muestras de un subconjunto aleatorio de los clientes que participarán en cada ronda de entrenamiento, generalmente diferentes en cada ronda.

Dicho esto, como puede averiguar al estudiar el artículo sobre el algoritmo de promediado federado , lograr la convergencia en un sistema con subconjuntos de clientes muestreados al azar en cada ronda puede llevar un tiempo, y no sería práctico tener que ejecutar cientos de rondas en este tutorial interactivo.

En su lugar, lo que haremos es tomar una muestra del conjunto de clientes una vez y reutilizar el mismo conjunto en todas las rondas para acelerar la convergencia (sobreajuste intencionalmente a los datos de estos pocos usuarios). Lo dejamos como ejercicio para que el lector modifique este tutorial para simular un muestreo aleatorio; es bastante fácil de hacer (una vez que lo haga, tenga en cuenta que lograr que el modelo converja puede llevar un tiempo).

sample_clients = emnist_train.client_ids[0:NUM_CLIENTS]

federated_train_data = make_federated_data(emnist_train, sample_clients)

print('Number of client datasets: {l}'.format(l=len(federated_train_data)))
print('First dataset: {d}'.format(d=federated_train_data[0]))
Number of client datasets: 10
First dataset: <DatasetV1Adapter shapes: OrderedDict([(x, (None, 784)), (y, (None, 1))]), types: OrderedDict([(x, tf.float32), (y, tf.int32)])>

Creando un modelo con Keras

Si está utilizando Keras, es probable que ya tenga un código que construya un modelo de Keras. He aquí un ejemplo de un modelo simple que será suficiente para nuestras necesidades.

def create_keras_model():
  return tf.keras.models.Sequential([
      tf.keras.layers.Input(shape=(784,)),
      tf.keras.layers.Dense(10, kernel_initializer='zeros'),
      tf.keras.layers.Softmax(),
  ])

Para usar cualquier modelo con TFF, debe estar envuelto en una instancia de la interfaz tff.learning.Model , que expone métodos para sellar el pase directo del modelo, propiedades de metadatos, etc., de manera similar a Keras, pero también introduce métodos adicionales. elementos, como las formas de controlar el proceso de cálculo de métricas federadas. No nos preocupemos por esto por ahora; Si tiene un modelo de Keras como el que acabamos de definir anteriormente, puede hacer que TFF lo envuelva invocando tff.learning.from_keras_model , pasando el modelo y un lote de datos de muestra como argumentos, como se muestra a continuación.

def model_fn():
  # We _must_ create a new model here, and _not_ capture it from an external
  # scope. TFF will call this within different graph contexts.
  keras_model = create_keras_model()
  return tff.learning.from_keras_model(
      keras_model,
      input_spec=preprocessed_example_dataset.element_spec,
      loss=tf.keras.losses.SparseCategoricalCrossentropy(),
      metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])

Entrenamiento del modelo en datos federados

Ahora que tenemos un modelo envuelto como tff.learning.Model para usar con TFF, podemos permitir que TFF construya un algoritmo de promediado federado invocando la función auxiliar tff.learning.build_federated_averaging_process , como sigue.

Tenga en cuenta que el argumento debe ser un constructor (como model_fn arriba), no una instancia ya construida, de modo que la construcción de su modelo pueda ocurrir en un contexto controlado por TFF (si tiene curiosidad sobre las razones de esto, le recomendamos que lea el tutorial de seguimiento sobre algoritmos personalizados ).

Una nota crítica sobre el algoritmo de promediado federado a continuación, hay 2 optimizadores: un _client optimizer y un _server optimizer . El optimizador de _client solo se utiliza para calcular las actualizaciones del modelo local en cada cliente. El optimizador de _server aplica la actualización promediada al modelo global en el servidor. En particular, esto significa que la elección del optimizador y la tasa de aprendizaje utilizados pueden tener que ser diferentes a las que ha utilizado para entrenar el modelo en un conjunto de datos iid estándar. Recomendamos comenzar con SGD regular, posiblemente con una tasa de aprendizaje menor de lo habitual. La tasa de aprendizaje que usamos no se ha ajustado cuidadosamente, no dude en experimentar.

iterative_process = tff.learning.build_federated_averaging_process(
    model_fn,
    client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=0.02),
    server_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=1.0))

¿Lo que acaba de suceder? TFF ha construido un par de cálculos federados y los ha empaquetado en un tff.templates.IterativeProcess en el que estos cálculos están disponibles como un par de propiedades initialize y next .

En pocas palabras, los cálculos federados son programas en el lenguaje interno de TFF que pueden expresar varios algoritmos federados (puede encontrar más sobre esto en el tutorial de algoritmos personalizados ). En este caso, los dos cálculos generados y empaquetados en iterative_process implementan el promedio federado .

El objetivo de TFF es definir los cálculos de manera que puedan ejecutarse en entornos reales de aprendizaje federado, pero actualmente solo se implementa el tiempo de ejecución de simulación de ejecución local. Para ejecutar un cálculo en un simulador, simplemente lo invoca como una función de Python. Este entorno interpretado por defecto no está diseñado para un alto rendimiento, pero será suficiente para este tutorial; Esperamos proporcionar tiempos de ejecución de simulación de mayor rendimiento para facilitar la investigación a mayor escala en versiones futuras.

Comencemos con el cálculo de initialize . Como es el caso de todos los cálculos federados, puede considerarlo como una función. El cálculo no toma argumentos y devuelve un resultado: la representación del estado del proceso de Promedio federado en el servidor. Si bien no queremos profundizar en los detalles de TFF, puede ser instructivo ver cómo se ve este estado. Puedes visualizarlo de la siguiente manera.

str(iterative_process.initialize.type_signature)
'( -> <model=<trainable=<float32[784,10],float32[10]>,non_trainable=<>>,optimizer_state=<int64>,delta_aggregate_state=<>,model_broadcast_state=<>>@SERVER)'

Si bien la firma de tipo anterior puede parecer un poco críptica al principio, puede reconocer que el estado del servidor consiste en un model (los parámetros del modelo inicial para MNIST que se distribuirán a todos los dispositivos) y optimizer_state (información adicional mantenida por el servidor, como el número de rondas que se utilizarán para los programas de hiperparámetros, etc.).

Invoquemos el cálculo de initialize para construir el estado del servidor.

state = iterative_process.initialize()

El segundo del par de cálculos federados, a next , representa una ronda única de Promedio federado, que consiste en enviar el estado del servidor (incluidos los parámetros del modelo) a los clientes, entrenamiento en el dispositivo en sus datos locales, recopilación y promediado de actualizaciones del modelo. y producir un nuevo modelo actualizado en el servidor.

Conceptualmente, puede pensar en el next como si tuviera una firma de tipo funcional con el siguiente aspecto.

SERVER_STATE, FEDERATED_DATA -> SERVER_STATE, TRAINING_METRICS

En particular, uno debería pensar en next() no como una función que se ejecuta en un servidor, sino más bien como una representación funcional declarativa de todo el cálculo descentralizado: algunas de las entradas son proporcionadas por el servidor ( SERVER_STATE ), pero cada participante dispositivo aporta su propio conjunto de datos local.

Realicemos una sola ronda de entrenamiento y visualicemos los resultados. Podemos utilizar los datos federados que ya hemos generado anteriormente para una muestra de usuarios.

state, metrics = iterative_process.next(state, federated_train_data)
print('round  1, metrics={}'.format(metrics))
round  1, metrics=<broadcast=<>,aggregation=<>,train=<sparse_categorical_accuracy=0.12037037312984467,loss=3.0108425617218018>>

Hagamos algunas rondas más. Como se mencionó anteriormente, normalmente en este punto usted escogería un subconjunto de sus datos de simulación de una nueva muestra de usuarios seleccionada al azar para cada ronda con el fin de simular una implementación realista en la que los usuarios van y vienen continuamente, pero en este cuaderno interactivo, por En aras de la demostración, simplemente reutilizaremos a los mismos usuarios, de modo que el sistema converja rápidamente.

NUM_ROUNDS = 11
for round_num in range(2, NUM_ROUNDS):
  state, metrics = iterative_process.next(state, federated_train_data)
  print('round {:2d}, metrics={}'.format(round_num, metrics))
round  2, metrics=<broadcast=<>,aggregation=<>,train=<sparse_categorical_accuracy=0.14814814925193787,loss=2.8865506649017334>>
round  3, metrics=<broadcast=<>,aggregation=<>,train=<sparse_categorical_accuracy=0.148765429854393,loss=2.9079062938690186>>
round  4, metrics=<broadcast=<>,aggregation=<>,train=<sparse_categorical_accuracy=0.17633745074272156,loss=2.724686622619629>>
round  5, metrics=<broadcast=<>,aggregation=<>,train=<sparse_categorical_accuracy=0.20226337015628815,loss=2.6334855556488037>>
round  6, metrics=<broadcast=<>,aggregation=<>,train=<sparse_categorical_accuracy=0.22427983582019806,loss=2.5482592582702637>>
round  7, metrics=<broadcast=<>,aggregation=<>,train=<sparse_categorical_accuracy=0.24094650149345398,loss=2.4472343921661377>>
round  8, metrics=<broadcast=<>,aggregation=<>,train=<sparse_categorical_accuracy=0.259876549243927,loss=2.3809611797332764>>
round  9, metrics=<broadcast=<>,aggregation=<>,train=<sparse_categorical_accuracy=0.29814815521240234,loss=2.156442403793335>>
round 10, metrics=<broadcast=<>,aggregation=<>,train=<sparse_categorical_accuracy=0.31687241792678833,loss=2.122845411300659>>

La pérdida de entrenamiento está disminuyendo después de cada ronda de entrenamiento federado, lo que indica que el modelo está convergiendo. Hay algunas advertencias importantes con estas métricas de capacitación; sin embargo, consulte la sección sobre Evaluación más adelante en este tutorial.

Visualización de métricas de modelo en TensorBoard

A continuación, visualicemos las métricas de estos cálculos federados usando Tensorboard.

Comencemos por crear el directorio y el escritor de resumen correspondiente para escribir las métricas.


logdir = "/tmp/logs/scalars/training/"
summary_writer = tf.summary.create_file_writer(logdir)
state = iterative_process.initialize()

Trace las métricas escalares relevantes con el mismo escritor de resumen.


with summary_writer.as_default():
  for round_num in range(1, NUM_ROUNDS):
    state, metrics = iterative_process.next(state, federated_train_data)
    for name, value in metrics.train._asdict().items():
      tf.summary.scalar(name, value, step=round_num)

Inicie TensorBoard con el directorio de registro raíz especificado anteriormente. Los datos pueden tardar unos segundos en cargarse.


%tensorboard --logdir /tmp/logs/scalars/ --port=0

# Run this this cell to clean your directory of old output for future graphs from this directory.
!rm -R /tmp/logs/scalars/*

Para ver las métricas de evaluación de la misma manera, puede crear una carpeta de evaluación separada, como "logs / scalars / eval", para escribir en TensorBoard.

Personalizar la implementación del modelo

Keras es la API de modelo de alto nivel recomendada para TensorFlow , y recomendamos el uso de modelos de Keras (a través de tff.learning.from_keras_model ) en TFF siempre que sea posible.

Sin embargo, tff.learning proporciona una interfaz de modelo de nivel inferior, tff.learning.Model , que expone la funcionalidad mínima necesaria para usar un modelo de aprendizaje federado. La implementación directa de esta interfaz (posiblemente todavía usando bloques de construcción como tf.keras.layers ) permite la máxima personalización sin modificar los tf.keras.layers internos de los algoritmos de aprendizaje federados.

Así que hagámoslo todo de nuevo desde cero.

Definición de variables de modelo, pase directo y métricas

El primer paso es identificar las variables de TensorFlow con las que vamos a trabajar. Para que el siguiente código sea más legible, definamos una estructura de datos para representar el conjunto completo. Esto incluirá variables como weights y bias que entrenaremos, así como variables que contendrán varias estadísticas acumulativas y contadores que actualizaremos durante el entrenamiento, como loss_sum , accuracy_sum y num_examples .

MnistVariables = collections.namedtuple(
    'MnistVariables', 'weights bias num_examples loss_sum accuracy_sum')

Aquí hay un método que crea las variables. En aras de la simplicidad, representamos todas las estadísticas como tf.float32 , ya que eso eliminará la necesidad de conversiones de tipos en una etapa posterior. Envolver los inicializadores de variables como lambdas es un requisito impuesto por las variables de recursos .

def create_mnist_variables():
  return MnistVariables(
      weights=tf.Variable(
          lambda: tf.zeros(dtype=tf.float32, shape=(784, 10)),
          name='weights',
          trainable=True),
      bias=tf.Variable(
          lambda: tf.zeros(dtype=tf.float32, shape=(10)),
          name='bias',
          trainable=True),
      num_examples=tf.Variable(0.0, name='num_examples', trainable=False),
      loss_sum=tf.Variable(0.0, name='loss_sum', trainable=False),
      accuracy_sum=tf.Variable(0.0, name='accuracy_sum', trainable=False))

Con las variables para los parámetros del modelo y las estadísticas acumuladas en su lugar, ahora podemos definir el método de paso hacia adelante que calcula la pérdida, emite predicciones y actualiza las estadísticas acumuladas para un solo lote de datos de entrada, como se indica a continuación.

def mnist_forward_pass(variables, batch):
  y = tf.nn.softmax(tf.matmul(batch['x'], variables.weights) + variables.bias)
  predictions = tf.cast(tf.argmax(y, 1), tf.int32)

  flat_labels = tf.reshape(batch['y'], [-1])
  loss = -tf.reduce_mean(
      tf.reduce_sum(tf.one_hot(flat_labels, 10) * tf.math.log(y), axis=[1]))
  accuracy = tf.reduce_mean(
      tf.cast(tf.equal(predictions, flat_labels), tf.float32))

  num_examples = tf.cast(tf.size(batch['y']), tf.float32)

  variables.num_examples.assign_add(num_examples)
  variables.loss_sum.assign_add(loss * num_examples)
  variables.accuracy_sum.assign_add(accuracy * num_examples)

  return loss, predictions

A continuación, definimos una función que devuelve un conjunto de métricas locales, nuevamente usando TensorFlow. Estos son los valores (además de las actualizaciones del modelo, que se manejan automáticamente) que son elegibles para agregarse al servidor en un proceso de evaluación o aprendizaje federado.

Aquí, simplemente devolvemos la loss y la accuracy promedio, así como los num_examples , que necesitaremos para ponderar correctamente las contribuciones de diferentes usuarios al calcular agregados federados.

def get_local_mnist_metrics(variables):
  return collections.OrderedDict(
      num_examples=variables.num_examples,
      loss=variables.loss_sum / variables.num_examples,
      accuracy=variables.accuracy_sum / variables.num_examples)

Finalmente, necesitamos determinar cómo agregar las métricas locales emitidas por cada dispositivo a través de get_local_mnist_metrics . Esta es la única parte del código que no está escrita en TensorFlow; es un cálculo federado expresado en TFF. Si desea profundizar más, lea el tutorial de algoritmos personalizados , pero en la mayoría de las aplicaciones, realmente no es necesario; las variantes del patrón que se muestra a continuación deberían ser suficientes. Así es como se ve:

@tff.federated_computation
def aggregate_mnist_metrics_across_clients(metrics):
  return collections.OrderedDict(
      num_examples=tff.federated_sum(metrics.num_examples),
      loss=tff.federated_mean(metrics.loss, metrics.num_examples),
      accuracy=tff.federated_mean(metrics.accuracy, metrics.num_examples))
  

El argumento de metrics entrada corresponde al OrderedDict devuelto por get_local_mnist_metrics arriba, pero críticamente los valores ya no son tf.Tensors : están " tff.Value " como tff.Value s, para dejar en claro que ya no puede manipularlos con TensorFlow, pero solo utilizando los operadores federados de TFF como tff.federated_mean y tff.federated_sum . El diccionario devuelto de agregados globales define el conjunto de métricas que estarán disponibles en el servidor.

Construyendo una instancia de tff.learning.Model

Con todo lo anterior en su lugar, estamos listos para construir una representación de modelo para usar con TFF similar a la que se genera para usted cuando deja que TFF ingiera un modelo de Keras.

class MnistModel(tff.learning.Model):

  def __init__(self):
    self._variables = create_mnist_variables()

  @property
  def trainable_variables(self):
    return [self._variables.weights, self._variables.bias]

  @property
  def non_trainable_variables(self):
    return []

  @property
  def local_variables(self):
    return [
        self._variables.num_examples, self._variables.loss_sum,
        self._variables.accuracy_sum
    ]

  @property
  def input_spec(self):
    return collections.OrderedDict(
        x=tf.TensorSpec([None, 784], tf.float32),
        y=tf.TensorSpec([None, 1], tf.int32))

  @tf.function
  def forward_pass(self, batch, training=True):
    del training
    loss, predictions = mnist_forward_pass(self._variables, batch)
    num_exmaples = tf.shape(batch['x'])[0]
    return tff.learning.BatchOutput(
        loss=loss, predictions=predictions, num_examples=num_exmaples)

  @tf.function
  def report_local_outputs(self):
    return get_local_mnist_metrics(self._variables)

  @property
  def federated_output_computation(self):
    return aggregate_mnist_metrics_across_clients

Como puede ver, los métodos abstractos y las propiedades definidas por tff.learning.Model corresponden a los fragmentos de código en la sección anterior que introdujeron las variables y definieron la pérdida y las estadísticas.

Aquí hay algunos puntos que vale la pena destacar:

  • Todos los estados que usará su modelo deben capturarse como variables de TensorFlow, ya que TFF no usa Python en tiempo de ejecución (recuerde que su código debe escribirse de manera que pueda implementarse en dispositivos móviles; consulte el tutorial de algoritmos personalizados para obtener información más detallada comentario sobre las razones).
  • Su modelo debe describir qué forma de datos acepta ( input_spec ), ya que, en general, TFF es un entorno fuertemente tipado y quiere determinar las firmas de tipo para todos los componentes. Declarar el formato de la entrada de su modelo es una parte esencial.
  • Aunque técnicamente no es necesario, recomendamos encapsular toda la lógica de TensorFlow (pase directo, cálculos de métricas, etc.) como tf.function s, ya que esto ayuda a garantizar que TensorFlow se pueda serializar y elimina la necesidad de dependencias de control explícitas.

Lo anterior es suficiente para evaluación y algoritmos como Federated SGD. Sin embargo, para el promedio federado, debemos especificar cómo se debe entrenar el modelo localmente en cada lote. Especificaremos un optimizador local al crear el algoritmo de promediado federado.

Simulando entrenamiento federado con el nuevo modelo

Con todo lo anterior en su lugar, el resto del proceso se parece a lo que ya hemos visto: simplemente reemplace el constructor del modelo con el constructor de nuestra nueva clase de modelo y use los dos cálculos federados en el proceso iterativo que creó para recorrer el ciclo. rondas de entrenamiento.

iterative_process = tff.learning.build_federated_averaging_process(
    MnistModel,
    client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=0.02))
state = iterative_process.initialize()
state, metrics = iterative_process.next(state, federated_train_data)
print('round  1, metrics={}'.format(metrics))
round  1, metrics=<broadcast=<>,aggregation=<>,train=<num_examples=4860.0,loss=2.9713594913482666,accuracy=0.13518518209457397>>

for round_num in range(2, 11):
  state, metrics = iterative_process.next(state, federated_train_data)
  print('round {:2d}, metrics={}'.format(round_num, metrics))
round  2, metrics=<broadcast=<>,aggregation=<>,train=<num_examples=4860.0,loss=2.975412607192993,accuracy=0.14032921195030212>>
round  3, metrics=<broadcast=<>,aggregation=<>,train=<num_examples=4860.0,loss=2.9395227432250977,accuracy=0.1594650149345398>>
round  4, metrics=<broadcast=<>,aggregation=<>,train=<num_examples=4860.0,loss=2.710164785385132,accuracy=0.17139917612075806>>
round  5, metrics=<broadcast=<>,aggregation=<>,train=<num_examples=4860.0,loss=2.5891618728637695,accuracy=0.20267489552497864>>
round  6, metrics=<broadcast=<>,aggregation=<>,train=<num_examples=4860.0,loss=2.5148487091064453,accuracy=0.21666666865348816>>
round  7, metrics=<broadcast=<>,aggregation=<>,train=<num_examples=4860.0,loss=2.2816808223724365,accuracy=0.2580246925354004>>
round  8, metrics=<broadcast=<>,aggregation=<>,train=<num_examples=4860.0,loss=2.3656885623931885,accuracy=0.25884774327278137>>
round  9, metrics=<broadcast=<>,aggregation=<>,train=<num_examples=4860.0,loss=2.23549222946167,accuracy=0.28477364778518677>>
round 10, metrics=<broadcast=<>,aggregation=<>,train=<num_examples=4860.0,loss=1.974222183227539,accuracy=0.35329216718673706>>

Para ver estas métricas dentro de TensorBoard, consulte los pasos enumerados anteriormente en "Visualización de métricas del modelo en TensorBoard".

Evaluación

Todos nuestros experimentos hasta ahora presentaron solo métricas de entrenamiento federadas: las métricas promedio de todos los lotes de datos entrenados en todos los clientes de la ronda. Esto introduce las preocupaciones normales sobre el sobreajuste, especialmente porque usamos el mismo conjunto de clientes en cada ronda para simplificar, pero existe una noción adicional de sobreajuste en las métricas de entrenamiento específicas del algoritmo de promediado federado. Esto es más fácil de ver si imaginamos que cada cliente tiene un solo lote de datos y entrenamos en ese lote para muchas iteraciones (épocas). En este caso, el modelo local rápidamente se ajustará exactamente a ese lote, por lo que la métrica de precisión local que promediamos se acercará a 1.0. Por lo tanto, estas métricas de entrenamiento pueden tomarse como una señal de que el entrenamiento está progresando, pero no mucho más.

Para realizar la evaluación de datos federados, puede construir otro cálculo federado diseñado precisamente para este propósito, utilizando la función tff.learning.build_federated_evaluation y pasando el constructor de su modelo como argumento. Tenga en cuenta que, a diferencia de Federated Averaging, donde hemos usado MnistTrainableModel , es suficiente para pasar el MnistModel . La evaluación no realiza un descenso de gradiente y no es necesario construir optimizadores.

Para la experimentación y la investigación, cuando se dispone de un conjunto de datos de prueba centralizado, Federated Learning for Text Generation demuestra otra opción de evaluación: tomar los pesos entrenados del aprendizaje federado, aplicarlos a un modelo estándar de Keras y luego simplemente llamar a tf.keras.models.Model.evaluate() en un conjunto de datos centralizado.

evaluation = tff.learning.build_federated_evaluation(MnistModel)

Puede inspeccionar la firma de tipo abstracto de la función de evaluación de la siguiente manera.

str(evaluation.type_signature)
'(<<trainable=<float32[784,10],float32[10]>,non_trainable=<>>@SERVER,{<x=float32[?,784],y=int32[?,1]>*}@CLIENTS> -> <num_examples=float32@SERVER,loss=float32@SERVER,accuracy=float32@SERVER>)'

No hay necesidad de preocuparse por los detalles en este punto, solo tenga en cuenta que toma la siguiente forma general, similar a tff.templates.IterativeProcess.next pero con dos diferencias importantes. Primero, no estamos devolviendo el estado del servidor, ya que la evaluación no modifica el modelo ni ningún otro aspecto del estado; puede considerarlo sin estado. En segundo lugar, la evaluación solo necesita el modelo y no requiere ninguna otra parte del estado del servidor que pueda estar asociada con el entrenamiento, como las variables del optimizador.

SERVER_MODEL, FEDERATED_DATA -> TRAINING_METRICS

Invoquemos la evaluación del último estado al que llegamos durante el entrenamiento. Para extraer el último modelo entrenado del estado del servidor, simplemente acceda al miembro .model , de la siguiente manera.

train_metrics = evaluation(state.model, federated_train_data)

Esto es lo que obtenemos. Tenga en cuenta que los números se ven ligeramente mejores que lo que se informó en la última ronda de capacitación anterior. Por convención, las métricas de capacitación informadas por el proceso de capacitación iterativo generalmente reflejan el rendimiento del modelo al comienzo de la ronda de capacitación, por lo que las métricas de evaluación siempre estarán un paso por delante.

str(train_metrics)
'<num_examples=4860.0,loss=1.7142657041549683,accuracy=0.38683128356933594>'

Ahora, compilemos una muestra de prueba de datos federados y volvamos a ejecutar la evaluación en los datos de prueba. Los datos provendrán de la misma muestra de usuarios reales, pero de un conjunto de datos diferente.

federated_test_data = make_federated_data(emnist_test, sample_clients)

len(federated_test_data), federated_test_data[0]
(10,
 <DatasetV1Adapter shapes: OrderedDict([(x, (None, 784)), (y, (None, 1))]), types: OrderedDict([(x, tf.float32), (y, tf.int32)])>)
test_metrics = evaluation(state.model, federated_test_data)
str(test_metrics)
'<num_examples=580.0,loss=1.861915111541748,accuracy=0.3362068831920624>'

Con esto concluye el tutorial. Le recomendamos que juegue con los parámetros (por ejemplo, tamaños de lote, número de usuarios, épocas, tasas de aprendizaje, etc.), que modifique el código anterior para simular el entrenamiento en muestras aleatorias de usuarios en cada ronda y que explore los otros tutoriales. hemos desarrollado.