הדמיות TFF עם מאיצים

מדריך זה יתאר כיצד להתקין הדמיות TFF עם מאיצים. אנו מתמקדים לעת עתה במכשיר יחיד (רב) של מכונה אחת ונעדכן הדרכה זו עם הגדרות ריבוי מכונות ו- TPU.

צפה ב- TensorFlow.org הפעל בגוגל קולאב צפה במקור ב- GitHub הורד מחברת

לפני שנתחיל

ראשית, בואו נוודא שהמחברת מחוברת ל- backend שמרכיבים את הרכיבים הרלוונטיים.

!pip install --quiet --upgrade tensorflow_federated_nightly
!pip install --quiet --upgrade nest_asyncio
!pip install -U tensorboard_plugin_profile

import nest_asyncio
nest_asyncio.apply()
%load_ext tensorboard
import collections
import time

import numpy as np
import tensorflow as tf
import tensorflow_federated as tff

בדוק אם TF יכול לאתר GPUs פיזיים וליצור סביבה וירטואלית מרובת GPU עבור הדמיות TFF GPU. לשני המעבדים הגרפיים הווירטואליים יהיה זיכרון מוגבל כדי להדגים כיצד להגדיר את זמן הריצה של TFF.

gpu_devices = tf.config.list_physical_devices('GPU')
if not gpu_devices:
  raise ValueError('Cannot detect physical GPU device in TF')
tf.config.set_logical_device_configuration(
    gpu_devices[0], 
    [tf.config.LogicalDeviceConfiguration(memory_limit=1024),
     tf.config.LogicalDeviceConfiguration(memory_limit=1024)])
tf.config.list_logical_devices()
[LogicalDevice(name='/device:CPU:0', device_type='CPU'),
 LogicalDevice(name='/device:GPU:0', device_type='GPU'),
 LogicalDevice(name='/device:GPU:1', device_type='GPU')]

הפעל את הדוגמה הבאה "שלום עולם" כדי לוודא שסביבת TFF מוגדרת כהלכה. אם זה לא עובד, עיין במדריך ההתקנה לקבלת הוראות.

@tff.federated_computation
def hello_world():
  return 'Hello, World!'

hello_world()
b'Hello, World!'

הגדרת ניסוי של EMNIST

במדריך זה אנו מכשירים מסווג תמונות EMNIST עם אלגוריתם ממוצע ממוצע. נתחיל בטעינה של דוגמת MNIST מאתר TFF.

emnist_train, emnist_test = tff.simulation.datasets.emnist.load_data(only_digits=True)

אנו מגדירים פונקציה שעובדת מראש על דוגמת EMNIST בעקבות הדוגמה simple_fedavg . שים לב שהארגומנט client_epochs_per_round שולט במספר התקופות המקומיות על לקוחות בלמידה מאוחדת.

def preprocess_emnist_dataset(client_epochs_per_round, batch_size, test_batch_size):

  def element_fn(element):
    return collections.OrderedDict(
        x=tf.expand_dims(element['pixels'], -1), y=element['label'])

  def preprocess_train_dataset(dataset):
    # Use buffer_size same as the maximum client dataset size,
    # 418 for Federated EMNIST
    return dataset.map(element_fn).shuffle(buffer_size=418).repeat(
        count=client_epochs_per_round).batch(batch_size, drop_remainder=False)

  def preprocess_test_dataset(dataset):
    return dataset.map(element_fn).batch(test_batch_size, drop_remainder=False)

  train_set = emnist_train.preprocess(preprocess_train_dataset)
  test_set = preprocess_test_dataset(
      emnist_test.create_tf_dataset_from_all_clients())
  return train_set, test_set

אנו משתמשים במודל דמוי VGG, כלומר, לכל בלוק יש שני התפתחויות 3x3 ומספר המסננים מוכפל כאשר מדגמים את מפות התכונות.

def _conv_3x3(input_tensor, filters, strides):
  """2D Convolutional layer with kernel size 3x3."""

  x = tf.keras.layers.Conv2D(
      filters=filters,
      strides=strides,
      kernel_size=3,
      padding='same',
      kernel_initializer='he_normal',
      use_bias=False,
  )(input_tensor)
  return x


def _basic_block(input_tensor, filters, strides):
  """A block of two 3x3 conv layers."""

  x = input_tensor
  x = _conv_3x3(x, filters, strides)
  x = tf.keras.layers.Activation('relu')(x)

  x = _conv_3x3(x, filters, 1)
  x = tf.keras.layers.Activation('relu')(x)
  return x


def _vgg_block(input_tensor, size, filters, strides):
  """A stack of basic blocks."""
  x = _basic_block(input_tensor, filters, strides=strides)
  for _ in range(size - 1):
      x = _basic_block(x, filters, strides=1)
  return x


def create_cnn(num_blocks, conv_width_multiplier=1, num_classes=10):
  """Create a VGG-like CNN model. 

  The CNN has (6*num_blocks + 2) layers.
  """
  input_shape = (28, 28, 1)  # channels_last
  img_input = tf.keras.layers.Input(shape=input_shape)
  x = img_input
  x = tf.image.per_image_standardization(x)

  x = _conv_3x3(x, 16 * conv_width_multiplier, 1)
  x = _vgg_block(x, size=num_blocks, filters=16 * conv_width_multiplier, strides=1)
  x = _vgg_block(x, size=num_blocks, filters=32 * conv_width_multiplier, strides=2)
  x = _vgg_block(x, size=num_blocks, filters=64 * conv_width_multiplier, strides=2)

  x = tf.keras.layers.GlobalAveragePooling2D()(x)
  x = tf.keras.layers.Dense(num_classes)(x)

  model = tf.keras.models.Model(
      img_input,
      x,
      name='cnn-{}-{}'.format(6 * num_blocks + 2, conv_width_multiplier))
  return model

עכשיו נגדיר את לולאת האימונים ל- EMNIST. שים לב ש- use_experimental_simulation_loop=True ב- tff.learning.build_federated_averaging_process מוצע לסימולציה של TFF ביצועים, ונדרש לנצל את מרובי ה- GPU במכונה אחת. ראה דוגמה של simple_fedavg כיצד להגדיר אלגוריתם למידה מאוחד של פדרציה בעל ביצועים גבוהים ב- GPUs, אחת התכונות העיקריות היא להשתמש במפורש for ... iter(dataset) לאימון לולאות.

def keras_evaluate(model, test_data, metric):
  metric.reset_states()
  for batch in test_data:
    preds = model(batch['x'], training=False)
    metric.update_state(y_true=batch['y'], y_pred=preds)
  return metric.result()


