¡El Día de la Comunidad de ML es el 9 de noviembre! Únase a nosotros para recibir actualizaciones de TensorFlow, JAX, y más Más información

Trabajando con ClientData de tff.

Ver en TensorFlow.org Ejecutar en Google Colab Ver fuente en GitHub Descargar cuaderno

La noción de un conjunto de datos codificado por clientes (por ejemplo, usuarios) es esencial para la computación federada como se modela en TFF. TFF proporciona la interfaz tff.simulation.datasets.ClientData abstraer sobre este concepto, y los conjuntos de datos donde se celebran TFF ( stackoverflow , Shakespeare , emnist , cifar100 y gldv2 ) toda implementar esta interfaz.

Si está trabajando en el aprendizaje federada con su propio conjunto de datos, TFF recomienda encarecidamente que implementar ya sea el ClientData interfaz o utilizar una de las funciones de ayudante de TFF para generar un ClientData que representa los datos en el disco, por ejemplo tff.simulation.datasets.ClientData.from_clients_and_fn .

Como la mayoría de los ejemplos de extremo a extremo de TFF empezar con ClientData objetos, la aplicación de la ClientData interfaz con el conjunto de datos a medida que hará más fácil a través de Spelunk código existente escrito con TFF. Además, los tf.data.Datasets que ClientData construcciones se pueden repiten a lo largo directamente para producir estructuras de numpy arrays, por lo ClientData objetos se pueden utilizar con cualquier marco ML basado en Python antes de pasar a TFF.

Hay varios patrones con los que puede hacer su vida más fácil si tiene la intención de ampliar sus simulaciones a muchas máquinas o implementarlas. A continuación vamos a caminar a través de algunas de las formas en que podemos utilizar ClientData y TFF para hacer nuestra pequeña escala iteración a la experimentación a gran escala a la producción de experiencia de implementación lo más suave posible.

¿Qué patrón debo usar para pasar ClientData a TFF?

Vamos a discutir dos usos de la TFF ClientData de profundidad; si encaja en cualquiera de las dos categorías siguientes, claramente preferirá una sobre la otra. De lo contrario, es posible que necesite una comprensión más detallada de los pros y los contras de cada uno para tomar una decisión más matizada.

  • Quiero iterar lo más rápido posible en una máquina local; No necesito poder aprovechar fácilmente el tiempo de ejecución distribuido de TFF.

    • Se debe pasar tf.data.Datasets a TFF directamente.
    • Esto le permite programar imperativamente con tf.data.Dataset objetos, y procesarlos de manera arbitraria.
    • Proporciona más flexibilidad que la siguiente opción; Enviar la lógica a los clientes requiere que esta lógica sea serializable.
  • Quiero ejecutar mi computación federada en el tiempo de ejecución remoto de TFF, o planeo hacerlo pronto.

    • En este caso, desea mapear la construcción y el preprocesamiento de conjuntos de datos a los clientes.
    • Esto resulta en que pasa simplemente una lista de client_ids directamente a su cómputo federado.
    • Impulsar la construcción y el preprocesamiento de conjuntos de datos a los clientes evita los cuellos de botella en la serialización y aumenta significativamente el rendimiento con cientos o miles de clientes.

Configurar un entorno de código abierto

Importar paquetes

Manipular un objeto ClientData

Vamos a empezar por la carga y la exploración de TFF EMNIST ClientData :

client_data, _ = tff.simulation.datasets.emnist.load_data()
Downloading emnist_all.sqlite.lzma: 100%|██████████| 170507172/170507172 [00:19<00:00, 8831921.67it/s]
2021-10-01 11:17:58.718735: E tensorflow/stream_executor/cuda/cuda_driver.cc:271] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected

Inspeccionar el primer conjunto de datos nos puede decir qué tipo de ejemplos se encuentran en el ClientData .

first_client_id = client_data.client_ids[0]
first_client_dataset = client_data.create_tf_dataset_for_client(
    first_client_id)
print(first_client_dataset.element_spec)
# This information is also available as a `ClientData` property:
assert client_data.element_type_structure == first_client_dataset.element_spec
OrderedDict([('label', TensorSpec(shape=(), dtype=tf.int32, name=None)), ('pixels', TensorSpec(shape=(28, 28), dtype=tf.float32, name=None))])

