Aprendizaje federado para clasificación de imágenes

Ver en TensorFlow.org Ejecutar en Google Colab Ver fuente en GitHub Descargar cuaderno

En este tutorial, se utiliza el ejemplo de entrenamiento MNIST clásico para introducir el aprendizaje Federados (FL) capa API de TFF, tff.learning - un conjunto de interfaces de alto nivel que se puede utilizar para realizar tipos comunes de las tareas de aprendizaje federados, tales como entrenamiento federado, contra modelos proporcionados por el usuario implementados en TensorFlow.

Este tutorial, y la API de aprendizaje federado, están destinados principalmente a usuarios que desean conectar sus propios modelos de TensorFlow en TFF, tratando este último principalmente como una caja negra. Para una comprensión más profunda de la TFF y cómo implementar sus propios algoritmos de aprendizaje federados, consulte los tutoriales en la API FC Core - Custom Federados Algoritmos Parte 1 y Parte 2 .

Para más información sobre tff.learning , continuar con el aprendizaje Federados de texto Generación , tutorial, que además de cubrir los modelos recurrentes, también demuestra la carga de un modelo de Keras serializado pre-formados para el refinamiento con el aprendizaje federados en combinación con la evaluación mediante Keras.

Antes que empecemos

Antes de comenzar, ejecute lo siguiente para asegurarse de que su entorno esté configurado correctamente. Si no ve un saludo, por favor refiérase a la instalación de guía para obtener instrucciones.

# tensorflow_federated_nightly also bring in tf_nightly, which
# can causes a duplicate tensorboard install, leading to errors.
!pip uninstall --yes tensorboard tb-nightly

!pip install --quiet --upgrade tensorflow-federated-nightly
!pip install --quiet --upgrade nest-asyncio
!pip install --quiet --upgrade tb-nightly  # or tensorboard, but not both

import nest_asyncio
nest_asyncio.apply()
%load_ext tensorboard
import collections

import numpy as np
import tensorflow as tf
import tensorflow_federated as tff

np.random.seed(0)

tff.federated_computation(lambda: 'Hello, World!')()
b'Hello, World!'

Preparando los datos de entrada

Comencemos con los datos. El aprendizaje federado requiere un conjunto de datos federados, es decir, una colección de datos de múltiples usuarios. Datos federada es típicamente no iid , lo que plantea una serie de desafíos.

A fin de facilitar la experimentación, sembramos el repositorio de TFF con unos pocos conjuntos de datos, incluyendo una versión Federados de MNIST que contiene una versión de la base de datos NIST original, que ha sido re-procesado utilizando la hoja de modo que los datos se teclea por el escritor original de los dígitos. Dado que cada escritor tiene un estilo único, este conjunto de datos presenta el tipo de comportamiento no iid que se espera de los conjuntos de datos federados.

Así es como podemos cargarlo.

emnist_train, emnist_test = tff.simulation.datasets.emnist.load_data()

Los conjuntos de datos devueltos por load_data() son instancias de tff.simulation.ClientData , una interfaz que le permite enumerar el conjunto de usuarios, para construir un tf.data.Dataset que representa los datos de un usuario particular, y para consultar el estructura de elementos individuales. A continuación, le mostramos cómo puede utilizar esta interfaz para explorar el contenido del conjunto de datos. Tenga en cuenta que, si bien esta interfaz le permite iterar sobre los identificadores de clientes, esta es solo una característica de los datos de simulación. Como verá en breve, el marco de aprendizaje federado no utiliza las identidades de los clientes; su único propósito es permitirle seleccionar subconjuntos de datos para simulaciones.

len(emnist_train.client_ids)
3383
emnist_train.element_type_structure
OrderedDict([('label', TensorSpec(shape=(), dtype=tf.int32, name=None)), ('pixels', TensorSpec(shape=(28, 28), dtype=tf.float32, name=None))])
example_dataset = emnist_train.create_tf_dataset_for_client(
    emnist_train.client_ids[0])

example_element = next(iter(example_dataset))

example_element['label'].numpy()
1
from matplotlib import pyplot as plt
plt.imshow(example_element['pixels'].numpy(), cmap='gray', aspect='equal')
plt.grid(False)
_ = plt.show()

png

Explorando la heterogeneidad en datos federados

Datos federada es típicamente no IID , los usuarios suelen tener diferentes distribuciones de los datos en función de los patrones de uso. Algunos clientes pueden tener menos ejemplos de capacitación en el dispositivo, debido a la escasez de datos a nivel local, mientras que algunos clientes tendrán ejemplos de capacitación más que suficientes. Exploremos este concepto de heterogeneidad de datos típico de un sistema federado con los datos EMNIST que tenemos disponibles. Es importante tener en cuenta que este análisis profundo de los datos de un cliente solo está disponible para nosotros porque se trata de un entorno de simulación en el que todos los datos están disponibles localmente. En un entorno federado de producción real, no podría inspeccionar los datos de un solo cliente.

