Özel yinelemeli süreçte TFF iyileştiricilerini kullanın

TensorFlow.org'da görüntüleyin Google Colab'da çalıştırın Kaynağı GitHub'da görüntüleyin Not defterini indir

Bu bir alternatiftir Build'a Kendi Federe Öğrenme Algoritması öğretici ve simple_fedavg için özel bir iteratif süreç oluşturmak için örnek federe ortalama algoritması. Bu eğitimde kullanacağı TFF optimize yerine Keras optimize. TFF iyileştirici soyutlaması, bir TFF yinelemeli sürece dahil edilmesi daha kolay olması için durum içinde durum dışında olacak şekilde tasarlanmıştır. tff.learning API'ler de girdi argüman olarak TFF optimize kabul ediyoruz.

Başlamadan önce

Başlamadan önce, ortamınızın doğru şekilde kurulduğundan emin olmak için lütfen aşağıdakileri çalıştırın. Eğer bir selamlama görmüyorsanız, bakınız Kurulum talimatları için rehber.

!pip install --quiet --upgrade tensorflow-federated-nightly
!pip install --quiet --upgrade nest-asyncio

import nest_asyncio
nest_asyncio.apply()
import functools
import attr
import numpy as np
import tensorflow as tf
import tensorflow_federated as tff

Veri ve model hazırlama

EMNIST veri işleme ve model çok benzer simple_fedavg örneğin.

only_digits=True

# Load dataset.
emnist_train, emnist_test = tff.simulation.datasets.emnist.load_data(only_digits)

# Define preprocessing functions.
def preprocess_fn(dataset, batch_size=16):

  def batch_format_fn(element):
    return (tf.expand_dims(element['pixels'], -1), element['label'])

  return dataset.batch(batch_size).map(batch_format_fn)

# Preprocess and sample clients for prototyping.
train_client_ids = sorted(emnist_train.client_ids)
train_data = emnist_train.preprocess(preprocess_fn)
central_test_data = preprocess_fn(
    emnist_train.create_tf_dataset_for_client(train_client_ids[0]))

# Define model.
def create_keras_model():
  """The CNN model used in https://arxiv.org/abs/1602.05629."""
  data_format = 'channels_last'
  input_shape = [28, 28, 1]

  max_pool = functools.partial(
      tf.keras.layers.MaxPooling2D,
      pool_size=(2, 2),
      padding='same',
      data_format=data_format)
  conv2d = functools.partial(
      tf.keras.layers.Conv2D,
      kernel_size=5,
      padding='same',
      data_format=data_format,
      activation=tf.nn.relu)

  model = tf.keras.models.Sequential([
      conv2d(filters=32, input_shape=input_shape),
      max_pool(),
      conv2d(filters=64),
      max_pool(),
      tf.keras.layers.Flatten(),
      tf.keras.layers.Dense(512, activation=tf.nn.relu),
      tf.keras.layers.Dense(10 if only_digits else 62),
  ])

  return model

# Wrap as `tff.learning.Model`.
def model_fn():
  keras_model = create_keras_model()
  return tff.learning.from_keras_model(
      keras_model,
      input_spec=central_test_data.element_spec,
      loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True))

Özel yinelemeli süreç

Çoğu durumda, birleşik algoritmaların 4 ana bileşeni vardır:

  1. Sunucudan istemciye bir yayın adımı.
  2. Yerel bir istemci güncelleme adımı.
  3. İstemciden sunucuya yükleme adımı.
  4. Bir sunucu güncelleme adımı.

TFF, biz genellikle olarak federe algoritmaları temsil tff.templates.IterativeProcess (biz sadece bir dediğimiz IterativeProcess boyunca). Bu içeren bir sınıftır initialize ve next işlevleri. Burada, initialize sunucuyu başlatmak için kullanılır ve next federe algoritmanın bir haberleşme yuvarlak gerçekleştirecektir.

İstemci güncelleme adımında bir optimize edici ve sunucu güncelleme adımında başka bir optimize edici kullanacak olan birleşik ortalama (FedAvg) algoritmasını oluşturmak için farklı bileşenleri tanıtacağız. İstemci ve sunucu güncellemelerinin temel mantığı saf TF blokları olarak ifade edilebilir.

TF blokları: istemci ve sunucu güncellemesi

Her istemci üzerinde yerel client_optimizer başlatıldı ve istemci modeli ağırlıkları güncellemek için kullanılır. Sunucuda, server_optimizer önceki turdan durumunu kullanacak ve sonraki turda devlet güncelleyin.