Nota que los rendimientos de conjunto de datos collections.OrderedDict objetos que tienen pixels y label llaves, donde píxeles es un tensor con forma de [28, 28] . Supongamos que deseamos para aplanar nuestras entradas a la forma [784] . Una posible manera de hacer esto sería aplicar una función de pre-procesamiento en nuestro ClientData objeto.

def preprocess_dataset(dataset):
  """Create batches of 5 examples, and limit to 3 batches."""

  def map_fn(input):
    return collections.OrderedDict(
        x=tf.reshape(input['pixels'], shape=(-1, 784)),
        y=tf.cast(tf.reshape(input['label'], shape=(-1, 1)), tf.int64),
    )

  return dataset.batch(5).map(
      map_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE).take(5)


preprocessed_client_data = client_data.preprocess(preprocess_dataset)

# Notice that we have both reshaped and renamed the elements of the ordered dict.
first_client_dataset = preprocessed_client_data.create_tf_dataset_for_client(
    first_client_id)
print(first_client_dataset.element_spec)
OrderedDict([('x', TensorSpec(shape=(None, 784), dtype=tf.float32, name=None)), ('y', TensorSpec(shape=(None, 1), dtype=tf.int64, name=None))])

Es posible que queramos, además, realizar un preprocesamiento más complejo (y posiblemente con estado), por ejemplo, barajar.

def preprocess_and_shuffle(dataset):
  """Applies `preprocess_dataset` above and shuffles the result."""
  preprocessed = preprocess_dataset(dataset)
  return preprocessed.shuffle(buffer_size=5)

preprocessed_and_shuffled = client_data.preprocess(preprocess_and_shuffle)

# The type signature will remain the same, but the batches will be shuffled.
first_client_dataset = preprocessed_and_shuffled.create_tf_dataset_for_client(
    first_client_id)
print(first_client_dataset.element_spec)
OrderedDict([('x', TensorSpec(shape=(None, 784), dtype=tf.float32, name=None)), ('y', TensorSpec(shape=(None, 1), dtype=tf.int64, name=None))])

Interfaz con un tff.Computation

Ahora que podemos realizar algunas operaciones básicas con ClientData objetos, estamos listos para los datos de alimentación a un tff.Computation . Definimos una tff.templates.IterativeProcess que implementa Federados de promedio , y explorar diferentes métodos de pasándole datos.

def model_fn():
  model = tf.keras.models.Sequential([
      tf.keras.layers.InputLayer(input_shape=(784,)),
      tf.keras.layers.Dense(10, kernel_initializer='zeros'),
  ])
  return tff.learning.from_keras_model(
      model,
      # Note: input spec is the _batched_ shape, and includes the 
      # label tensor which will be passed to the loss function. This model is
      # therefore configured to accept data _after_ it has been preprocessed.
      input_spec=collections.OrderedDict(
          x=tf.TensorSpec(shape=[None, 784], dtype=tf.float32),
          y=tf.TensorSpec(shape=[None, 1], dtype=tf.int64)),
      loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
      metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])

trainer = tff.learning.build_federated_averaging_process(
    model_fn,
    client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=0.01))

Antes de comenzar a trabajar con esta IterativeProcess , un comentario sobre la semántica de ClientData está en orden. Un ClientData objeto representa la totalidad de la población disponible para la formación federados, que en general es no disponible para el entorno de ejecución de un sistema de producción FL y es específica para la simulación. ClientData de hecho le da al usuario la capacidad de derivación de la computación federados por completo y simplemente entrenar un modelo del lado del servidor como de costumbre a través de ClientData.create_tf_dataset_from_all_clients .

El entorno de simulación de TFF pone al investigador en completo control del bucle exterior. En particular, esto implica consideraciones de disponibilidad del cliente, abandono del cliente, etc., que deben ser abordadas por el usuario o el script del controlador Python. Uno podría, por ejemplo modelo de deserción cliente mediante el ajuste de la distribución de muestreo sobre sus ClientData's client_ids tales que los usuarios con más datos (y correspondientemente más largo a ejecutar cálculos locales) sería seleccionado con menor probabilidad.

Sin embargo, en un sistema federado real, el entrenador modelo no puede seleccionar explícitamente a los clientes; la selección de clientes se delega al sistema que está ejecutando el cómputo federado.

Pasando tf.data.Datasets directamente a TFF

Una de las opciones que tenemos para la interconexión entre un ClientData y un IterativeProcess es el de construir tf.data.Datasets en Python, y pasando estos conjuntos de datos a TFF.