Primero, tomemos una muestra de los datos de un cliente para tener una idea de los ejemplos en un dispositivo simulado. Debido a que el conjunto de datos que estamos usando ha sido codificado por un escritor único, los datos de un cliente representan la escritura a mano de una persona para una muestra de los dígitos del 0 al 9, simulando el "patrón de uso" único de un usuario.

## Example MNIST digits for one client
figure = plt.figure(figsize=(20, 4))
j = 0

for example in example_dataset.take(40):
  plt.subplot(4, 10, j+1)
  plt.imshow(example['pixels'].numpy(), cmap='gray', aspect='equal')
  plt.axis('off')
  j += 1

png

Ahora visualicemos el número de ejemplos en cada cliente para cada etiqueta de dígito MNIST. En el entorno federado, la cantidad de ejemplos en cada cliente puede variar bastante, dependiendo del comportamiento del usuario.

# Number of examples per layer for a sample of clients
f = plt.figure(figsize=(12, 7))
f.suptitle('Label Counts for a Sample of Clients')
for i in range(6):
  client_dataset = emnist_train.create_tf_dataset_for_client(
      emnist_train.client_ids[i])
  plot_data = collections.defaultdict(list)
  for example in client_dataset:
    # Append counts individually per label to make plots
    # more colorful instead of one color per plot.
    label = example['label'].numpy()
    plot_data[label].append(label)
  plt.subplot(2, 3, i+1)
  plt.title('Client {}'.format(i))
  for j in range(10):
    plt.hist(
        plot_data[j],
        density=False,
        bins=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10])

png

Ahora visualicemos la imagen media por cliente para cada etiqueta MNIST. Este código producirá la media de cada valor de píxel para todos los ejemplos del usuario para una etiqueta. Veremos que la imagen media de un cliente para un dígito se verá diferente a la imagen media de otro cliente para el mismo dígito, debido al estilo de escritura único de cada persona. Podemos reflexionar sobre cómo cada ronda de capacitación local empujará el modelo en una dirección diferente en cada cliente, ya que estamos aprendiendo de los datos únicos de ese usuario en esa ronda local. Más adelante en el tutorial veremos cómo podemos tomar cada actualización del modelo de todos los clientes y agregarlas en nuestro nuevo modelo global, que ha aprendido de los datos únicos de cada uno de nuestros clientes.

# Each client has different mean images, meaning each client will be nudging
# the model in their own directions locally.

for i in range(5):
  client_dataset = emnist_train.create_tf_dataset_for_client(
      emnist_train.client_ids[i])
  plot_data = collections.defaultdict(list)
  for example in client_dataset:
    plot_data[example['label'].numpy()].append(example['pixels'].numpy())
  f = plt.figure(i, figsize=(12, 5))
  f.suptitle("Client #{}'s Mean Image Per Label".format(i))
  for j in range(10):
    mean_img = np.mean(plot_data[j], 0)
    plt.subplot(2, 5, j+1)
    plt.imshow(mean_img.reshape((28, 28)))
    plt.axis('off')

png

png

png

png

png

Los datos del usuario pueden ser ruidosos y estar etiquetados de forma poco fiable. Por ejemplo, al observar los datos del Cliente n. ° 2 anteriores, podemos ver que para la etiqueta 2, es posible que haya algunos ejemplos mal etiquetados que hayan creado una imagen más ruidosa.

Preprocesar los datos de entrada

Dado que los datos ya es un tf.data.Dataset , preprocesamiento puede realizarse utilizando transformaciones conjunto de datos. En este sentido, aplanar el 28x28 imágenes en 784 matrices -elemento, baraja los ejemplos individuales, organizarlos en lotes, y renombrar las características de pixels y label de x e y para su uso con Keras. También tiramos en una repeat del juego de datos para ejecutar varias épocas.

NUM_CLIENTS = 10
NUM_EPOCHS = 5
BATCH_SIZE = 20
SHUFFLE_BUFFER = 100
PREFETCH_BUFFER = 10

def preprocess(dataset):

  def batch_format_fn(element):
    """Flatten a batch `pixels` and return the features as an `OrderedDict`."""
    return collections.OrderedDict(
        x=tf.reshape(element['pixels'], [-1, 784]),
        y=tf.reshape(element['label'], [-1, 1]))

  return dataset.repeat(NUM_EPOCHS).shuffle(SHUFFLE_BUFFER, seed=1).batch(
      BATCH_SIZE).map(batch_format_fn).prefetch(PREFETCH_BUFFER)

Verifiquemos que esto funcionó.

preprocessed_example_dataset = preprocess(example_dataset)

sample_batch = tf.nest.map_structure(lambda x: x.numpy(),
                                     next(iter(preprocessed_example_dataset)))