@tf.function
def client_update(model, dataset, server_weights, client_optimizer):
  """Performs local training on the client's dataset."""
  # Initialize the client model with the current server weights.
  client_weights = model.trainable_variables
  # Assign the server weights to the client model.
  tf.nest.map_structure(lambda x, y: x.assign(y),
                        client_weights, server_weights)
  # Initialize the client optimizer.
  trainable_tensor_specs = tf.nest.map_structure(
          lambda v: tf.TensorSpec(v.shape, v.dtype), client_weights)
  optimizer_state = client_optimizer.initialize(trainable_tensor_specs)
  # Use the client_optimizer to update the local model.
  for batch in iter(dataset):
    with tf.GradientTape() as tape:
      # Compute a forward pass on the batch of data.
      outputs = model.forward_pass(batch)
    # Compute the corresponding gradient.
    grads = tape.gradient(outputs.loss, client_weights)
    # Apply the gradient using a client optimizer.
    optimizer_state, updated_weights = client_optimizer.next(
        optimizer_state, client_weights, grads)
    tf.nest.map_structure(lambda a, b: a.assign(b), 
                          client_weights, updated_weights)
  # Return model deltas.
  return tf.nest.map_structure(tf.subtract, client_weights, server_weights)
@attr.s(eq=False, frozen=True, slots=True)
class ServerState(object):
  trainable_weights = attr.ib()
  optimizer_state = attr.ib()

@tf.function
def server_update(server_state, mean_model_delta, server_optimizer):
  """Updates the server model weights."""
  # Use aggregated negative model delta as pseudo gradient. 
  negative_weights_delta = tf.nest.map_structure(
      lambda w: -1.0 * w, mean_model_delta)
  new_optimizer_state, updated_weights = server_optimizer.next(
      server_state.optimizer_state, server_state.trainable_weights, 
      negative_weights_delta)
  return tff.structure.update_struct(
      server_state,
      trainable_weights=updated_weights,
      optimizer_state=new_optimizer_state)

TFF blokları: tff.tf_computation ve tff.federated_computation

Artık düzenleme için TFF kullanıyoruz ve FedAvg için yinelemeli süreci oluşturuyoruz. Biz yukarıda tanımlanan TF blokları kaydırmak zorunda tff.tf_computation ve kullanım TFF yöntemleri tff.federated_broadcast , tff.federated_map , tff.federated_mean bir de tff.federated_computation fonksiyonu. Kullanımı kolaydır tff.learning.optimizers.Optimizer ile API'leri initialize ve next özel iteratif süreç tanımlanırken görev.

# 1. Server and client optimizer to be used.
server_optimizer = tff.learning.optimizers.build_sgdm(
    learning_rate=0.05, momentum=0.9)
client_optimizer = tff.learning.optimizers.build_sgdm(
    learning_rate=0.01)

# 2. Functions return initial state on server. 
@tff.tf_computation
def server_init():
  model = model_fn()
  trainable_tensor_specs = tf.nest.map_structure(
        lambda v: tf.TensorSpec(v.shape, v.dtype), model.trainable_variables)
  optimizer_state = server_optimizer.initialize(trainable_tensor_specs)
  return ServerState(
      trainable_weights=model.trainable_variables,
      optimizer_state=optimizer_state)

@tff.federated_computation
def server_init_tff():
  return tff.federated_value(server_init(), tff.SERVER)

# 3. One round of computation and communication.
server_state_type = server_init.type_signature.result
print('server_state_type:\n', 
      server_state_type.formatted_representation())
trainable_weights_type = server_state_type.trainable_weights
print('trainable_weights_type:\n', 
      trainable_weights_type.formatted_representation())

# 3-1. Wrap server and client TF blocks with `tff.tf_computation`.
@tff.tf_computation(server_state_type, trainable_weights_type)
def server_update_fn(server_state, model_delta):
  return server_update(server_state, model_delta, server_optimizer)

whimsy_model = model_fn()
tf_dataset_type = tff.SequenceType(whimsy_model.input_spec)
print('tf_dataset_type:\n', 
      tf_dataset_type.formatted_representation())
@tff.tf_computation(tf_dataset_type, trainable_weights_type)
def client_update_fn(dataset, server_weights):
  model = model_fn()
  return client_update(model, dataset, server_weights, client_optimizer)