def run_federated_training(client_epochs_per_round, 
                           train_batch_size, 
                           test_batch_size, 
                           cnn_num_blocks, 
                           conv_width_multiplier,
                           server_learning_rate, 
                           client_learning_rate, 
                           total_rounds, 
                           clients_per_round, 
                           rounds_per_eval,
                           logdir='logdir'):

  train_data, test_data = preprocess_emnist_dataset(
      client_epochs_per_round, train_batch_size, test_batch_size)
  data_spec = test_data.element_spec

  def _model_fn():
    keras_model = create_cnn(cnn_num_blocks, conv_width_multiplier)
    loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
    return tff.learning.from_keras_model(
        keras_model, input_spec=data_spec, loss=loss)

  def _server_optimizer_fn():
    return tf.keras.optimizers.SGD(learning_rate=server_learning_rate)

  def _client_optimizer_fn():
    return tf.keras.optimizers.SGD(learning_rate=client_learning_rate)

  iterative_process = tff.learning.build_federated_averaging_process(
      model_fn=_model_fn, 
      server_optimizer_fn=_server_optimizer_fn, 
      client_optimizer_fn=_client_optimizer_fn, 
      use_experimental_simulation_loop=True)

  metric = tf.keras.metrics.SparseCategoricalAccuracy(name='test_accuracy')
  eval_model = create_cnn(cnn_num_blocks, conv_width_multiplier)
  logging.info(eval_model.summary())

  server_state = iterative_process.initialize()
  start_time = time.time()
  for round_num in range(total_rounds):
    sampled_clients = np.random.choice(
        train_data.client_ids,
        size=clients_per_round,
        replace=False)
    sampled_train_data = [
        train_data.create_tf_dataset_for_client(client)
        for client in sampled_clients
    ]
    if round_num == total_rounds-1:
      with tf.profiler.experimental.Profile(logdir):
        server_state, train_metrics = iterative_process.next(
            server_state, sampled_train_data)
    else:
      server_state, train_metrics = iterative_process.next(
            server_state, sampled_train_data)
    print(f'Round {round_num} training loss: {train_metrics["train"]["loss"]}, '
     f'time: {(time.time()-start_time)/(round_num+1.)} secs')
    if round_num % rounds_per_eval == 0 or round_num == total_rounds-1:
      server_state.model.assign_weights_to(eval_model)
      accuracy = keras_evaluate(eval_model, test_data, metric)
      print(f'Round {round_num} validation accuracy: {accuracy * 100.0}')

ביצוע GPU יחיד

זמן הריצה המוגדר כברירת מחדל של TFF זהה ל- TF: כאשר מסופקים GPUs, ה- GPU הראשון ייבחר לביצוע. אנו עורכים את האימונים המאוחדים שהוגדרו בעבר למספר סיבובים עם מודל קטן יחסית. סבב הביצוע האחרון tf.profiler עם tf.profiler tensorboard ידי tensorboard . פרופיל אימת כי נעשה שימוש ב- GPU הראשון.

run_federated_training(
    client_epochs_per_round=1, 
    train_batch_size=16, 
    test_batch_size=128, 
    cnn_num_blocks=2, 
    conv_width_multiplier=4,
    server_learning_rate=1.0, 
    client_learning_rate=0.01,
    total_rounds=10,
    clients_per_round=16, 
    rounds_per_eval=2,
    )