sample_batch
OrderedDict([('x', array([[1., 1., 1., ..., 1., 1., 1.],
       [1., 1., 1., ..., 1., 1., 1.],
       [1., 1., 1., ..., 1., 1., 1.],
       ...,
       [1., 1., 1., ..., 1., 1., 1.],
       [1., 1., 1., ..., 1., 1., 1.],
       [1., 1., 1., ..., 1., 1., 1.]], dtype=float32)), ('y', array([[2],
       [1],
       [5],
       [7],
       [1],
       [7],
       [7],
       [1],
       [4],
       [7],
       [4],
       [2],
       [2],
       [5],
       [4],
       [1],
       [1],
       [0],
       [0],
       [9]], dtype=int32))])

Tenemos casi todos los componentes básicos para construir conjuntos de datos federados.

Una de las formas de alimentar datos federados a TFF en una simulación es simplemente como una lista de Python, con cada elemento de la lista que contiene los datos de un usuario individual, ya sea como una lista o como una tf.data.Dataset . Como ya tenemos una interfaz que proporciona este último, usémosla.

Aquí hay una función auxiliar simple que construirá una lista de conjuntos de datos a partir del conjunto dado de usuarios como entrada para una ronda de capacitación o evaluación.

def make_federated_data(client_data, client_ids):
  return [
      preprocess(client_data.create_tf_dataset_for_client(x))
      for x in client_ids
  ]

Ahora bien, ¿cómo elegimos a los clientes?

En un escenario típico de capacitación federado, estamos tratando con una población potencialmente muy grande de dispositivos de usuario, de los cuales solo una fracción puede estar disponible para capacitación en un momento dado. Este es el caso, por ejemplo, cuando los dispositivos del cliente son teléfonos móviles que participan en la capacitación solo cuando están enchufados a una fuente de alimentación, fuera de una red con medidor y, por lo demás, inactivos.

Por supuesto, estamos en un entorno de simulación y todos los datos están disponibles localmente. Por lo general, entonces, al ejecutar simulaciones, simplemente tomaríamos una muestra de un subconjunto aleatorio de los clientes que participarán en cada ronda de capacitación, generalmente diferentes en cada ronda.

Dicho esto, como se puede averiguar mediante el estudio del papel sobre el valor promedio Federados algoritmo, lograr la convergencia en un sistema con subconjuntos seleccionados al azar de los clientes en cada ronda puede tomar un tiempo, y no sería práctico tener que ejecutar cientos de rondas de este tutorial interactivo.

En cambio, lo que haremos es tomar una muestra del conjunto de clientes una vez y reutilizar el mismo conjunto en las rondas para acelerar la convergencia (sobreajuste intencionalmente a los datos de estos pocos usuarios). Lo dejamos como ejercicio para que el lector modifique este tutorial para simular un muestreo aleatorio; es bastante fácil de hacer (una vez que lo haga, tenga en cuenta que lograr que el modelo converja puede llevar un tiempo).

sample_clients = emnist_train.client_ids[0:NUM_CLIENTS]

federated_train_data = make_federated_data(emnist_train, sample_clients)

print('Number of client datasets: {l}'.format(l=len(federated_train_data)))
print('First dataset: {d}'.format(d=federated_train_data[0]))
Number of client datasets: 10
First dataset: <DatasetV1Adapter shapes: OrderedDict([(x, (None, 784)), (y, (None, 1))]), types: OrderedDict([(x, tf.float32), (y, tf.int32)])>

Creando un modelo con Keras

Si está utilizando Keras, es probable que ya tenga un código que construya un modelo de Keras. Aquí hay un ejemplo de un modelo simple que será suficiente para nuestras necesidades.

def create_keras_model():
  return tf.keras.models.Sequential([
      tf.keras.layers.InputLayer(input_shape=(784,)),
      tf.keras.layers.Dense(10, kernel_initializer='zeros'),
      tf.keras.layers.Softmax(),
  ])

Para utilizar cualquier modelo con TFF, que necesita ser envuelto en una instancia de la tff.learning.Model interfaz, que expone métodos para estampar pase hacia adelante del modelo, las propiedades de metadatos, etc., de manera similar a Keras, pero también introduce adicional elementos, como las formas de controlar el proceso de cálculo de métricas federadas. No nos preocupemos por esto por ahora; si usted tiene un modelo Keras como el que acaba de definir anteriormente, puede tener TFF envolverlo para usted invocando tff.learning.from_keras_model , pasando por el modelo y un lote de datos de ejemplo como argumentos, como se muestra a continuación.

def model_fn():
  # We _must_ create a new model here, and _not_ capture it from an external
  # scope. TFF will call this within different graph contexts.
  keras_model = create_keras_model()
  return tff.learning.from_keras_model(
      keras_model,
      input_spec=preprocessed_example_dataset.element_spec,
      loss=tf.keras.losses.SparseCategoricalCrossentropy(),
      metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])