Tenga en cuenta que si usamos nuestros preprocesados ClientData los conjuntos de datos que dió son del tipo apropiado esperado por nuestro modelo definido anteriormente.

selected_client_ids = preprocessed_and_shuffled.client_ids[:10]

preprocessed_data_for_clients = [
    preprocessed_and_shuffled.create_tf_dataset_for_client(
        selected_client_ids[i]) for i in range(10)
]

state = trainer.initialize()
for _ in range(5):
  t1 = time.time()
  state, metrics = trainer.next(state, preprocessed_data_for_clients)
  t2 = time.time()
  print('loss {}, round time {}'.format(metrics['train']['loss'], t2 - t1))
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow_federated/python/core/impl/compiler/tensorflow_computation_transformations.py:62: extract_sub_graph (from tensorflow.python.framework.graph_util_impl) is deprecated and will be removed in a future version.
Instructions for updating:
Use `tf.compat.v1.graph_util.extract_sub_graph`
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow_federated/python/core/impl/compiler/tensorflow_computation_transformations.py:62: extract_sub_graph (from tensorflow.python.framework.graph_util_impl) is deprecated and will be removed in a future version.
Instructions for updating:
Use `tf.compat.v1.graph_util.extract_sub_graph`
loss 2.9005744457244873, round time 4.576513767242432
loss 3.113278388977051, round time 0.49641919136047363
loss 2.7581865787506104, round time 0.4904160499572754
loss 2.87259578704834, round time 0.48976993560791016
loss 3.1202380657196045, round time 0.6724586486816406

Si tomamos esta ruta, sin embargo, no seremos capaces de trivialmente para mover a la simulación multimáquina. Los conjuntos de datos que construimos en el tiempo de ejecución TensorFlow local puede capturar el estado del ambiente que rodea a pitón, y fallar en la serialización o deserialización cuando intentan estado de referencia, que ya no está disponible para ellos es. Esto puede manifestarse por ejemplo en el error de inescrutable de TensorFlow tensor_util.cc :

Check failed: DT_VARIANT == input.dtype() (21 vs. 20)

Mapeo de la construcción y preprocesamiento sobre los clientes.

Para evitar este problema, TFF recomienda a sus usuarios a tener en cuenta el conjunto de datos de instancias y el procesamiento previo como algo que ocurre localmente en cada cliente, ya utilizar los ayudantes de TFF o federated_map para funcionar de forma explícita este código procesamiento previo a cada cliente.

Conceptualmente, la razón para preferir esto es clara: en el tiempo de ejecución local de TFF, los clientes solo "accidentalmente" tienen acceso al entorno global de Python debido al hecho de que toda la orquestación federada ocurre en una sola máquina. Vale la pena señalar en este punto que un pensamiento similar da lugar a la filosofía funcional multiplataforma, siempre serializable de TFF.

TFF hace que un cambio tan sencilla a través ClientData's atributo dataset_computation , un tff.Computation que toma un client_id y devuelve el asociado tf.data.Dataset .

Tenga en cuenta que preprocess simplemente trabaja con dataset_computation ; la dataset_computation atributo del preprocesado ClientData incorpora toda la tubería de procesamiento previo que acabamos de definir:

print('dataset computation without preprocessing:')
print(client_data.dataset_computation.type_signature)
print('\n')
print('dataset computation with preprocessing:')
print(preprocessed_and_shuffled.dataset_computation.type_signature)
dataset computation without preprocessing:
(string -> <label=int32,pixels=float32[28,28]>*)


dataset computation with preprocessing:
(string -> <x=float32[?,784],y=int64[?,1]>*)

Podríamos invocar dataset_computation y recibir un conjunto de datos con ganas en el tiempo de ejecución de Python, pero el poder real de este enfoque se ejerce cuando componemos con un proceso iterativo de cálculo u otra para evitar la materialización de estos conjuntos de datos en el tiempo de ejecución ansiosos mundial en absoluto. TFF proporciona una función de ayuda tff.simulation.compose_dataset_computation_with_iterative_process que puede ser utilizado para hacer exactamente esto.

trainer_accepting_ids = tff.simulation.compose_dataset_computation_with_iterative_process(
    preprocessed_and_shuffled.dataset_computation, trainer)