Model: "cnn-14-4"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_1 (InputLayer)         [(None, 28, 28, 1)]       0         
_________________________________________________________________
tf.image.per_image_standardi (None, 28, 28, 1)         0         
_________________________________________________________________
conv2d (Conv2D)              (None, 28, 28, 64)        576       
_________________________________________________________________
conv2d_1 (Conv2D)            (None, 28, 28, 64)        36864     
_________________________________________________________________
activation (Activation)      (None, 28, 28, 64)        0         
_________________________________________________________________
conv2d_2 (Conv2D)            (None, 28, 28, 64)        36864     
_________________________________________________________________
activation_1 (Activation)    (None, 28, 28, 64)        0         
_________________________________________________________________
conv2d_3 (Conv2D)            (None, 28, 28, 64)        36864     
_________________________________________________________________
activation_2 (Activation)    (None, 28, 28, 64)        0         
_________________________________________________________________
conv2d_4 (Conv2D)            (None, 28, 28, 64)        36864     
_________________________________________________________________
activation_3 (Activation)    (None, 28, 28, 64)        0         
_________________________________________________________________
conv2d_5 (Conv2D)            (None, 14, 14, 128)       73728     
_________________________________________________________________
activation_4 (Activation)    (None, 14, 14, 128)       0         
_________________________________________________________________
conv2d_6 (Conv2D)            (None, 14, 14, 128)       147456    
_________________________________________________________________
activation_5 (Activation)    (None, 14, 14, 128)       0         
_________________________________________________________________
conv2d_7 (Conv2D)            (None, 14, 14, 128)       147456    
_________________________________________________________________
activation_6 (Activation)    (None, 14, 14, 128)       0         
_________________________________________________________________
conv2d_8 (Conv2D)            (None, 14, 14, 128)       147456    
_________________________________________________________________
activation_7 (Activation)    (None, 14, 14, 128)       0         
_________________________________________________________________
conv2d_9 (Conv2D)            (None, 7, 7, 256)         294912    
_________________________________________________________________
activation_8 (Activation)    (None, 7, 7, 256)         0         
_________________________________________________________________
conv2d_10 (Conv2D)           (None, 7, 7, 256)         589824    
_________________________________________________________________
activation_9 (Activation)    (None, 7, 7, 256)         0         
_________________________________________________________________
conv2d_11 (Conv2D)           (None, 7, 7, 256)         589824    
_________________________________________________________________
activation_10 (Activation)   (None, 7, 7, 256)         0         
_________________________________________________________________
conv2d_12 (Conv2D)           (None, 7, 7, 256)         589824    
_________________________________________________________________
activation_11 (Activation)   (None, 7, 7, 256)         0         
_________________________________________________________________
global_average_pooling2d (Gl (None, 256)               0         
_________________________________________________________________
dense (Dense)                (None, 10)                2570      
=================================================================
Total params: 2,731,082
Trainable params: 2,731,082
Non-trainable params: 0
_________________________________________________________________
Round 0 training loss: 2.4688243865966797, time: 13.382015466690063 secs
Round 0 validation accuracy: 15.240497589111328
Round 1 training loss: 2.3217368125915527, time: 9.311999917030334 secs
Round 2 training loss: 2.3100595474243164, time: 6.972411632537842 secs
Round 2 validation accuracy: 11.226489067077637
Round 3 training loss: 2.303222417831421, time: 6.467299699783325 secs
Round 4 training loss: 2.2976326942443848, time: 5.526083135604859 secs
Round 4 validation accuracy: 11.224040031433105
Round 5 training loss: 2.2919719219207764, time: 5.468692660331726 secs
Round 6 training loss: 2.2911534309387207, time: 4.935825347900391 secs
Round 6 validation accuracy: 11.833855628967285
Round 7 training loss: 2.2871201038360596, time: 4.918408691883087 secs
Round 8 training loss: 2.2818832397460938, time: 4.602836343977186 secs
Round 8 validation accuracy: 11.385677337646484
Round 9 training loss: 2.2790346145629883, time: 4.99558527469635 secs
Round 9 validation accuracy: 11.226489067077637
%tensorboard --logdir=logdir --port=0

ביצוע מעבד

לשם השוואה, תן לנו להגדיר זמן ריצה של TFF לביצוע מעבד. ביצוע המעבד איטי רק במודל קטן יחסית זה.

cpu_device = tf.config.list_logical_devices('CPU')[0]
tff.backends.native.set_local_execution_context(
    server_tf_device=cpu_device, client_tf_devices=[cpu_device])

run_federated_training(
    client_epochs_per_round=1, 
    train_batch_size=16, 
    test_batch_size=128, 
    cnn_num_blocks=2, 
    conv_width_multiplier=4,
    server_learning_rate=1.0, 
    client_learning_rate=0.01,
    total_rounds=10,
    clients_per_round=16, 
    rounds_per_eval=2,
    )
Model: "cnn-14-4"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_2 (InputLayer)         [(None, 28, 28, 1)]       0         
_________________________________________________________________
tf.image.per_image_standardi (None, 28, 28, 1)         0         
_________________________________________________________________
conv2d_13 (Conv2D)           (None, 28, 28, 64)        576       
_________________________________________________________________
conv2d_14 (Conv2D)           (None, 28, 28, 64)        36864     
_________________________________________________________________
activation_12 (Activation)   (None, 28, 28, 64)        0         
_________________________________________________________________
conv2d_15 (Conv2D)           (None, 28, 28, 64)        36864     
_________________________________________________________________
activation_13 (Activation)   (None, 28, 28, 64)        0         
_________________________________________________________________
conv2d_16 (Conv2D)           (None, 28, 28, 64)        36864     
_________________________________________________________________
activation_14 (Activation)   (None, 28, 28, 64)        0         
_________________________________________________________________
conv2d_17 (Conv2D)           (None, 28, 28, 64)        36864     
_________________________________________________________________
activation_15 (Activation)   (None, 28, 28, 64)        0         
_________________________________________________________________
conv2d_18 (Conv2D)           (None, 14, 14, 128)       73728     
_________________________________________________________________
activation_16 (Activation)   (None, 14, 14, 128)       0         
_________________________________________________________________
conv2d_19 (Conv2D)           (None, 14, 14, 128)       147456    
_________________________________________________________________
activation_17 (Activation)   (None, 14, 14, 128)       0         
_________________________________________________________________
conv2d_20 (Conv2D)           (None, 14, 14, 128)       147456    
_________________________________________________________________
activation_18 (Activation)   (None, 14, 14, 128)       0         
_________________________________________________________________
conv2d_21 (Conv2D)           (None, 14, 14, 128)       147456    
_________________________________________________________________
activation_19 (Activation)   (None, 14, 14, 128)       0         
_________________________________________________________________
conv2d_22 (Conv2D)           (None, 7, 7, 256)         294912    
_________________________________________________________________
activation_20 (Activation)   (None, 7, 7, 256)         0         
_________________________________________________________________
conv2d_23 (Conv2D)           (None, 7, 7, 256)         589824    
_________________________________________________________________
activation_21 (Activation)   (None, 7, 7, 256)         0         
_________________________________________________________________
conv2d_24 (Conv2D)           (None, 7, 7, 256)         589824    
_________________________________________________________________
activation_22 (Activation)   (None, 7, 7, 256)         0         
_________________________________________________________________
conv2d_25 (Conv2D)           (None, 7, 7, 256)         589824    
_________________________________________________________________
activation_23 (Activation)   (None, 7, 7, 256)         0         
_________________________________________________________________
global_average_pooling2d_1 ( (None, 256)               0         
_________________________________________________________________
dense_1 (Dense)              (None, 10)                2570      
=================================================================
Total params: 2,731,082
Trainable params: 2,731,082
Non-trainable params: 0
_________________________________________________________________
Round 0 training loss: 2.4787657260894775, time: 15.264191627502441 secs
Round 0 validation accuracy: 12.191418647766113
Round 1 training loss: 2.3097336292266846, time: 11.785032272338867 secs
Round 2 training loss: 2.3062121868133545, time: 9.677561124165853 secs
Round 2 validation accuracy: 11.415066719055176
Round 3 training loss: 2.2982261180877686, time: 9.301376760005951 secs
Round 4 training loss: 2.2953946590423584, time: 8.377780866622924 secs
Round 4 validation accuracy: 20.537813186645508
Round 5 training loss: 2.290337324142456, time: 8.385509928067526 secs
Round 6 training loss: 2.2842795848846436, time: 7.809031554630825 secs
Round 6 validation accuracy: 11.934267044067383
Round 7 training loss: 2.2752432823181152, time: 7.8433578312397 secs
Round 8 training loss: 2.2698657512664795, time: 7.478067080179851 secs
Round 8 validation accuracy: 26.16330337524414
Round 9 training loss: 2.2609798908233643, time: 7.632814192771912 secs
Round 9 validation accuracy: 23.079936981201172

ביצוע רב-GPU

קל להגדיר את TFF לביצוע רב-GPU. שים לב שהדרכת לקוחות מקבילה ב- TFF. בהגדרת ריבוי GPU, לקוחות יוקצו לריבוי GPU בצורה סביבה. שני ביצועי ה- GPU הבאים אינם מהירים מביצוע GPU יחיד, מכיוון שאימון הלקוח מקביל הן בהגדרות GPU יחיד והן במספר רב של הגדרות GPU, ובהגדרת multiGPU יש שני GPUs וירטואליים שנוצרו מ- GPU פיזי יחיד.

gpu_devices = tf.config.list_logical_devices('GPU')
tff.backends.native.set_local_execution_context(client_tf_devices=gpu_devices)

run_federated_training(
    client_epochs_per_round=1, 
    train_batch_size=16, 
    test_batch_size=128, 
    cnn_num_blocks=2, 
    conv_width_multiplier=4,
    server_learning_rate=1.0, 
    client_learning_rate=0.01,
    total_rounds=10,
    clients_per_round=16, 
    rounds_per_eval=2,
    logdir='multigpu'
    )
Model: "cnn-14-4"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_3 (InputLayer)         [(None, 28, 28, 1)]       0         
_________________________________________________________________
tf.image.per_image_standardi (None, 28, 28, 1)         0         
_________________________________________________________________
conv2d_26 (Conv2D)           (None, 28, 28, 64)        576       
_________________________________________________________________
conv2d_27 (Conv2D)           (None, 28, 28, 64)        36864     
_________________________________________________________________
activation_24 (Activation)   (None, 28, 28, 64)        0         
_________________________________________________________________
conv2d_28 (Conv2D)           (None, 28, 28, 64)        36864     
_________________________________________________________________
activation_25 (Activation)   (None, 28, 28, 64)        0         
_________________________________________________________________
conv2d_29 (Conv2D)           (None, 28, 28, 64)        36864     
_________________________________________________________________
activation_26 (Activation)   (None, 28, 28, 64)        0         
_________________________________________________________________
conv2d_30 (Conv2D)           (None, 28, 28, 64)        36864     
_________________________________________________________________
activation_27 (Activation)   (None, 28, 28, 64)        0         
_________________________________________________________________
conv2d_31 (Conv2D)           (None, 14, 14, 128)       73728     
_________________________________________________________________
activation_28 (Activation)   (None, 14, 14, 128)       0         
_________________________________________________________________
conv2d_32 (Conv2D)           (None, 14, 14, 128)       147456    
_________________________________________________________________
activation_29 (Activation)   (None, 14, 14, 128)       0         
_________________________________________________________________
conv2d_33 (Conv2D)           (None, 14, 14, 128)       147456    
_________________________________________________________________
activation_30 (Activation)   (None, 14, 14, 128)       0         
_________________________________________________________________
conv2d_34 (Conv2D)           (None, 14, 14, 128)       147456    
_________________________________________________________________
activation_31 (Activation)   (None, 14, 14, 128)       0         
_________________________________________________________________
conv2d_35 (Conv2D)           (None, 7, 7, 256)         294912    
_________________________________________________________________
activation_32 (Activation)   (None, 7, 7, 256)         0         
_________________________________________________________________
conv2d_36 (Conv2D)           (None, 7, 7, 256)         589824    
_________________________________________________________________
activation_33 (Activation)   (None, 7, 7, 256)         0         
_________________________________________________________________
conv2d_37 (Conv2D)           (None, 7, 7, 256)         589824    
_________________________________________________________________
activation_34 (Activation)   (None, 7, 7, 256)         0         
_________________________________________________________________
conv2d_38 (Conv2D)           (None, 7, 7, 256)         589824    
_________________________________________________________________
activation_35 (Activation)   (None, 7, 7, 256)         0         
_________________________________________________________________
global_average_pooling2d_2 ( (None, 256)               0         
_________________________________________________________________
dense_2 (Dense)              (None, 10)                2570      
=================================================================
Total params: 2,731,082
Trainable params: 2,731,082
Non-trainable params: 0
_________________________________________________________________
Round 0 training loss: 2.911365270614624, time: 12.759389877319336 secs
Round 0 validation accuracy: 9.541536331176758
Round 1 training loss: 2.3175694942474365, time: 9.202919721603394 secs
Round 2 training loss: 2.311001777648926, time: 6.802880525588989 secs
Round 2 validation accuracy: 9.911344528198242
Round 3 training loss: 2.3105244636535645, time: 6.611470937728882 secs
Round 4 training loss: 2.3082072734832764, time: 5.678833389282227 secs
Round 4 validation accuracy: 10.212578773498535
Round 5 training loss: 2.304673671722412, time: 5.5404335260391235 secs
Round 6 training loss: 2.3035168647766113, time: 5.008027451378958 secs
Round 6 validation accuracy: 9.935834884643555
Round 7 training loss: 2.3052737712860107, time: 5.1173741817474365 secs
Round 8 training loss: 2.3007171154022217, time: 4.745321141348945 secs
Round 8 validation accuracy: 10.768514633178711
Round 9 training loss: 2.302018404006958, time: 5.0809732437133786 secs
Round 9 validation accuracy: 12.311422348022461

דגם גדול יותר ו- OOM

הבה נפעיל מודל גדול יותר על מעבד עם פחות סבבים מאוחדים.

cpu_device = tf.config.list_logical_devices('CPU')[0]
tff.backends.native.set_local_execution_context(
    server_tf_device=cpu_device, client_tf_devices=[cpu_device])

run_federated_training(
    client_epochs_per_round=1, 
    train_batch_size=16, 
    test_batch_size=128, 
    cnn_num_blocks=4, 
    conv_width_multiplier=4,
    server_learning_rate=1.0, 
    client_learning_rate=0.01,
    total_rounds=5,
    clients_per_round=16, 
    rounds_per_eval=2,
    )
Model: "cnn-26-4"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_4 (InputLayer)         [(None, 28, 28, 1)]       0         
_________________________________________________________________
tf.image.per_image_standardi (None, 28, 28, 1)         0         
_________________________________________________________________
conv2d_39 (Conv2D)           (None, 28, 28, 64)        576       
_________________________________________________________________
conv2d_40 (Conv2D)           (None, 28, 28, 64)        36864     
_________________________________________________________________
activation_36 (Activation)   (None, 28, 28, 64)        0         
_________________________________________________________________
conv2d_41 (Conv2D)           (None, 28, 28, 64)        36864     
_________________________________________________________________
activation_37 (Activation)   (None, 28, 28, 64)        0         
_________________________________________________________________
conv2d_42 (Conv2D)           (None, 28, 28, 64)        36864     
_________________________________________________________________
activation_38 (Activation)   (None, 28, 28, 64)        0         
_________________________________________________________________
conv2d_43 (Conv2D)           (None, 28, 28, 64)        36864     
_________________________________________________________________
activation_39 (Activation)   (None, 28, 28, 64)        0         
_________________________________________________________________
conv2d_44 (Conv2D)           (None, 28, 28, 64)        36864     
_________________________________________________________________
activation_40 (Activation)   (None, 28, 28, 64)        0         
_________________________________________________________________
conv2d_45 (Conv2D)           (None, 28, 28, 64)        36864     
_________________________________________________________________
activation_41 (Activation)   (None, 28, 28, 64)        0         
_________________________________________________________________
conv2d_46 (Conv2D)           (None, 28, 28, 64)        36864     
_________________________________________________________________
activation_42 (Activation)   (None, 28, 28, 64)        0         
_________________________________________________________________
conv2d_47 (Conv2D)           (None, 28, 28, 64)        36864     
_________________________________________________________________
activation_43 (Activation)   (None, 28, 28, 64)        0         
_________________________________________________________________
conv2d_48 (Conv2D)           (None, 14, 14, 128)       73728     
_________________________________________________________________
activation_44 (Activation)   (None, 14, 14, 128)       0         
_________________________________________________________________
conv2d_49 (Conv2D)           (None, 14, 14, 128)       147456    
_________________________________________________________________
activation_45 (Activation)   (None, 14, 14, 128)       0         
_________________________________________________________________
conv2d_50 (Conv2D)           (None, 14, 14, 128)       147456    
_________________________________________________________________
activation_46 (Activation)   (None, 14, 14, 128)       0         
_________________________________________________________________
conv2d_51 (Conv2D)           (None, 14, 14, 128)       147456    
_________________________________________________________________
activation_47 (Activation)   (None, 14, 14, 128)       0         
_________________________________________________________________
conv2d_52 (Conv2D)           (None, 14, 14, 128)       147456    
_________________________________________________________________
activation_48 (Activation)   (None, 14, 14, 128)       0         
_________________________________________________________________
conv2d_53 (Conv2D)           (None, 14, 14, 128)       147456    
_________________________________________________________________
activation_49 (Activation)   (None, 14, 14, 128)       0         
_________________________________________________________________
conv2d_54 (Conv2D)           (None, 14, 14, 128)       147456    
_________________________________________________________________
activation_50 (Activation)   (None, 14, 14, 128)       0         
_________________________________________________________________
conv2d_55 (Conv2D)           (None, 14, 14, 128)       147456    
_________________________________________________________________
activation_51 (Activation)   (None, 14, 14, 128)       0         
_________________________________________________________________
conv2d_56 (Conv2D)           (None, 7, 7, 256)         294912    
_________________________________________________________________
activation_52 (Activation)   (None, 7, 7, 256)         0         
_________________________________________________________________
conv2d_57 (Conv2D)           (None, 7, 7, 256)         589824    
_________________________________________________________________
activation_53 (Activation)   (None, 7, 7, 256)         0         
_________________________________________________________________
conv2d_58 (Conv2D)           (None, 7, 7, 256)         589824    
_________________________________________________________________
activation_54 (Activation)   (None, 7, 7, 256)         0         
_________________________________________________________________
conv2d_59 (Conv2D)           (None, 7, 7, 256)         589824    
_________________________________________________________________
activation_55 (Activation)   (None, 7, 7, 256)         0         
_________________________________________________________________
conv2d_60 (Conv2D)           (None, 7, 7, 256)         589824    
_________________________________________________________________
activation_56 (Activation)   (None, 7, 7, 256)         0         
_________________________________________________________________
conv2d_61 (Conv2D)           (None, 7, 7, 256)         589824    
_________________________________________________________________
activation_57 (Activation)   (None, 7, 7, 256)         0         
_________________________________________________________________
conv2d_62 (Conv2D)           (None, 7, 7, 256)         589824    
_________________________________________________________________
activation_58 (Activation)   (None, 7, 7, 256)         0         
_________________________________________________________________
conv2d_63 (Conv2D)           (None, 7, 7, 256)         589824    
_________________________________________________________________
activation_59 (Activation)   (None, 7, 7, 256)         0         
_________________________________________________________________
global_average_pooling2d_3 ( (None, 256)               0         
_________________________________________________________________
dense_3 (Dense)              (None, 10)                2570      
=================================================================
Total params: 5,827,658
Trainable params: 5,827,658
Non-trainable params: 0
_________________________________________________________________
Round 0 training loss: 2.437223434448242, time: 24.121686458587646 secs
Round 0 validation accuracy: 9.024785041809082
Round 1 training loss: 2.3081459999084473, time: 19.48685622215271 secs
Round 2 training loss: 2.305305242538452, time: 15.73950457572937 secs
Round 2 validation accuracy: 9.791339874267578
Round 3 training loss: 2.303149700164795, time: 15.194068729877472 secs
Round 4 training loss: 2.3026506900787354, time: 14.036769819259643 secs
Round 4 validation accuracy: 12.193867683410645

מודל זה עשוי לפגוע בבעיית זיכרון ב- GPU יחיד. המעבר מניסויי מעבד בקנה מידה גדול לסימולציית GPU יכול להיות מוגבל על ידי שימוש בזיכרון מכיוון של- GPUs יש לעתים קרובות זיכרונות מוגבלים. ישנם מספר פרמטרים שניתן לכוון בזמן הריצה של TFF כדי להקל על בעיית OOM

# Single GPU execution might hit OOM. 
gpu_devices = tf.config.list_logical_devices('GPU')
tff.backends.native.set_local_execution_context(client_tf_devices=[gpu_devices[0]])

try:
  run_federated_training(
      client_epochs_per_round=1, 
      train_batch_size=16, 
      test_batch_size=128, 
      cnn_num_blocks=4, 
      conv_width_multiplier=4,
      server_learning_rate=1.0, 
      client_learning_rate=0.01,
      total_rounds=5,
      clients_per_round=16, 
      rounds_per_eval=2,
      )
except ResourceExhaustedError as e:
  print(e)
# Control concurrency by `clients_per_thread`.
gpu_devices = tf.config.list_logical_devices('GPU')
tff.backends.native.set_local_execution_context(
    client_tf_devices=[gpu_devices[0]], clients_per_thread=2)

run_federated_training(
    client_epochs_per_round=1, 
    train_batch_size=16, 
    test_batch_size=128, 
    cnn_num_blocks=4, 
    conv_width_multiplier=4,
    server_learning_rate=1.0, 
    client_learning_rate=0.01,
    total_rounds=5,
    clients_per_round=16, 
    rounds_per_eval=2,
    )
Model: "cnn-26-4"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_1 (InputLayer)         [(None, 28, 28, 1)]       0         
_________________________________________________________________
tf.image.per_image_standardi (None, 28, 28, 1)         0         
_________________________________________________________________
conv2d (Conv2D)              (None, 28, 28, 64)        576       
_________________________________________________________________
conv2d_1 (Conv2D)            (None, 28, 28, 64)        36864     
_________________________________________________________________
activation (Activation)      (None, 28, 28, 64)        0         
_________________________________________________________________
conv2d_2 (Conv2D)            (None, 28, 28, 64)        36864     
_________________________________________________________________
activation_1 (Activation)    (None, 28, 28, 64)        0         
_________________________________________________________________
conv2d_3 (Conv2D)            (None, 28, 28, 64)        36864     
_________________________________________________________________
activation_2 (Activation)    (None, 28, 28, 64)        0         
_________________________________________________________________
conv2d_4 (Conv2D)            (None, 28, 28, 64)        36864     
_________________________________________________________________
activation_3 (Activation)    (None, 28, 28, 64)        0         
_________________________________________________________________
conv2d_5 (Conv2D)            (None, 28, 28, 64)        36864     
_________________________________________________________________
activation_4 (Activation)    (None, 28, 28, 64)        0         
_________________________________________________________________
conv2d_6 (Conv2D)            (None, 28, 28, 64)        36864     
_________________________________________________________________
activation_5 (Activation)    (None, 28, 28, 64)        0         
_________________________________________________________________
conv2d_7 (Conv2D)            (None, 28, 28, 64)        36864     
_________________________________________________________________
activation_6 (Activation)    (None, 28, 28, 64)        0         
_________________________________________________________________
conv2d_8 (Conv2D)            (None, 28, 28, 64)        36864     
_________________________________________________________________
activation_7 (Activation)    (None, 28, 28, 64)        0         
_________________________________________________________________
conv2d_9 (Conv2D)            (None, 14, 14, 128)       73728     
_________________________________________________________________
activation_8 (Activation)    (None, 14, 14, 128)       0         
_________________________________________________________________
conv2d_10 (Conv2D)           (None, 14, 14, 128)       147456    
_________________________________________________________________
activation_9 (Activation)    (None, 14, 14, 128)       0         
_________________________________________________________________
conv2d_11 (Conv2D)           (None, 14, 14, 128)       147456    
_________________________________________________________________
activation_10 (Activation)   (None, 14, 14, 128)       0         
_________________________________________________________________
conv2d_12 (Conv2D)           (None, 14, 14, 128)       147456    
_________________________________________________________________
activation_11 (Activation)   (None, 14, 14, 128)       0         
_________________________________________________________________
conv2d_13 (Conv2D)           (None, 14, 14, 128)       147456    
_________________________________________________________________
activation_12 (Activation)   (None, 14, 14, 128)       0         
_________________________________________________________________
conv2d_14 (Conv2D)           (None, 14, 14, 128)       147456    
_________________________________________________________________
activation_13 (Activation)   (None, 14, 14, 128)       0         
_________________________________________________________________
conv2d_15 (Conv2D)           (None, 14, 14, 128)       147456    
_________________________________________________________________
activation_14 (Activation)   (None, 14, 14, 128)       0         
_________________________________________________________________
conv2d_16 (Conv2D)           (None, 14, 14, 128)       147456    
_________________________________________________________________
activation_15 (Activation)   (None, 14, 14, 128)       0         
_________________________________________________________________
conv2d_17 (Conv2D)           (None, 7, 7, 256)         294912    
_________________________________________________________________
activation_16 (Activation)   (None, 7, 7, 256)         0         
_________________________________________________________________
conv2d_18 (Conv2D)           (None, 7, 7, 256)         589824    
_________________________________________________________________
activation_17 (Activation)   (None, 7, 7, 256)         0         
_________________________________________________________________
conv2d_19 (Conv2D)           (None, 7, 7, 256)         589824    
_________________________________________________________________
activation_18 (Activation)   (None, 7, 7, 256)         0         
_________________________________________________________________
conv2d_20 (Conv2D)           (None, 7, 7, 256)         589824    
_________________________________________________________________
activation_19 (Activation)   (None, 7, 7, 256)         0         
_________________________________________________________________
conv2d_21 (Conv2D)           (None, 7, 7, 256)         589824    
_________________________________________________________________
activation_20 (Activation)   (None, 7, 7, 256)         0         
_________________________________________________________________
conv2d_22 (Conv2D)           (None, 7, 7, 256)         589824    
_________________________________________________________________
activation_21 (Activation)   (None, 7, 7, 256)         0         
_________________________________________________________________
conv2d_23 (Conv2D)           (None, 7, 7, 256)         589824    
_________________________________________________________________
activation_22 (Activation)   (None, 7, 7, 256)         0         
_________________________________________________________________
conv2d_24 (Conv2D)           (None, 7, 7, 256)         589824    
_________________________________________________________________
activation_23 (Activation)   (None, 7, 7, 256)         0         
_________________________________________________________________
global_average_pooling2d (Gl (None, 256)               0         
_________________________________________________________________
dense (Dense)                (None, 10)                2570      
=================================================================
Total params: 5,827,658
Trainable params: 5,827,658
Non-trainable params: 0
_________________________________________________________________
Round 0 training loss: 2.4990053176879883, time: 11.922378778457642 secs
Round 0 validation accuracy: 11.224040031433105
Round 1 training loss: 2.307560920715332, time: 9.916815996170044 secs
Round 2 training loss: 2.3032877445220947, time: 7.68927804629008 secs
Round 2 validation accuracy: 11.224040031433105
Round 3 training loss: 2.302366256713867, time: 7.681552231311798 secs
Round 4 training loss: 2.301671028137207, time: 7.613566827774048 secs
Round 4 validation accuracy: 11.224040031433105
# Multi-GPU execution with configuration to mitigate OOM. 
cpu_device = tf.config.list_logical_devices('CPU')[0]
gpu_devices = tf.config.list_logical_devices('GPU')
tff.backends.native.set_local_execution_context(
    server_tf_device=cpu_device,
    client_tf_devices=gpu_devices, 
    clients_per_thread=1, 
    max_fanout=32)

run_federated_training(
    client_epochs_per_round=1, 
    train_batch_size=16, 
    test_batch_size=128, 
    cnn_num_blocks=4, 
    conv_width_multiplier=4,
    server_learning_rate=1.0, 
    client_learning_rate=0.01,
    total_rounds=5,
    clients_per_round=16, 
    rounds_per_eval=2,
    )
Model: "cnn-26-4"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_1 (InputLayer)         [(None, 28, 28, 1)]       0         
_________________________________________________________________
tf.image.per_image_standardi (None, 28, 28, 1)         0         
_________________________________________________________________
conv2d (Conv2D)              (None, 28, 28, 64)        576       
_________________________________________________________________
conv2d_1 (Conv2D)            (None, 28, 28, 64)        36864     
_________________________________________________________________
activation (Activation)      (None, 28, 28, 64)        0         
_________________________________________________________________
conv2d_2 (Conv2D)            (None, 28, 28, 64)        36864     
_________________________________________________________________
activation_1 (Activation)    (None, 28, 28, 64)        0         
_________________________________________________________________
conv2d_3 (Conv2D)            (None, 28, 28, 64)        36864     
_________________________________________________________________
activation_2 (Activation)    (None, 28, 28, 64)        0         
_________________________________________________________________
conv2d_4 (Conv2D)            (None, 28, 28, 64)        36864     
_________________________________________________________________
activation_3 (Activation)    (None, 28, 28, 64)        0         
_________________________________________________________________
conv2d_5 (Conv2D)            (None, 28, 28, 64)        36864     
_________________________________________________________________
activation_4 (Activation)    (None, 28, 28, 64)        0         
_________________________________________________________________
conv2d_6 (Conv2D)            (None, 28, 28, 64)        36864     
_________________________________________________________________
activation_5 (Activation)    (None, 28, 28, 64)        0         
_________________________________________________________________
conv2d_7 (Conv2D)            (None, 28, 28, 64)        36864     
_________________________________________________________________
activation_6 (Activation)    (None, 28, 28, 64)        0         
_________________________________________________________________
conv2d_8 (Conv2D)            (None, 28, 28, 64)        36864     
_________________________________________________________________
activation_7 (Activation)    (None, 28, 28, 64)        0         
_________________________________________________________________
conv2d_9 (Conv2D)            (None, 14, 14, 128)       73728     
_________________________________________________________________
activation_8 (Activation)    (None, 14, 14, 128)       0         
_________________________________________________________________
conv2d_10 (Conv2D)           (None, 14, 14, 128)       147456    
_________________________________________________________________
activation_9 (Activation)    (None, 14, 14, 128)       0         
_________________________________________________________________
conv2d_11 (Conv2D)           (None, 14, 14, 128)       147456    
_________________________________________________________________
activation_10 (Activation)   (None, 14, 14, 128)       0         
_________________________________________________________________
conv2d_12 (Conv2D)           (None, 14, 14, 128)       147456    
_________________________________________________________________
activation_11 (Activation)   (None, 14, 14, 128)       0         
_________________________________________________________________
conv2d_13 (Conv2D)           (None, 14, 14, 128)       147456    
_________________________________________________________________
activation_12 (Activation)   (None, 14, 14, 128)       0         
_________________________________________________________________
conv2d_14 (Conv2D)           (None, 14, 14, 128)       147456    
_________________________________________________________________
activation_13 (Activation)   (None, 14, 14, 128)       0         
_________________________________________________________________
conv2d_15 (Conv2D)           (None, 14, 14, 128)       147456    
_________________________________________________________________
activation_14 (Activation)   (None, 14, 14, 128)       0         
_________________________________________________________________
conv2d_16 (Conv2D)           (None, 14, 14, 128)       147456    
_________________________________________________________________
activation_15 (Activation)   (None, 14, 14, 128)       0         
_________________________________________________________________
conv2d_17 (Conv2D)           (None, 7, 7, 256)         294912    
_________________________________________________________________
activation_16 (Activation)   (None, 7, 7, 256)         0         
_________________________________________________________________
conv2d_18 (Conv2D)           (None, 7, 7, 256)         589824    
_________________________________________________________________
activation_17 (Activation)   (None, 7, 7, 256)         0         
_________________________________________________________________
conv2d_19 (Conv2D)           (None, 7, 7, 256)         589824    
_________________________________________________________________
activation_18 (Activation)   (None, 7, 7, 256)         0         
_________________________________________________________________
conv2d_20 (Conv2D)           (None, 7, 7, 256)         589824    
_________________________________________________________________
activation_19 (Activation)   (None, 7, 7, 256)         0         
_________________________________________________________________
conv2d_21 (Conv2D)           (None, 7, 7, 256)         589824    
_________________________________________________________________
activation_20 (Activation)   (None, 7, 7, 256)         0         
_________________________________________________________________
conv2d_22 (Conv2D)           (None, 7, 7, 256)         589824    
_________________________________________________________________
activation_21 (Activation)   (None, 7, 7, 256)         0         
_________________________________________________________________
conv2d_23 (Conv2D)           (None, 7, 7, 256)         589824    
_________________________________________________________________
activation_22 (Activation)   (None, 7, 7, 256)         0         
_________________________________________________________________
conv2d_24 (Conv2D)           (None, 7, 7, 256)         589824    
_________________________________________________________________
activation_23 (Activation)   (None, 7, 7, 256)         0         
_________________________________________________________________
global_average_pooling2d (Gl (None, 256)               0         
_________________________________________________________________
dense (Dense)                (None, 10)                2570      
=================================================================
Total params: 5,827,658
Trainable params: 5,827,658
Non-trainable params: 0
_________________________________________________________________
Round 0 training loss: 2.4691953659057617, time: 17.81941556930542 secs
Round 0 validation accuracy: 10.817495346069336
Round 1 training loss: 2.3081436157226562, time: 12.986191034317017 secs
Round 2 training loss: 2.3028159141540527, time: 9.518500963846842 secs
Round 2 validation accuracy: 11.500783920288086
Round 3 training loss: 2.303886651992798, time: 8.989932537078857 secs
Round 4 training loss: 2.3030669689178467, time: 8.733866214752197 secs
Round 4 validation accuracy: 12.992260932922363

ייעל את הביצועים

בדרך כלל ניתן להשתמש בטכניקות ב- TF שיכולות להשיג ביצועים טובים יותר ב- TFF, למשל אימוני דיוק מעורבים ו- XLA . tf.profiler (ב- GPU כמו V100) וחיסכון בזיכרון בדיוק מעורב יכול להיות משמעותי, דבר שניתן לבחון על ידי tf.profiler .

# Mixed precision training. 
cpu_device = tf.config.list_logical_devices('CPU')[0]
gpu_devices = tf.config.list_logical_devices('GPU')
tff.backends.native.set_local_execution_context(
    server_tf_device=cpu_device,
    client_tf_devices=gpu_devices, 
    clients_per_thread=1, 
    max_fanout=32)
policy = tf.keras.mixed_precision.experimental.Policy('mixed_float16')
tf.keras.mixed_precision.experimental.set_policy(policy)


run_federated_training(
    client_epochs_per_round=1, 
    train_batch_size=16, 
    test_batch_size=128, 
    cnn_num_blocks=4, 
    conv_width_multiplier=4,
    server_learning_rate=1.0, 
    client_learning_rate=0.01,
    total_rounds=5,
    clients_per_round=16, 
    rounds_per_eval=2,
    logdir='mixed'
    )
Model: "cnn-26-4"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_1 (InputLayer)         [(None, 28, 28, 1)]       0         
_________________________________________________________________
tf.image.per_image_standardi (None, 28, 28, 1)         0         
_________________________________________________________________
conv2d (Conv2D)              (None, 28, 28, 64)        576       
_________________________________________________________________
conv2d_1 (Conv2D)            (None, 28, 28, 64)        36864     
_________________________________________________________________
activation (Activation)      (None, 28, 28, 64)        0         
_________________________________________________________________
conv2d_2 (Conv2D)            (None, 28, 28, 64)        36864     
_________________________________________________________________
activation_1 (Activation)    (None, 28, 28, 64)        0         
_________________________________________________________________
conv2d_3 (Conv2D)            (None, 28, 28, 64)        36864     
_________________________________________________________________
activation_2 (Activation)    (None, 28, 28, 64)        0         
_________________________________________________________________
conv2d_4 (Conv2D)            (None, 28, 28, 64)        36864     
_________________________________________________________________
activation_3 (Activation)    (None, 28, 28, 64)        0         
_________________________________________________________________
conv2d_5 (Conv2D)            (None, 28, 28, 64)        36864     
_________________________________________________________________
activation_4 (Activation)    (None, 28, 28, 64)        0         
_________________________________________________________________
conv2d_6 (Conv2D)            (None, 28, 28, 64)        36864     
_________________________________________________________________
activation_5 (Activation)    (None, 28, 28, 64)        0         
_________________________________________________________________
conv2d_7 (Conv2D)            (None, 28, 28, 64)        36864     
_________________________________________________________________
activation_6 (Activation)    (None, 28, 28, 64)        0         
_________________________________________________________________
conv2d_8 (Conv2D)            (None, 28, 28, 64)        36864     
_________________________________________________________________
activation_7 (Activation)    (None, 28, 28, 64)        0         
_________________________________________________________________
conv2d_9 (Conv2D)            (None, 14, 14, 128)       73728     
_________________________________________________________________
activation_8 (Activation)    (None, 14, 14, 128)       0         
_________________________________________________________________
conv2d_10 (Conv2D)           (None, 14, 14, 128)       147456    
_________________________________________________________________
activation_9 (Activation)    (None, 14, 14, 128)       0         
_________________________________________________________________
conv2d_11 (Conv2D)           (None, 14, 14, 128)       147456    
_________________________________________________________________
activation_10 (Activation)   (None, 14, 14, 128)       0         
_________________________________________________________________
conv2d_12 (Conv2D)           (None, 14, 14, 128)       147456    
_________________________________________________________________
activation_11 (Activation)   (None, 14, 14, 128)       0         
_________________________________________________________________
conv2d_13 (Conv2D)           (None, 14, 14, 128)       147456    
_________________________________________________________________
activation_12 (Activation)   (None, 14, 14, 128)       0         
_________________________________________________________________
conv2d_14 (Conv2D)           (None, 14, 14, 128)       147456    
_________________________________________________________________
activation_13 (Activation)   (None, 14, 14, 128)       0         
_________________________________________________________________
conv2d_15 (Conv2D)           (None, 14, 14, 128)       147456    
_________________________________________________________________
activation_14 (Activation)   (None, 14, 14, 128)       0         
_________________________________________________________________
conv2d_16 (Conv2D)           (None, 14, 14, 128)       147456    
_________________________________________________________________
activation_15 (Activation)   (None, 14, 14, 128)       0         
_________________________________________________________________
conv2d_17 (Conv2D)           (None, 7, 7, 256)         294912    
_________________________________________________________________
activation_16 (Activation)   (None, 7, 7, 256)         0         
_________________________________________________________________
conv2d_18 (Conv2D)           (None, 7, 7, 256)         589824    
_________________________________________________________________
activation_17 (Activation)   (None, 7, 7, 256)         0         
_________________________________________________________________
conv2d_19 (Conv2D)           (None, 7, 7, 256)         589824    
_________________________________________________________________
activation_18 (Activation)   (None, 7, 7, 256)         0         
_________________________________________________________________
conv2d_20 (Conv2D)           (None, 7, 7, 256)         589824    
_________________________________________________________________
activation_19 (Activation)   (None, 7, 7, 256)         0         
_________________________________________________________________
conv2d_21 (Conv2D)           (None, 7, 7, 256)         589824    
_________________________________________________________________
activation_20 (Activation)   (None, 7, 7, 256)         0         
_________________________________________________________________
conv2d_22 (Conv2D)           (None, 7, 7, 256)         589824    
_________________________________________________________________
activation_21 (Activation)   (None, 7, 7, 256)         0         
_________________________________________________________________
conv2d_23 (Conv2D)           (None, 7, 7, 256)         589824    
_________________________________________________________________
activation_22 (Activation)   (None, 7, 7, 256)         0         
_________________________________________________________________
conv2d_24 (Conv2D)           (None, 7, 7, 256)         589824    
_________________________________________________________________
activation_23 (Activation)   (None, 7, 7, 256)         0         
_________________________________________________________________
global_average_pooling2d (Gl (None, 256)               0         
_________________________________________________________________
dense (Dense)                (None, 10)                2570      
=================================================================
Total params: 5,827,658
Trainable params: 5,827,658
Non-trainable params: 0
_________________________________________________________________
Round 0 training loss: 2.4187185764312744, time: 18.763780117034912 secs
Round 0 validation accuracy: 9.977468490600586
Round 1 training loss: 2.305102825164795, time: 13.712820529937744 secs
Round 2 training loss: 2.304737091064453, time: 9.993690172831217 secs
Round 2 validation accuracy: 11.779976844787598
Round 3 training loss: 2.2996833324432373, time: 9.29404467344284 secs
Round 4 training loss: 2.299349308013916, time: 9.195427560806275 secs
Round 4 validation accuracy: 11.224040031433105
%tensorboard --logdir=mixed --port=0