Entrenamiento del modelo en datos federados

Ahora que tenemos un modelo envuelto como tff.learning.Model para su uso con TFF, podemos dejar TFF construir un algoritmo de promedio federados mediante la invocación de la función auxiliar tff.learning.build_federated_averaging_process , como sigue.

Tenga en cuenta que el argumento tiene que ser un constructor (como model_fn arriba), no una instancia ya construido, por lo que la construcción de su modelo puede suceder en un contexto controlado por TFF (si tienes curiosidad sobre las razones de esto, os animo a leer el tutorial de seguimiento de algoritmos personalizados ).

Una nota crítica sobre el algoritmo de promedio Federados de abajo, hay 2 optimizadores: un optimizador _client y un optimizador _SERVER. El optimizador _client sólo se utiliza para calcular las actualizaciones modelo local en cada cliente. El optimizador _SERVER aplica la actualización promediado al modelo global en el servidor. En particular, esto significa que la elección del optimizador y la tasa de aprendizaje utilizados pueden tener que ser diferentes a las que ha utilizado para entrenar el modelo en un conjunto de datos iid estándar. Recomendamos comenzar con SGD regular, posiblemente con una tasa de aprendizaje menor de lo habitual. La tasa de aprendizaje que usamos no se ha ajustado cuidadosamente, no dude en experimentar.

iterative_process = tff.learning.build_federated_averaging_process(
    model_fn,
    client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=0.02),
    server_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=1.0))

¿Lo que acaba de suceder? TFF ha construido un par de cálculos federados y los empaquetados en un tff.templates.IterativeProcess en el que estos cálculos están disponibles como un par de propiedades initialize y next .

En pocas palabras, los cálculos federados son programas en lenguaje interno de TFF que pueden expresar diferentes algoritmos federados (se puede encontrar más información sobre esto en el algoritmos personalizados tutorial). En este caso, los dos cálculos generan y se empaquetan en iterative_process implemento Federados de promedio .

El objetivo de TFF es definir los cálculos de manera que puedan ejecutarse en entornos reales de aprendizaje federado, pero actualmente solo se implementa el tiempo de ejecución de simulación de ejecución local. Para ejecutar un cálculo en un simulador, simplemente lo invoca como una función de Python. Este entorno interpretado por defecto no está diseñado para un alto rendimiento, pero será suficiente para este tutorial; Esperamos proporcionar tiempos de ejecución de simulación de mayor rendimiento para facilitar la investigación a mayor escala en versiones futuras.

Vamos a empezar con la initialize de cálculo. Como es el caso de todos los cálculos federados, puede considerarlo como una función. El cálculo no toma argumentos y devuelve un resultado: la representación del estado del proceso de Promedio federado en el servidor. Si bien no queremos profundizar en los detalles de TFF, puede ser instructivo ver cómo se ve este estado. Puedes visualizarlo de la siguiente manera.

str(iterative_process.initialize.type_signature)
'( -> <model=<trainable=<float32[784,10],float32[10]>,non_trainable=<>>,optimizer_state=<int64>,delta_aggregate_state=<value_sum_process=<>,weight_sum_process=<>>,model_broadcast_state=<>>@SERVER)'

Mientras que el tipo de firma anterior puede parecer a primera vista un críptico bits, puede reconocer que el estado del servidor consiste en un model (los parámetros iniciales del modelo de MNIST que serán distribuidos a todos los dispositivos), y optimizer_state (más información mantenida por el servidor, como el número de rondas que se utilizarán para las programaciones de hiperparámetros, etc.).

Invoquemos la initialize de cálculo para construir el estado del servidor.

state = iterative_process.initialize()

El segundo del par de cálculos federados, next , representa una única ronda de Federated de promedio, que consiste en empujar el estado del servidor (incluyendo los parámetros del modelo) a los clientes, en el dispositivo de formación sobre sus datos locales, la recolección y actualizaciones del modelo de promediado y producir un nuevo modelo actualizado en el servidor.

Conceptualmente, se puede pensar en next como tener una firma de tipo funcional que tiene la siguiente apariencia.

SERVER_STATE, FEDERATED_DATA -> SERVER_STATE, TRAINING_METRICS

En particular, hay que pensar en next() no como una función que se ejecuta en un servidor, sino más bien ser una representación funcional declarativa de toda la computación descentralizada - algunos de los insumos son proporcionados por el servidor ( SERVER_STATE ), pero cada participante dispositivo aporta su propio conjunto de datos local.

Realicemos una sola ronda de entrenamiento y visualicemos los resultados. Podemos utilizar los datos federados que ya hemos generado anteriormente para una muestra de usuarios.

state, metrics = iterative_process.next(state, federated_train_data)
print('round  1, metrics={}'.format(metrics))
round  1, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.12345679), ('loss', 3.1193738)])), ('stat', OrderedDict([('num_examples', 4860)]))])