# 3-2. Orchestration with `tff.federated_computation`.
federated_server_type = tff.FederatedType(server_state_type, tff.SERVER)
federated_dataset_type = tff.FederatedType(tf_dataset_type, tff.CLIENTS)
@tff.federated_computation(federated_server_type, federated_dataset_type)
def run_one_round(server_state, federated_dataset):
  # Server-to-client broadcast.
  server_weights_at_client = tff.federated_broadcast(
      server_state.trainable_weights)
  # Local client update.
  model_deltas = tff.federated_map(
      client_update_fn, (federated_dataset, server_weights_at_client))
  # Client-to-server upload and aggregation.
  mean_model_delta = tff.federated_mean(model_deltas)
  # Server update.
  server_state = tff.federated_map(
      server_update_fn, (server_state, mean_model_delta))
  return server_state

# 4. Build the iterative process for FedAvg.
fedavg_process = tff.templates.IterativeProcess(
    initialize_fn=server_init_tff, next_fn=run_one_round)
print('type signature of `initialize`:\n', 
      fedavg_process.initialize.type_signature.formatted_representation())
print('type signature of `next`:\n', 
      fedavg_process.next.type_signature.formatted_representation())
server_state_type:
 <
  trainable_weights=<
    float32[5,5,1,32],
    float32[32],
    float32[5,5,32,64],
    float32[64],
    float32[3136,512],
    float32[512],
    float32[512,10],
    float32[10]
  >,
  optimizer_state=<
    float32[5,5,1,32],
    float32[32],
    float32[5,5,32,64],
    float32[64],
    float32[3136,512],
    float32[512],
    float32[512,10],
    float32[10]
  >
>
trainable_weights_type:
 <
  float32[5,5,1,32],
  float32[32],
  float32[5,5,32,64],
  float32[64],
  float32[3136,512],
  float32[512],
  float32[512,10],
  float32[10]
>
tf_dataset_type:
 <
  float32[?,28,28,1],
  int32[?]
>*
type signature of `initialize`:
 ( -> <
  trainable_weights=<
    float32[5,5,1,32],
    float32[32],
    float32[5,5,32,64],
    float32[64],
    float32[3136,512],
    float32[512],
    float32[512,10],
    float32[10]
  >,
  optimizer_state=<
    float32[5,5,1,32],
    float32[32],
    float32[5,5,32,64],
    float32[64],
    float32[3136,512],
    float32[512],
    float32[512,10],
    float32[10]
  >
>@SERVER)
type signature of `next`:
 (<
  server_state=<
    trainable_weights=<
      float32[5,5,1,32],
      float32[32],
      float32[5,5,32,64],
      float32[64],
      float32[3136,512],
      float32[512],
      float32[512,10],
      float32[10]
    >,
    optimizer_state=<
      float32[5,5,1,32],
      float32[32],
      float32[5,5,32,64],
      float32[64],
      float32[3136,512],
      float32[512],
      float32[512,10],
      float32[10]
    >
  >@SERVER,
  federated_dataset={<
    float32[?,28,28,1],
    int32[?]
  >*}@CLIENTS
> -> <
  trainable_weights=<
    float32[5,5,1,32],
    float32[32],
    float32[5,5,32,64],
    float32[64],
    float32[3136,512],
    float32[512],
    float32[512,10],
    float32[10]
  >,
  optimizer_state=<
    float32[5,5,1,32],
    float32[32],
    float32[5,5,32,64],
    float32[64],
    float32[3136,512],
    float32[512],
    float32[512,10],
    float32[10]
  >
>@SERVER)

Algoritmayı değerlendirmek

Performansı merkezi bir değerlendirme veri setinde değerlendiririz.

def evaluate(server_state):
  keras_model = create_keras_model()
  tf.nest.map_structure(
      lambda var, t: var.assign(t),
      keras_model.trainable_weights, server_state.trainable_weights)
  metric = tf.keras.metrics.SparseCategoricalAccuracy()
  for batch in iter(central_test_data):
    preds = keras_model(batch[0], training=False)
    metric.update_state(y_true=batch[1], y_pred=preds)
  return metric.result().numpy()
server_state = fedavg_process.initialize()
acc = evaluate(server_state)
print('Initial test accuracy', acc)

# Evaluate after a few rounds
CLIENTS_PER_ROUND=2
sampled_clients = train_client_ids[:CLIENTS_PER_ROUND]
sampled_train_data = [
    train_data.create_tf_dataset_for_client(client)
    for client in sampled_clients]
for round in range(20):
  server_state = fedavg_process.next(server_state, sampled_train_data)
acc = evaluate(server_state)
print('Test accuracy', acc)
Initial test accuracy 0.09677419
Test accuracy 0.13978495