Tanto este tff.templates.IterativeProcesses y la de arriba funcionar de la misma manera; pero el ex acepta preprocesados conjuntos de datos de cliente, y el último acepta cadenas que representan los ID de cliente, la manipulación tanto en la construcción y el conjunto de datos de preprocesamiento en su cuerpo - de hecho state se puede transmitir entre los dos.

for _ in range(5):
  t1 = time.time()
  state, metrics = trainer_accepting_ids.next(state, selected_client_ids)
  t2 = time.time()
  print('loss {}, round time {}'.format(metrics['train']['loss'], t2 - t1))
loss 2.8417396545410156, round time 1.6707067489624023
loss 2.7670371532440186, round time 0.5207102298736572
loss 2.665048122406006, round time 0.5302855968475342
loss 2.7213189601898193, round time 0.5313887596130371
loss 2.580148935317993, round time 0.5283482074737549

Escalando a un gran número de clientes

trainer_accepting_ids de inmediato se pueden utilizar en tiempo de ejecución multimáquina de TFF, y evita materializar tf.data.Datasets y el controlador (y por lo tanto la serialización de ellos y enviarlos a los trabajadores).

Esto acelera significativamente las simulaciones distribuidas, especialmente con una gran cantidad de clientes, y permite la agregación intermedia para evitar una sobrecarga similar de serialización / deserialización.

Deepdive opcional: componer manualmente la lógica de preprocesamiento en TFF

TFF está diseñado para la composicionalidad desde cero; el tipo de composición que acaba de realizar el ayudante de TFF está totalmente bajo nuestro control como usuarios. Podríamos tener manualmente componer el cálculo de procesamiento previo que acabamos de definir con la del entrenador propia next sencillamente:

selected_clients_type = tff.FederatedType(preprocessed_and_shuffled.dataset_computation.type_signature.parameter, tff.CLIENTS)

@tff.federated_computation(trainer.next.type_signature.parameter[0], selected_clients_type)
def new_next(server_state, selected_clients):
  preprocessed_data = tff.federated_map(preprocessed_and_shuffled.dataset_computation, selected_clients)
  return trainer.next(server_state, preprocessed_data)

manual_trainer_with_preprocessing = tff.templates.IterativeProcess(initialize_fn=trainer.initialize, next_fn=new_next)

De hecho, esto es efectivamente lo que el ayudante que usamos está haciendo bajo el capó (además de realizar una verificación y manipulación de tipo adecuada). Incluso podríamos haber expresado la misma lógica ligeramente diferente, serializando preprocess_and_shuffle en un tff.Computation , y descomponer el federated_map en un paso que construye conjuntos de datos de la ONU-preprocesado y otra que corre preprocess_and_shuffle a cada cliente.

Podemos verificar que esta ruta más manual da como resultado cálculos con la misma firma de tipo que el ayudante de TFF (nombres de parámetros de módulo):

print(trainer_accepting_ids.next.type_signature)
print(manual_trainer_with_preprocessing.next.type_signature)
(<server_state=<model=<trainable=<float32[784,10],float32[10]>,non_trainable=<>>,optimizer_state=<int64>,delta_aggregate_state=<value_sum_process=<>,weight_sum_process=<>>,model_broadcast_state=<>>@SERVER,federated_dataset={string}@CLIENTS> -> <<model=<trainable=<float32[784,10],float32[10]>,non_trainable=<>>,optimizer_state=<int64>,delta_aggregate_state=<value_sum_process=<>,weight_sum_process=<>>,model_broadcast_state=<>>@SERVER,<broadcast=<>,aggregation=<mean_value=<>,mean_weight=<>>,train=<sparse_categorical_accuracy=float32,loss=float32>,stat=<num_examples=int64>>@SERVER>)
(<server_state=<model=<trainable=<float32[784,10],float32[10]>,non_trainable=<>>,optimizer_state=<int64>,delta_aggregate_state=<value_sum_process=<>,weight_sum_process=<>>,model_broadcast_state=<>>@SERVER,selected_clients={string}@CLIENTS> -> <<model=<trainable=<float32[784,10],float32[10]>,non_trainable=<>>,optimizer_state=<int64>,delta_aggregate_state=<value_sum_process=<>,weight_sum_process=<>>,model_broadcast_state=<>>@SERVER,<broadcast=<>,aggregation=<mean_value=<>,mean_weight=<>>,train=<sparse_categorical_accuracy=float32,loss=float32>,stat=<num_examples=int64>>@SERVER>)