Hagamos algunas rondas más. Como se mencionó anteriormente, normalmente en este punto usted escogería un subconjunto de sus datos de simulación de una nueva muestra de usuarios seleccionada al azar para cada ronda con el fin de simular una implementación realista en la que los usuarios van y vienen continuamente, pero en este cuaderno interactivo, por En aras de la demostración, simplemente reutilizaremos a los mismos usuarios, de modo que el sistema converja rápidamente.

NUM_ROUNDS = 11
for round_num in range(2, NUM_ROUNDS):
  state, metrics = iterative_process.next(state, federated_train_data)
  print('round {:2d}, metrics={}'.format(round_num, metrics))
round  2, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.13518518), ('loss', 2.9834728)])), ('stat', OrderedDict([('num_examples', 4860)]))])
round  3, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.14382716), ('loss', 2.861665)])), ('stat', OrderedDict([('num_examples', 4860)]))])
round  4, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.17407407), ('loss', 2.7957022)])), ('stat', OrderedDict([('num_examples', 4860)]))])
round  5, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.19917695), ('loss', 2.6146567)])), ('stat', OrderedDict([('num_examples', 4860)]))])
round  6, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.21975309), ('loss', 2.529761)])), ('stat', OrderedDict([('num_examples', 4860)]))])
round  7, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.2409465), ('loss', 2.4053504)])), ('stat', OrderedDict([('num_examples', 4860)]))])
round  8, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.2611111), ('loss', 2.315389)])), ('stat', OrderedDict([('num_examples', 4860)]))])
round  9, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.30823046), ('loss', 2.1240263)])), ('stat', OrderedDict([('num_examples', 4860)]))])
round 10, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.33312756), ('loss', 2.1164262)])), ('stat', OrderedDict([('num_examples', 4860)]))])

La pérdida de entrenamiento está disminuyendo después de cada ronda de entrenamiento federado, lo que indica que el modelo está convergiendo. Hay algunas advertencias importantes con estas métricas de formación, sin embargo, ver la sección de evaluación más adelante en este tutorial.

Visualización de métricas del modelo en TensorBoard

A continuación, visualicemos las métricas de estos cálculos federados usando Tensorboard.

Comencemos por crear el directorio y el escritor de resumen correspondiente para escribir las métricas.

logdir = "/tmp/logs/scalars/training/"
summary_writer = tf.summary.create_file_writer(logdir)
state = iterative_process.initialize()

Trace las métricas escalares relevantes con el mismo escritor de resumen.

with summary_writer.as_default():
  for round_num in range(1, NUM_ROUNDS):
    state, metrics = iterative_process.next(state, federated_train_data)
    for name, value in metrics['train'].items():
      tf.summary.scalar(name, value, step=round_num)

Inicie TensorBoard con el directorio de registro raíz especificado anteriormente. Los datos pueden tardar unos segundos en cargarse.

!ls {logdir}
%tensorboard --logdir {logdir} --port=0
events.out.tfevents.1629557449.ebe6e776479e64ea-4903924a278.borgtask.google.com.458912.1.v2
Launching TensorBoard...
Reusing TensorBoard on port 50681 (pid 292785), started 0:30:30 ago. (Use '!kill 292785' to kill it.)
<IPython.core.display.Javascript at 0x7fd6617e02d0>
# Uncomment and run this this cell to clean your directory of old output for
# future graphs from this directory. We don't run it by default so that if 
# you do a "Runtime > Run all" you don't lose your results.

# !rm -R /tmp/logs/scalars/*

Para ver las métricas de evaluación de la misma manera, puede crear una carpeta de evaluación separada, como "logs / scalars / eval", para escribir en TensorBoard.

Personalización de la implementación del modelo

Keras es la API de modelo de alto nivel recomendado para TensorFlow , y que fomentan el uso de modelos Keras (a través de tff.learning.from_keras_model ) en TFF siempre que sea posible.

Sin embargo, tff.learning proporciona una interfaz de modelo de nivel inferior, tff.learning.Model , que expone la funcionalidad mínima necesaria para el uso de un modelo de aprendizaje federado. Directamente la implementación de esta interfaz (posiblemente sigue utilizando bloques de construcción como tf.keras.layers ) permite la máxima personalización sin modificar el funcionamiento interno de los algoritmos de aprendizaje federados.

Así que hagámoslo de nuevo desde cero.

Definición de variables de modelo, pase directo y métricas

El primer paso es identificar las variables de TensorFlow con las que vamos a trabajar. Para que el siguiente código sea más legible, definamos una estructura de datos para representar el conjunto completo. Esto incluye variables como weights y bias que vamos a entrenar, así como las variables que contendrán diversas estadísticas acumulativas y contadores vamos a actualizar durante el entrenamiento, como loss_sum , accuracy_sum y num_examples .

MnistVariables = collections.namedtuple(
    'MnistVariables', 'weights bias num_examples loss_sum accuracy_sum')

Aquí hay un método que crea las variables. En aras de la simplicidad, representamos a todas las estadísticas como tf.float32 , ya que esto eliminaría la necesidad de conversiones de tipos en una etapa posterior. Envolviendo inicializadores variables como lambdas es un requisito impuesto por variables de recurso .

def create_mnist_variables():
  return MnistVariables(
      weights=tf.Variable(
          lambda: tf.zeros(dtype=tf.float32, shape=(784, 10)),
          name='weights',
          trainable=True),
      bias=tf.Variable(
          lambda: tf.zeros(dtype=tf.float32, shape=(10)),
          name='bias',
          trainable=True),
      num_examples=tf.Variable(0.0, name='num_examples', trainable=False),
      loss_sum=tf.Variable(0.0, name='loss_sum', trainable=False),
      accuracy_sum=tf.Variable(0.0, name='accuracy_sum', trainable=False))

Con las variables para los parámetros del modelo y las estadísticas acumuladas en su lugar, ahora podemos definir el método de paso hacia adelante que calcula la pérdida, emite predicciones y actualiza las estadísticas acumuladas para un solo lote de datos de entrada, de la siguiente manera.

def predict_on_batch(variables, x):
  return tf.nn.softmax(tf.matmul(x, variables.weights) + variables.bias)

def mnist_forward_pass(variables, batch):
  y = predict_on_batch(variables, batch['x'])
  predictions = tf.cast(tf.argmax(y, 1), tf.int32)

  flat_labels = tf.reshape(batch['y'], [-1])
  loss = -tf.reduce_mean(
      tf.reduce_sum(tf.one_hot(flat_labels, 10) * tf.math.log(y), axis=[1]))
  accuracy = tf.reduce_mean(
      tf.cast(tf.equal(predictions, flat_labels), tf.float32))

  num_examples = tf.cast(tf.size(batch['y']), tf.float32)

  variables.num_examples.assign_add(num_examples)
  variables.loss_sum.assign_add(loss * num_examples)
  variables.accuracy_sum.assign_add(accuracy * num_examples)

  return loss, predictions

A continuación, definimos una función que devuelve un conjunto de métricas locales, nuevamente usando TensorFlow. Estos son los valores (además de las actualizaciones del modelo, que se manejan automáticamente) que son elegibles para agregarse al servidor en un proceso de evaluación o aprendizaje federado.

Aquí, nos devuelva el promedio de loss y accuracy , así como los num_examples , que vamos a necesitar para ponderar correctamente las contribuciones de los diferentes usuarios cuando se calculan los agregados federados.

def get_local_mnist_metrics(variables):
  return collections.OrderedDict(
      num_examples=variables.num_examples,
      loss=variables.loss_sum / variables.num_examples,
      accuracy=variables.accuracy_sum / variables.num_examples)

Por último, es necesario determinar cómo agregar los indicadores locales emitidos por cada dispositivo a través de get_local_mnist_metrics . Esta es la única parte del código que no está escrita en TensorFlow - es un cálculo federada expresada en TFF. Si desea conocer en profundidad, deslizarse sobre la algoritmos personalizados tutorial, pero en la mayoría de las aplicaciones, realmente no se necesita; las variantes del patrón que se muestra a continuación deberían ser suficientes. Así es como se ve:

@tff.federated_computation
def aggregate_mnist_metrics_across_clients(metrics):
  return collections.OrderedDict(
      num_examples=tff.federated_sum(metrics.num_examples),
      loss=tff.federated_mean(metrics.loss, metrics.num_examples),
      accuracy=tff.federated_mean(metrics.accuracy, metrics.num_examples))

La entrada metrics argumento corresponde a la OrderedDict devuelto por get_local_mnist_metrics anteriores, pero críticamente los valores ya no son tf.Tensors - que están "en caja", como tff.Value s, para que quede claro que ya no puede manipularlos utilizando TensorFlow, pero sólo utilizando operadores federados de TFF como tff.federated_mean y tff.federated_sum . El diccionario devuelto de agregados globales define el conjunto de métricas que estarán disponibles en el servidor.

La construcción de una instancia de tff.learning.Model

Con todo lo anterior en su lugar, estamos listos para construir una representación de modelo para usar con TFF similar a la que se genera para usted cuando deja que TFF ingiera un modelo de Keras.

class MnistModel(tff.learning.Model):

  def __init__(self):
    self._variables = create_mnist_variables()

  @property
  def trainable_variables(self):
    return [self._variables.weights, self._variables.bias]

  @property
  def non_trainable_variables(self):
    return []

  @property
  def local_variables(self):
    return [
        self._variables.num_examples, self._variables.loss_sum,
        self._variables.accuracy_sum
    ]

  @property
  def input_spec(self):
    return collections.OrderedDict(
        x=tf.TensorSpec([None, 784], tf.float32),
        y=tf.TensorSpec([None, 1], tf.int32))

  @tf.function
  def predict_on_batch(self, x, training=True):
    del training
    return predict_on_batch(self._variables, x)

  @tf.function
  def forward_pass(self, batch, training=True):
    del training
    loss, predictions = mnist_forward_pass(self._variables, batch)
    num_exmaples = tf.shape(batch['x'])[0]
    return tff.learning.BatchOutput(
        loss=loss, predictions=predictions, num_examples=num_exmaples)

  @tf.function
  def report_local_outputs(self):
    return get_local_mnist_metrics(self._variables)

  @property
  def federated_output_computation(self):
    return aggregate_mnist_metrics_across_clients

Como se puede ver, los métodos abstractos y propiedades definidas por tff.learning.Model corresponde a los fragmentos de código en la sección anterior que introdujo las variables y define la pérdida y la estadística.

Aquí hay algunos puntos que vale la pena destacar:

  • Todo estado que su modelo se utilice debe ser capturado como variables TensorFlow, como TFF no utiliza Python en tiempo de ejecución (recuerda que el código debe ser escrita de tal manera que se puede implementar en dispositivos móviles; ver el encargo algoritmos tutorial para una mayor profundidad comentario sobre las razones).
  • Su modelo debe describir qué tipo de datos se acepta ( input_spec ), ya que en general, TFF es un entorno fuertemente tipado y quiere determinar las firmas de tipos de todos los componentes. Declarar el formato de la entrada de su modelo es una parte esencial.
  • Aunque técnicamente no es obligatorio, se recomienda envolver toda lógica TensorFlow (pase hacia adelante, cómputos métricos, etc.) como tf.function s, ya que esto ayuda a garantizar la TensorFlow se puede serializar, y elimina la necesidad de que las dependencias de control explícitos.

Lo anterior es suficiente para la evaluación y los algoritmos como Federated SGD. Sin embargo, para el promedio federado, debemos especificar cómo se debe entrenar el modelo localmente en cada lote. Especificaremos un optimizador local al crear el algoritmo de promediado federado.

Simulando entrenamiento federado con el nuevo modelo

Con todo lo anterior en su lugar, el resto del proceso se parece a lo que ya hemos visto: simplemente reemplace el constructor del modelo con el constructor de nuestra nueva clase de modelo y use los dos cálculos federados en el proceso iterativo que creó para recorrer el ciclo. rondas de entrenamiento.

iterative_process = tff.learning.build_federated_averaging_process(
    MnistModel,
    client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=0.02))
state = iterative_process.initialize()
state, metrics = iterative_process.next(state, federated_train_data)
print('round  1, metrics={}'.format(metrics))
round  1, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('num_examples', 4860.0), ('loss', 3.0708053), ('accuracy', 0.12777779)])), ('stat', OrderedDict([('num_examples', 4860)]))])
for round_num in range(2, 11):
  state, metrics = iterative_process.next(state, federated_train_data)
  print('round {:2d}, metrics={}'.format(round_num, metrics))
round  2, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('num_examples', 4860.0), ('loss', 3.011699), ('accuracy', 0.13024691)])), ('stat', OrderedDict([('num_examples', 4860)]))])
round  3, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('num_examples', 4860.0), ('loss', 2.7408307), ('accuracy', 0.15576132)])), ('stat', OrderedDict([('num_examples', 4860)]))])
round  4, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('num_examples', 4860.0), ('loss', 2.6761012), ('accuracy', 0.17921811)])), ('stat', OrderedDict([('num_examples', 4860)]))])
round  5, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('num_examples', 4860.0), ('loss', 2.675567), ('accuracy', 0.1855967)])), ('stat', OrderedDict([('num_examples', 4860)]))])
round  6, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('num_examples', 4860.0), ('loss', 2.5664043), ('accuracy', 0.20329218)])), ('stat', OrderedDict([('num_examples', 4860)]))])
round  7, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('num_examples', 4860.0), ('loss', 2.4179392), ('accuracy', 0.24382716)])), ('stat', OrderedDict([('num_examples', 4860)]))])
round  8, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('num_examples', 4860.0), ('loss', 2.3237286), ('accuracy', 0.26687244)])), ('stat', OrderedDict([('num_examples', 4860)]))])
round  9, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('num_examples', 4860.0), ('loss', 2.1861682), ('accuracy', 0.28209877)])), ('stat', OrderedDict([('num_examples', 4860)]))])
round 10, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('num_examples', 4860.0), ('loss', 2.046388), ('accuracy', 0.32037038)])), ('stat', OrderedDict([('num_examples', 4860)]))])

Para ver estas métricas dentro de TensorBoard, consulte los pasos enumerados anteriormente en "Visualización de las métricas del modelo en TensorBoard".

Evaluación

Todos nuestros experimentos hasta ahora presentaron solo métricas de entrenamiento federadas: las métricas promedio de todos los lotes de datos entrenados en todos los clientes de la ronda. Esto introduce las preocupaciones normales sobre el sobreajuste, especialmente porque usamos el mismo conjunto de clientes en cada ronda para simplificar, pero existe una noción adicional de sobreajuste en las métricas de entrenamiento específicas del algoritmo de promediado federado. Esto es más fácil de ver si imaginamos que cada cliente tiene un solo lote de datos, y entrenamos en ese lote para muchas iteraciones (épocas). En este caso, el modelo local rápidamente se ajustará exactamente a ese lote, por lo que la métrica de precisión local que promediamos se acercará a 1.0. Por lo tanto, estas métricas de entrenamiento pueden tomarse como una señal de que el entrenamiento está progresando, pero no mucho más.

Para llevar a cabo la evaluación de los datos federados, se puede construir otro cálculo federada diseñada para este propósito, utilizando el tff.learning.build_federated_evaluation función, y que pasa en su constructor modelo como un argumento. Tenga en cuenta que a diferencia de los Federados de promedio, en donde hemos utilizado MnistTrainableModel , basta con pasar el MnistModel . La evaluación no realiza un descenso de gradiente y no es necesario construir optimizadores.

Para la experimentación e investigación, cuando un conjunto de datos de prueba centralizada está disponible, Federated aprendizaje de texto Generación demuestra otra opción de evaluación: tomando los pesos capacitados de aprendizaje federados, aplicándolos a un modelo estándar Keras, y luego simplemente llamando tf.keras.models.Model.evaluate() en un conjunto de datos centralizada.

evaluation = tff.learning.build_federated_evaluation(MnistModel)

Puede inspeccionar la firma de tipo abstracto de la función de evaluación de la siguiente manera.

str(evaluation.type_signature)
'(<server_model_weights=<trainable=<float32[784,10],float32[10]>,non_trainable=<>>@SERVER,federated_dataset={<x=float32[?,784],y=int32[?,1]>*}@CLIENTS> -> <eval=<num_examples=float32,loss=float32,accuracy=float32>,stat=<num_examples=int64>>@SERVER)'

No hay necesidad de preocuparse por los detalles en este momento, sólo ten en cuenta que tiene la siguiente forma general, similares a tff.templates.IterativeProcess.next pero con dos diferencias importantes. En primer lugar, no devolvemos el estado del servidor, ya que la evaluación no modifica el modelo ni ningún otro aspecto del estado; puede considerarlo sin estado. En segundo lugar, la evaluación solo necesita el modelo y no requiere ninguna otra parte del estado del servidor que pueda estar asociada con el entrenamiento, como las variables del optimizador.

SERVER_MODEL, FEDERATED_DATA -> TRAINING_METRICS

Invoquemos la evaluación sobre el último estado al que llegamos durante el entrenamiento. Con el fin de extraer el último modelo entrenado desde el estado del servidor, sólo tiene que acceder a la .model miembro, de la siguiente manera.

train_metrics = evaluation(state.model, federated_train_data)

Esto es lo que obtenemos. Tenga en cuenta que los números se ven ligeramente mejores que los que se informó en la última ronda de capacitación anterior. Por convención, las métricas de capacitación informadas por el proceso de capacitación iterativo generalmente reflejan el rendimiento del modelo al comienzo de la ronda de capacitación, por lo que las métricas de evaluación siempre estarán un paso por delante.

str(train_metrics)
"OrderedDict([('eval', OrderedDict([('num_examples', 4860.0), ('loss', 1.7510437), ('accuracy', 0.2788066)])), ('stat', OrderedDict([('num_examples', 4860)]))])"

Ahora, compilemos una muestra de prueba de datos federados y volvamos a ejecutar la evaluación en los datos de prueba. Los datos provendrán de la misma muestra de usuarios reales, pero de un conjunto de datos distinto.

federated_test_data = make_federated_data(emnist_test, sample_clients)

len(federated_test_data), federated_test_data[0]
(10,
 <DatasetV1Adapter shapes: OrderedDict([(x, (None, 784)), (y, (None, 1))]), types: OrderedDict([(x, tf.float32), (y, tf.int32)])>)
test_metrics = evaluation(state.model, federated_test_data)
str(test_metrics)
"OrderedDict([('eval', OrderedDict([('num_examples', 580.0), ('loss', 1.8361608), ('accuracy', 0.2413793)])), ('stat', OrderedDict([('num_examples', 580)]))])"

Con esto concluye el tutorial. Le recomendamos que juegue con los parámetros (p. Ej., Tamaños de lote, número de usuarios, épocas, tasas de aprendizaje, etc.), que modifique el código anterior para simular el entrenamiento en muestras aleatorias de usuarios en cada ronda y que explore los otros tutoriales. hemos desarrollado.