การจำลอง TFF ด้วยคันเร่ง

บทช่วยสอนนี้จะอธิบายวิธีตั้งค่าการจำลอง TFF ด้วยตัวเร่งความเร็ว ตอนนี้เราเน้นที่ GPU เครื่องเดียว (หลายเครื่อง) และจะอัปเดตบทช่วยสอนนี้ด้วยการตั้งค่าเครื่องหลายเครื่องและ TPU

ดูบน TensorFlow.org ทำงานใน Google Colab ดูแหล่งที่มาบน GitHub ดาวน์โหลดโน๊ตบุ๊ค

ก่อนที่เราจะเริ่มต้น

อันดับแรก ให้เราตรวจสอบให้แน่ใจว่าโน้ตบุ๊กเชื่อมต่อกับแบ็กเอนด์ที่มีการรวบรวมส่วนประกอบที่เกี่ยวข้อง

!pip install --quiet --upgrade tensorflow_federated_nightly
!pip install --quiet --upgrade nest_asyncio
!pip install -U tensorboard_plugin_profile

import nest_asyncio
nest_asyncio.apply()
%load_ext tensorboard
import collections
import time

import numpy as np
import tensorflow as tf
import tensorflow_federated as tff

ตรวจสอบว่า TF สามารถตรวจจับ GPU จริงและสร้างสภาพแวดล้อม multi-GPU เสมือนสำหรับการจำลอง TFF GPU ได้หรือไม่ GPU เสมือนทั้งสองจะมีหน่วยความจำที่จำกัดเพื่อสาธิตวิธีกำหนดค่ารันไทม์ TFF

gpu_devices = tf.config.list_physical_devices('GPU')
if not gpu_devices:
 raise ValueError('Cannot detect physical GPU device in TF')
tf.config.set_logical_device_configuration(
  gpu_devices[0], 
  [tf.config.LogicalDeviceConfiguration(memory_limit=1024),
   tf.config.LogicalDeviceConfiguration(memory_limit=1024)])
tf.config.list_logical_devices()
[LogicalDevice(name='/device:CPU:0', device_type='CPU'),
 LogicalDevice(name='/device:GPU:0', device_type='GPU'),
 LogicalDevice(name='/device:GPU:1', device_type='GPU')]

เรียกใช้ตัวอย่าง "Hello World" ต่อไปนี้เพื่อให้แน่ใจว่าสภาพแวดล้อม TFF ได้รับการตั้งค่าอย่างถูกต้อง ถ้ามันไม่ทำงานโปรดดูที่ การติดตั้ง คู่มือสำหรับคำแนะนำ

@tff.federated_computation
def hello_world():
 return 'Hello, World!'

hello_world()
b'Hello, World!'

ตั้งค่าการทดลอง EMNIST

ในบทช่วยสอนนี้ เราฝึกตัวแยกประเภทรูปภาพ EMNIST ด้วยอัลกอริธึม Federated Averaging ให้เราเริ่มต้นด้วยการโหลดตัวอย่าง MNIST จากเว็บไซต์ TFF

emnist_train, emnist_test = tff.simulation.datasets.emnist.load_data(only_digits=True)

เรากำหนดฟังก์ชั่น preprocessing ตัวอย่าง EMNIST ดังต่อไปนี้ simple_fedavg ตัวอย่างเช่น โปรดทราบว่าการโต้แย้ง client_epochs_per_round ควบคุมจำนวน epochs ท้องถิ่นเกี่ยวกับลูกค้าในการเรียนรู้แบบ federated

def preprocess_emnist_dataset(client_epochs_per_round, batch_size, test_batch_size):

 def element_fn(element):
  return collections.OrderedDict(
    x=tf.expand_dims(element['pixels'], -1), y=element['label'])

 def preprocess_train_dataset(dataset):
  # Use buffer_size same as the maximum client dataset size,
  # 418 for Federated EMNIST
  return dataset.map(element_fn).shuffle(buffer_size=418).repeat(
    count=client_epochs_per_round).batch(batch_size, drop_remainder=False)

 def preprocess_test_dataset(dataset):
  return dataset.map(element_fn).batch(test_batch_size, drop_remainder=False)

 train_set = emnist_train.preprocess(preprocess_train_dataset)
 test_set = preprocess_test_dataset(
   emnist_test.create_tf_dataset_from_all_clients())
 return train_set, test_set

เราใช้โมเดลที่เหมือน VGG กล่าวคือ แต่ละบล็อกมีการหมุนวน 3x3 สองครั้ง และจำนวนตัวกรองจะเพิ่มเป็นสองเท่าเมื่อแผนที่คุณลักษณะถูกสุ่มตัวอย่าง

def _conv_3x3(input_tensor, filters, strides):
 """2D Convolutional layer with kernel size 3x3."""

 x = tf.keras.layers.Conv2D(
   filters=filters,
   strides=strides,
   kernel_size=3,
   padding='same',
   kernel_initializer='he_normal',
   use_bias=False,
 )(input_tensor)
 return x


def _basic_block(input_tensor, filters, strides):
 """A block of two 3x3 conv layers."""

 x = input_tensor
 x = _conv_3x3(x, filters, strides)
 x = tf.keras.layers.Activation('relu')(x)

 x = _conv_3x3(x, filters, 1)
 x = tf.keras.layers.Activation('relu')(x)
 return x


def _vgg_block(input_tensor, size, filters, strides):
 """A stack of basic blocks."""
 x = _basic_block(input_tensor, filters, strides=strides)
 for _ in range(size - 1):
   x = _basic_block(x, filters, strides=1)
 return x


def create_cnn(num_blocks, conv_width_multiplier=1, num_classes=10):
 """Create a VGG-like CNN model. 

 The CNN has (6*num_blocks + 2) layers.
 """
 input_shape = (28, 28, 1) # channels_last
 img_input = tf.keras.layers.Input(shape=input_shape)
 x = img_input
 x = tf.image.per_image_standardization(x)

 x = _conv_3x3(x, 16 * conv_width_multiplier, 1)
 x = _vgg_block(x, size=num_blocks, filters=16 * conv_width_multiplier, strides=1)
 x = _vgg_block(x, size=num_blocks, filters=32 * conv_width_multiplier, strides=2)
 x = _vgg_block(x, size=num_blocks, filters=64 * conv_width_multiplier, strides=2)

 x = tf.keras.layers.GlobalAveragePooling2D()(x)
 x = tf.keras.layers.Dense(num_classes)(x)

 model = tf.keras.models.Model(
   img_input,
   x,
   name='cnn-{}-{}'.format(6 * num_blocks + 2, conv_width_multiplier))
 return model

ตอนนี้ ให้เรากำหนดวงการฝึกอบรมสำหรับ EMNIST โปรดทราบว่า use_experimental_simulation_loop=True ใน tff.learning.build_federated_averaging_process เป็นข้อเสนอแนะสำหรับการจำลองฉิบหาย performant และจำเป็นในการใช้ประโยชน์จากหลาย GPUs ในเครื่องเดียว ดู simple_fedavg ตัวอย่างสำหรับวิธีการกำหนดขั้นตอนวิธีการเรียนรู้ที่กำหนดเองแบบ federated ที่มีประสิทธิภาพสูงบน GPUs ซึ่งเป็นหนึ่งในกุญแจสำคัญในการให้บริการคือการใช้อย่างชัดเจน for ... iter(dataset) สำหรับลูปการฝึกอบรม

def keras_evaluate(model, test_data, metric):
 metric.reset_states()
 for batch in test_data:
  preds = model(batch['x'], training=False)
  metric.update_state(y_true=batch['y'], y_pred=preds)
 return metric.result()


def run_federated_training(client_epochs_per_round, 
              train_batch_size, 
              test_batch_size, 
              cnn_num_blocks, 
              conv_width_multiplier,
              server_learning_rate, 
              client_learning_rate, 
              total_rounds, 
              clients_per_round, 
              rounds_per_eval,
              logdir='logdir'):

 train_data, test_data = preprocess_emnist_dataset(
   client_epochs_per_round, train_batch_size, test_batch_size)
 data_spec = test_data.element_spec

 def _model_fn():
  keras_model = create_cnn(cnn_num_blocks, conv_width_multiplier)
  loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
  return tff.learning.from_keras_model(
    keras_model, input_spec=data_spec, loss=loss)

 def _server_optimizer_fn():
  return tf.keras.optimizers.SGD(learning_rate=server_learning_rate)

 def _client_optimizer_fn():
  return tf.keras.optimizers.SGD(learning_rate=client_learning_rate)

 iterative_process = tff.learning.build_federated_averaging_process(
   model_fn=_model_fn, 
   server_optimizer_fn=_server_optimizer_fn, 
   client_optimizer_fn=_client_optimizer_fn, 
   use_experimental_simulation_loop=True)

 metric = tf.keras.metrics.SparseCategoricalAccuracy(name='test_accuracy')
 eval_model = create_cnn(cnn_num_blocks, conv_width_multiplier)
 logging.info(eval_model.summary())

 server_state = iterative_process.initialize()
 start_time = time.time()
 for round_num in range(total_rounds):
  sampled_clients = np.random.choice(
    train_data.client_ids,
    size=clients_per_round,
    replace=False)
  sampled_train_data = [
    train_data.create_tf_dataset_for_client(client)
    for client in sampled_clients
  ]
  if round_num == total_rounds-1:
   with tf.profiler.experimental.Profile(logdir):
    server_state, train_metrics = iterative_process.next(
      server_state, sampled_train_data)
  else:
   server_state, train_metrics = iterative_process.next(
      server_state, sampled_train_data)
  print(f'Round {round_num} training loss: {train_metrics["train"]["loss"]}, '
   f'time: {(time.time()-start_time)/(round_num+1.)} secs')
  if round_num % rounds_per_eval == 0 or round_num == total_rounds-1:
   server_state.model.assign_weights_to(eval_model)
   accuracy = keras_evaluate(eval_model, test_data, metric)
   print(f'Round {round_num} validation accuracy: {accuracy * 100.0}')

การประมวลผล GPU เดียว

รันไทม์เริ่มต้นของ TFF จะเหมือนกับ TF: เมื่อมีการจัดหา GPU จะเลือก GPU ตัวแรกสำหรับการดำเนินการ เราดำเนินการการฝึกแบบสหพันธรัฐที่กำหนดไว้ก่อนหน้านี้หลายรอบด้วยแบบจำลองที่ค่อนข้างเล็ก รอบสุดท้ายของการดำเนินการมีประวัติกับ tf.profiler และมองเห็นโดย tensorboard โปรไฟล์ตรวจสอบแล้วว่าใช้ GPU ตัวแรก

run_federated_training(
  client_epochs_per_round=1, 
  train_batch_size=16, 
  test_batch_size=128, 
  cnn_num_blocks=2, 
  conv_width_multiplier=4,
  server_learning_rate=1.0, 
  client_learning_rate=0.01,
  total_rounds=10,
  clients_per_round=16, 
  rounds_per_eval=2,
  )
Model: "cnn-14-4"
_________________________________________________________________
Layer (type)         Output Shape       Param #  
=================================================================
input_1 (InputLayer)     [(None, 28, 28, 1)]    0     
_________________________________________________________________
tf.image.per_image_standardi (None, 28, 28, 1)     0     
_________________________________________________________________
conv2d (Conv2D)       (None, 28, 28, 64)    576    
_________________________________________________________________
conv2d_1 (Conv2D)      (None, 28, 28, 64)    36864   
_________________________________________________________________
activation (Activation)   (None, 28, 28, 64)    0     
_________________________________________________________________
conv2d_2 (Conv2D)      (None, 28, 28, 64)    36864   
_________________________________________________________________
activation_1 (Activation)  (None, 28, 28, 64)    0     
_________________________________________________________________
conv2d_3 (Conv2D)      (None, 28, 28, 64)    36864   
_________________________________________________________________
activation_2 (Activation)  (None, 28, 28, 64)    0     
_________________________________________________________________
conv2d_4 (Conv2D)      (None, 28, 28, 64)    36864   
_________________________________________________________________
activation_3 (Activation)  (None, 28, 28, 64)    0     
_________________________________________________________________
conv2d_5 (Conv2D)      (None, 14, 14, 128)    73728   
_________________________________________________________________
activation_4 (Activation)  (None, 14, 14, 128)    0     
_________________________________________________________________
conv2d_6 (Conv2D)      (None, 14, 14, 128)    147456  
_________________________________________________________________
activation_5 (Activation)  (None, 14, 14, 128)    0     
_________________________________________________________________
conv2d_7 (Conv2D)      (None, 14, 14, 128)    147456  
_________________________________________________________________
activation_6 (Activation)  (None, 14, 14, 128)    0     
_________________________________________________________________
conv2d_8 (Conv2D)      (None, 14, 14, 128)    147456  
_________________________________________________________________
activation_7 (Activation)  (None, 14, 14, 128)    0     
_________________________________________________________________
conv2d_9 (Conv2D)      (None, 7, 7, 256)     294912  
_________________________________________________________________
activation_8 (Activation)  (None, 7, 7, 256)     0     
_________________________________________________________________
conv2d_10 (Conv2D)      (None, 7, 7, 256)     589824  
_________________________________________________________________
activation_9 (Activation)  (None, 7, 7, 256)     0     
_________________________________________________________________
conv2d_11 (Conv2D)      (None, 7, 7, 256)     589824  
_________________________________________________________________
activation_10 (Activation)  (None, 7, 7, 256)     0     
_________________________________________________________________
conv2d_12 (Conv2D)      (None, 7, 7, 256)     589824  
_________________________________________________________________
activation_11 (Activation)  (None, 7, 7, 256)     0     
_________________________________________________________________
global_average_pooling2d (Gl (None, 256)        0     
_________________________________________________________________
dense (Dense)        (None, 10)        2570   
=================================================================
Total params: 2,731,082
Trainable params: 2,731,082
Non-trainable params: 0
_________________________________________________________________
Round 0 training loss: 2.4688243865966797, time: 13.382015466690063 secs
Round 0 validation accuracy: 15.240497589111328
Round 1 training loss: 2.3217368125915527, time: 9.311999917030334 secs
Round 2 training loss: 2.3100595474243164, time: 6.972411632537842 secs
Round 2 validation accuracy: 11.226489067077637
Round 3 training loss: 2.303222417831421, time: 6.467299699783325 secs
Round 4 training loss: 2.2976326942443848, time: 5.526083135604859 secs
Round 4 validation accuracy: 11.224040031433105
Round 5 training loss: 2.2919719219207764, time: 5.468692660331726 secs
Round 6 training loss: 2.2911534309387207, time: 4.935825347900391 secs
Round 6 validation accuracy: 11.833855628967285
Round 7 training loss: 2.2871201038360596, time: 4.918408691883087 secs
Round 8 training loss: 2.2818832397460938, time: 4.602836343977186 secs
Round 8 validation accuracy: 11.385677337646484
Round 9 training loss: 2.2790346145629883, time: 4.99558527469635 secs
Round 9 validation accuracy: 11.226489067077637
%tensorboard --logdir=logdir --port=0

การทำงานของซีพียู

ในการเปรียบเทียบ ให้เรากำหนดค่ารันไทม์ TFF สำหรับการทำงานของ CPU การทำงานของ CPU ช้าลงเล็กน้อยสำหรับรุ่นที่ค่อนข้างเล็กนี้เท่านั้น

cpu_device = tf.config.list_logical_devices('CPU')[0]
tff.backends.native.set_local_python_execution_context(
  server_tf_device=cpu_device, client_tf_devices=[cpu_device])

run_federated_training(
  client_epochs_per_round=1, 
  train_batch_size=16, 
  test_batch_size=128, 
  cnn_num_blocks=2, 
  conv_width_multiplier=4,
  server_learning_rate=1.0, 
  client_learning_rate=0.01,
  total_rounds=10,
  clients_per_round=16, 
  rounds_per_eval=2,
  )
Model: "cnn-14-4"
_________________________________________________________________
Layer (type)         Output Shape       Param #  
=================================================================
input_2 (InputLayer)     [(None, 28, 28, 1)]    0     
_________________________________________________________________
tf.image.per_image_standardi (None, 28, 28, 1)     0     
_________________________________________________________________
conv2d_13 (Conv2D)      (None, 28, 28, 64)    576    
_________________________________________________________________
conv2d_14 (Conv2D)      (None, 28, 28, 64)    36864   
_________________________________________________________________
activation_12 (Activation)  (None, 28, 28, 64)    0     
_________________________________________________________________
conv2d_15 (Conv2D)      (None, 28, 28, 64)    36864   
_________________________________________________________________
activation_13 (Activation)  (None, 28, 28, 64)    0     
_________________________________________________________________
conv2d_16 (Conv2D)      (None, 28, 28, 64)    36864   
_________________________________________________________________
activation_14 (Activation)  (None, 28, 28, 64)    0     
_________________________________________________________________
conv2d_17 (Conv2D)      (None, 28, 28, 64)    36864   
_________________________________________________________________
activation_15 (Activation)  (None, 28, 28, 64)    0     
_________________________________________________________________
conv2d_18 (Conv2D)      (None, 14, 14, 128)    73728   
_________________________________________________________________
activation_16 (Activation)  (None, 14, 14, 128)    0     
_________________________________________________________________
conv2d_19 (Conv2D)      (None, 14, 14, 128)    147456  
_________________________________________________________________
activation_17 (Activation)  (None, 14, 14, 128)    0     
_________________________________________________________________
conv2d_20 (Conv2D)      (None, 14, 14, 128)    147456  
_________________________________________________________________
activation_18 (Activation)  (None, 14, 14, 128)    0     
_________________________________________________________________
conv2d_21 (Conv2D)      (None, 14, 14, 128)    147456  
_________________________________________________________________
activation_19 (Activation)  (None, 14, 14, 128)    0     
_________________________________________________________________
conv2d_22 (Conv2D)      (None, 7, 7, 256)     294912  
_________________________________________________________________
activation_20 (Activation)  (None, 7, 7, 256)     0     
_________________________________________________________________
conv2d_23 (Conv2D)      (None, 7, 7, 256)     589824  
_________________________________________________________________
activation_21 (Activation)  (None, 7, 7, 256)     0     
_________________________________________________________________
conv2d_24 (Conv2D)      (None, 7, 7, 256)     589824  
_________________________________________________________________
activation_22 (Activation)  (None, 7, 7, 256)     0     
_________________________________________________________________
conv2d_25 (Conv2D)      (None, 7, 7, 256)     589824  
_________________________________________________________________
activation_23 (Activation)  (None, 7, 7, 256)     0     
_________________________________________________________________
global_average_pooling2d_1 ( (None, 256)        0     
_________________________________________________________________
dense_1 (Dense)       (None, 10)        2570   
=================================================================
Total params: 2,731,082
Trainable params: 2,731,082
Non-trainable params: 0
_________________________________________________________________
Round 0 training loss: 2.4787657260894775, time: 15.264191627502441 secs
Round 0 validation accuracy: 12.191418647766113
Round 1 training loss: 2.3097336292266846, time: 11.785032272338867 secs
Round 2 training loss: 2.3062121868133545, time: 9.677561124165853 secs
Round 2 validation accuracy: 11.415066719055176
Round 3 training loss: 2.2982261180877686, time: 9.301376760005951 secs
Round 4 training loss: 2.2953946590423584, time: 8.377780866622924 secs
Round 4 validation accuracy: 20.537813186645508
Round 5 training loss: 2.290337324142456, time: 8.385509928067526 secs
Round 6 training loss: 2.2842795848846436, time: 7.809031554630825 secs
Round 6 validation accuracy: 11.934267044067383
Round 7 training loss: 2.2752432823181152, time: 7.8433578312397 secs
Round 8 training loss: 2.2698657512664795, time: 7.478067080179851 secs
Round 8 validation accuracy: 26.16330337524414
Round 9 training loss: 2.2609798908233643, time: 7.632814192771912 secs
Round 9 validation accuracy: 23.079936981201172

การประมวลผล GPU หลายตัว

ง่ายต่อการกำหนดค่า TFF สำหรับการประมวลผลหลาย GPU โปรดทราบว่าการฝึกอบรมไคลเอ็นต์เป็นแบบขนานใน TFF ในการตั้งค่า multi-GPU ไคลเอนต์จะถูกกำหนดให้กับ multi-GPU แบบวนซ้ำ การทำงานของ GPU สองตัวต่อไปนี้ไม่เร็วกว่าการเรียกใช้ GPU ตัวเดียว เนื่องจากการฝึกไคลเอ็นต์เป็นแบบขนานในการตั้งค่า GPU เดี่ยวและหลายตัว และการตั้งค่า multiGPU มี GPU เสมือนสองตัวที่สร้างจาก GPU จริงเพียงตัวเดียว

gpu_devices = tf.config.list_logical_devices('GPU')
tff.backends.native.set_local_python_execution_context(client_tf_devices=gpu_devices)

run_federated_training(
  client_epochs_per_round=1, 
  train_batch_size=16, 
  test_batch_size=128, 
  cnn_num_blocks=2, 
  conv_width_multiplier=4,
  server_learning_rate=1.0, 
  client_learning_rate=0.01,
  total_rounds=10,
  clients_per_round=16, 
  rounds_per_eval=2,
  logdir='multigpu'
  )
Model: "cnn-14-4"
_________________________________________________________________
Layer (type)         Output Shape       Param #  
=================================================================
input_3 (InputLayer)     [(None, 28, 28, 1)]    0     
_________________________________________________________________
tf.image.per_image_standardi (None, 28, 28, 1)     0     
_________________________________________________________________
conv2d_26 (Conv2D)      (None, 28, 28, 64)    576    
_________________________________________________________________
conv2d_27 (Conv2D)      (None, 28, 28, 64)    36864   
_________________________________________________________________
activation_24 (Activation)  (None, 28, 28, 64)    0     
_________________________________________________________________
conv2d_28 (Conv2D)      (None, 28, 28, 64)    36864   
_________________________________________________________________
activation_25 (Activation)  (None, 28, 28, 64)    0     
_________________________________________________________________
conv2d_29 (Conv2D)      (None, 28, 28, 64)    36864   
_________________________________________________________________
activation_26 (Activation)  (None, 28, 28, 64)    0     
_________________________________________________________________
conv2d_30 (Conv2D)      (None, 28, 28, 64)    36864   
_________________________________________________________________
activation_27 (Activation)  (None, 28, 28, 64)    0     
_________________________________________________________________
conv2d_31 (Conv2D)      (None, 14, 14, 128)    73728   
_________________________________________________________________
activation_28 (Activation)  (None, 14, 14, 128)    0     
_________________________________________________________________
conv2d_32 (Conv2D)      (None, 14, 14, 128)    147456  
_________________________________________________________________
activation_29 (Activation)  (None, 14, 14, 128)    0     
_________________________________________________________________
conv2d_33 (Conv2D)      (None, 14, 14, 128)    147456  
_________________________________________________________________
activation_30 (Activation)  (None, 14, 14, 128)    0     
_________________________________________________________________
conv2d_34 (Conv2D)      (None, 14, 14, 128)    147456  
_________________________________________________________________
activation_31 (Activation)  (None, 14, 14, 128)    0     
_________________________________________________________________
conv2d_35 (Conv2D)      (None, 7, 7, 256)     294912  
_________________________________________________________________
activation_32 (Activation)  (None, 7, 7, 256)     0     
_________________________________________________________________
conv2d_36 (Conv2D)      (None, 7, 7, 256)     589824  
_________________________________________________________________
activation_33 (Activation)  (None, 7, 7, 256)     0     
_________________________________________________________________
conv2d_37 (Conv2D)      (None, 7, 7, 256)     589824  
_________________________________________________________________
activation_34 (Activation)  (None, 7, 7, 256)     0     
_________________________________________________________________
conv2d_38 (Conv2D)      (None, 7, 7, 256)     589824  
_________________________________________________________________
activation_35 (Activation)  (None, 7, 7, 256)     0     
_________________________________________________________________
global_average_pooling2d_2 ( (None, 256)        0     
_________________________________________________________________
dense_2 (Dense)       (None, 10)        2570   
=================================================================
Total params: 2,731,082
Trainable params: 2,731,082
Non-trainable params: 0
_________________________________________________________________
Round 0 training loss: 2.911365270614624, time: 12.759389877319336 secs
Round 0 validation accuracy: 9.541536331176758
Round 1 training loss: 2.3175694942474365, time: 9.202919721603394 secs
Round 2 training loss: 2.311001777648926, time: 6.802880525588989 secs
Round 2 validation accuracy: 9.911344528198242
Round 3 training loss: 2.3105244636535645, time: 6.611470937728882 secs
Round 4 training loss: 2.3082072734832764, time: 5.678833389282227 secs
Round 4 validation accuracy: 10.212578773498535
Round 5 training loss: 2.304673671722412, time: 5.5404335260391235 secs
Round 6 training loss: 2.3035168647766113, time: 5.008027451378958 secs
Round 6 validation accuracy: 9.935834884643555
Round 7 training loss: 2.3052737712860107, time: 5.1173741817474365 secs
Round 8 training loss: 2.3007171154022217, time: 4.745321141348945 secs
Round 8 validation accuracy: 10.768514633178711
Round 9 training loss: 2.302018404006958, time: 5.0809732437133786 secs
Round 9 validation accuracy: 12.311422348022461

รุ่นใหญ่และ OOM

ให้เราเรียกใช้รุ่นที่ใหญ่กว่าบน CPU ที่มีรอบรวมน้อยกว่า

cpu_device = tf.config.list_logical_devices('CPU')[0]
tff.backends.native.set_local_python_execution_context(
  server_tf_device=cpu_device, client_tf_devices=[cpu_device])

run_federated_training(
  client_epochs_per_round=1, 
  train_batch_size=16, 
  test_batch_size=128, 
  cnn_num_blocks=4, 
  conv_width_multiplier=4,
  server_learning_rate=1.0, 
  client_learning_rate=0.01,
  total_rounds=5,
  clients_per_round=16, 
  rounds_per_eval=2,
  )
Model: "cnn-26-4"
_________________________________________________________________
Layer (type)         Output Shape       Param #  
=================================================================
input_4 (InputLayer)     [(None, 28, 28, 1)]    0     
_________________________________________________________________
tf.image.per_image_standardi (None, 28, 28, 1)     0     
_________________________________________________________________
conv2d_39 (Conv2D)      (None, 28, 28, 64)    576    
_________________________________________________________________
conv2d_40 (Conv2D)      (None, 28, 28, 64)    36864   
_________________________________________________________________
activation_36 (Activation)  (None, 28, 28, 64)    0     
_________________________________________________________________
conv2d_41 (Conv2D)      (None, 28, 28, 64)    36864   
_________________________________________________________________
activation_37 (Activation)  (None, 28, 28, 64)    0     
_________________________________________________________________
conv2d_42 (Conv2D)      (None, 28, 28, 64)    36864   
_________________________________________________________________
activation_38 (Activation)  (None, 28, 28, 64)    0     
_________________________________________________________________
conv2d_43 (Conv2D)      (None, 28, 28, 64)    36864   
_________________________________________________________________
activation_39 (Activation)  (None, 28, 28, 64)    0     
_________________________________________________________________
conv2d_44 (Conv2D)      (None, 28, 28, 64)    36864   
_________________________________________________________________
activation_40 (Activation)  (None, 28, 28, 64)    0     
_________________________________________________________________
conv2d_45 (Conv2D)      (None, 28, 28, 64)    36864   
_________________________________________________________________
activation_41 (Activation)  (None, 28, 28, 64)    0     
_________________________________________________________________
conv2d_46 (Conv2D)      (None, 28, 28, 64)    36864   
_________________________________________________________________
activation_42 (Activation)  (None, 28, 28, 64)    0     
_________________________________________________________________
conv2d_47 (Conv2D)      (None, 28, 28, 64)    36864   
_________________________________________________________________
activation_43 (Activation)  (None, 28, 28, 64)    0     
_________________________________________________________________
conv2d_48 (Conv2D)      (None, 14, 14, 128)    73728   
_________________________________________________________________
activation_44 (Activation)  (None, 14, 14, 128)    0     
_________________________________________________________________
conv2d_49 (Conv2D)      (None, 14, 14, 128)    147456  
_________________________________________________________________
activation_45 (Activation)  (None, 14, 14, 128)    0     
_________________________________________________________________
conv2d_50 (Conv2D)      (None, 14, 14, 128)    147456  
_________________________________________________________________
activation_46 (Activation)  (None, 14, 14, 128)    0     
_________________________________________________________________
conv2d_51 (Conv2D)      (None, 14, 14, 128)    147456  
_________________________________________________________________
activation_47 (Activation)  (None, 14, 14, 128)    0     
_________________________________________________________________
conv2d_52 (Conv2D)      (None, 14, 14, 128)    147456  
_________________________________________________________________
activation_48 (Activation)  (None, 14, 14, 128)    0     
_________________________________________________________________
conv2d_53 (Conv2D)      (None, 14, 14, 128)    147456  
_________________________________________________________________
activation_49 (Activation)  (None, 14, 14, 128)    0     
_________________________________________________________________
conv2d_54 (Conv2D)      (None, 14, 14, 128)    147456  
_________________________________________________________________
activation_50 (Activation)  (None, 14, 14, 128)    0     
_________________________________________________________________
conv2d_55 (Conv2D)      (None, 14, 14, 128)    147456  
_________________________________________________________________
activation_51 (Activation)  (None, 14, 14, 128)    0     
_________________________________________________________________
conv2d_56 (Conv2D)      (None, 7, 7, 256)     294912  
_________________________________________________________________
activation_52 (Activation)  (None, 7, 7, 256)     0     
_________________________________________________________________
conv2d_57 (Conv2D)      (None, 7, 7, 256)     589824  
_________________________________________________________________
activation_53 (Activation)  (None, 7, 7, 256)     0     
_________________________________________________________________
conv2d_58 (Conv2D)      (None, 7, 7, 256)     589824  
_________________________________________________________________
activation_54 (Activation)  (None, 7, 7, 256)     0     
_________________________________________________________________
conv2d_59 (Conv2D)      (None, 7, 7, 256)     589824  
_________________________________________________________________
activation_55 (Activation)  (None, 7, 7, 256)     0     
_________________________________________________________________
conv2d_60 (Conv2D)      (None, 7, 7, 256)     589824  
_________________________________________________________________
activation_56 (Activation)  (None, 7, 7, 256)     0     
_________________________________________________________________
conv2d_61 (Conv2D)      (None, 7, 7, 256)     589824  
_________________________________________________________________
activation_57 (Activation)  (None, 7, 7, 256)     0     
_________________________________________________________________
conv2d_62 (Conv2D)      (None, 7, 7, 256)     589824  
_________________________________________________________________
activation_58 (Activation)  (None, 7, 7, 256)     0     
_________________________________________________________________
conv2d_63 (Conv2D)      (None, 7, 7, 256)     589824  
_________________________________________________________________
activation_59 (Activation)  (None, 7, 7, 256)     0     
_________________________________________________________________
global_average_pooling2d_3 ( (None, 256)        0     
_________________________________________________________________
dense_3 (Dense)       (None, 10)        2570   
=================================================================
Total params: 5,827,658
Trainable params: 5,827,658
Non-trainable params: 0
_________________________________________________________________
Round 0 training loss: 2.437223434448242, time: 24.121686458587646 secs
Round 0 validation accuracy: 9.024785041809082
Round 1 training loss: 2.3081459999084473, time: 19.48685622215271 secs
Round 2 training loss: 2.305305242538452, time: 15.73950457572937 secs
Round 2 validation accuracy: 9.791339874267578
Round 3 training loss: 2.303149700164795, time: 15.194068729877472 secs
Round 4 training loss: 2.3026506900787354, time: 14.036769819259643 secs
Round 4 validation accuracy: 12.193867683410645

โมเดลนี้อาจมีปัญหาหน่วยความจำไม่เพียงพอใน GPU ตัวเดียว การย้ายจากการทดลอง CPU ขนาดใหญ่ไปเป็นการจำลอง GPU อาจถูกจำกัดโดยการใช้หน่วยความจำ เนื่องจาก GPU มักมีหน่วยความจำที่จำกัด มีพารามิเตอร์หลายอย่างที่สามารถปรับได้ในรันไทม์ TFF เพื่อลดปัญหา OOM

# Single GPU execution might hit OOM. 
gpu_devices = tf.config.list_logical_devices('GPU')
tff.backends.native.set_local_python_execution_context(client_tf_devices=[gpu_devices[0]])

try:
 run_federated_training(
   client_epochs_per_round=1, 
   train_batch_size=16, 
   test_batch_size=128, 
   cnn_num_blocks=4, 
   conv_width_multiplier=4,
   server_learning_rate=1.0, 
   client_learning_rate=0.01,
   total_rounds=5,
   clients_per_round=16, 
   rounds_per_eval=2,
   )
except ResourceExhaustedError as e:
 print(e)
# Control concurrency by `clients_per_thread`.
gpu_devices = tf.config.list_logical_devices('GPU')
tff.backends.native.set_local_python_execution_context(
  client_tf_devices=[gpu_devices[0]], clients_per_thread=2)

run_federated_training(
  client_epochs_per_round=1, 
  train_batch_size=16, 
  test_batch_size=128, 
  cnn_num_blocks=4, 
  conv_width_multiplier=4,
  server_learning_rate=1.0, 
  client_learning_rate=0.01,
  total_rounds=5,
  clients_per_round=16, 
  rounds_per_eval=2,
  )
Model: "cnn-26-4"
_________________________________________________________________
Layer (type)         Output Shape       Param #  
=================================================================
input_1 (InputLayer)     [(None, 28, 28, 1)]    0     
_________________________________________________________________
tf.image.per_image_standardi (None, 28, 28, 1)     0     
_________________________________________________________________
conv2d (Conv2D)       (None, 28, 28, 64)    576    
_________________________________________________________________
conv2d_1 (Conv2D)      (None, 28, 28, 64)    36864   
_________________________________________________________________
activation (Activation)   (None, 28, 28, 64)    0     
_________________________________________________________________
conv2d_2 (Conv2D)      (None, 28, 28, 64)    36864   
_________________________________________________________________
activation_1 (Activation)  (None, 28, 28, 64)    0     
_________________________________________________________________
conv2d_3 (Conv2D)      (None, 28, 28, 64)    36864   
_________________________________________________________________
activation_2 (Activation)  (None, 28, 28, 64)    0     
_________________________________________________________________
conv2d_4 (Conv2D)      (None, 28, 28, 64)    36864   
_________________________________________________________________
activation_3 (Activation)  (None, 28, 28, 64)    0     
_________________________________________________________________
conv2d_5 (Conv2D)      (None, 28, 28, 64)    36864   
_________________________________________________________________
activation_4 (Activation)  (None, 28, 28, 64)    0     
_________________________________________________________________
conv2d_6 (Conv2D)      (None, 28, 28, 64)    36864   
_________________________________________________________________
activation_5 (Activation)  (None, 28, 28, 64)    0     
_________________________________________________________________
conv2d_7 (Conv2D)      (None, 28, 28, 64)    36864   
_________________________________________________________________
activation_6 (Activation)  (None, 28, 28, 64)    0     
_________________________________________________________________
conv2d_8 (Conv2D)      (None, 28, 28, 64)    36864   
_________________________________________________________________
activation_7 (Activation)  (None, 28, 28, 64)    0     
_________________________________________________________________
conv2d_9 (Conv2D)      (None, 14, 14, 128)    73728   
_________________________________________________________________
activation_8 (Activation)  (None, 14, 14, 128)    0     
_________________________________________________________________
conv2d_10 (Conv2D)      (None, 14, 14, 128)    147456  
_________________________________________________________________
activation_9 (Activation)  (None, 14, 14, 128)    0     
_________________________________________________________________
conv2d_11 (Conv2D)      (None, 14, 14, 128)    147456  
_________________________________________________________________
activation_10 (Activation)  (None, 14, 14, 128)    0     
_________________________________________________________________
conv2d_12 (Conv2D)      (None, 14, 14, 128)    147456  
_________________________________________________________________
activation_11 (Activation)  (None, 14, 14, 128)    0     
_________________________________________________________________
conv2d_13 (Conv2D)      (None, 14, 14, 128)    147456  
_________________________________________________________________
activation_12 (Activation)  (None, 14, 14, 128)    0     
_________________________________________________________________
conv2d_14 (Conv2D)      (None, 14, 14, 128)    147456  
_________________________________________________________________
activation_13 (Activation)  (None, 14, 14, 128)    0     
_________________________________________________________________
conv2d_15 (Conv2D)      (None, 14, 14, 128)    147456  
_________________________________________________________________
activation_14 (Activation)  (None, 14, 14, 128)    0     
_________________________________________________________________
conv2d_16 (Conv2D)      (None, 14, 14, 128)    147456  
_________________________________________________________________
activation_15 (Activation)  (None, 14, 14, 128)    0     
_________________________________________________________________
conv2d_17 (Conv2D)      (None, 7, 7, 256)     294912  
_________________________________________________________________
activation_16 (Activation)  (None, 7, 7, 256)     0     
_________________________________________________________________
conv2d_18 (Conv2D)      (None, 7, 7, 256)     589824  
_________________________________________________________________
activation_17 (Activation)  (None, 7, 7, 256)     0     
_________________________________________________________________
conv2d_19 (Conv2D)      (None, 7, 7, 256)     589824  
_________________________________________________________________
activation_18 (Activation)  (None, 7, 7, 256)     0     
_________________________________________________________________
conv2d_20 (Conv2D)      (None, 7, 7, 256)     589824  
_________________________________________________________________
activation_19 (Activation)  (None, 7, 7, 256)     0     
_________________________________________________________________
conv2d_21 (Conv2D)      (None, 7, 7, 256)     589824  
_________________________________________________________________
activation_20 (Activation)  (None, 7, 7, 256)     0     
_________________________________________________________________
conv2d_22 (Conv2D)      (None, 7, 7, 256)     589824  
_________________________________________________________________
activation_21 (Activation)  (None, 7, 7, 256)     0     
_________________________________________________________________
conv2d_23 (Conv2D)      (None, 7, 7, 256)     589824  
_________________________________________________________________
activation_22 (Activation)  (None, 7, 7, 256)     0     
_________________________________________________________________
conv2d_24 (Conv2D)      (None, 7, 7, 256)     589824  
_________________________________________________________________
activation_23 (Activation)  (None, 7, 7, 256)     0     
_________________________________________________________________
global_average_pooling2d (Gl (None, 256)        0     
_________________________________________________________________
dense (Dense)        (None, 10)        2570   
=================================================================
Total params: 5,827,658
Trainable params: 5,827,658
Non-trainable params: 0
_________________________________________________________________
Round 0 training loss: 2.4990053176879883, time: 11.922378778457642 secs
Round 0 validation accuracy: 11.224040031433105
Round 1 training loss: 2.307560920715332, time: 9.916815996170044 secs
Round 2 training loss: 2.3032877445220947, time: 7.68927804629008 secs
Round 2 validation accuracy: 11.224040031433105
Round 3 training loss: 2.302366256713867, time: 7.681552231311798 secs
Round 4 training loss: 2.301671028137207, time: 7.613566827774048 secs
Round 4 validation accuracy: 11.224040031433105
# Multi-GPU execution with configuration to mitigate OOM. 
cpu_device = tf.config.list_logical_devices('CPU')[0]
gpu_devices = tf.config.list_logical_devices('GPU')
tff.backends.native.set_local_python_execution_context(
  server_tf_device=cpu_device,
  client_tf_devices=gpu_devices, 
  clients_per_thread=1, 
  max_fanout=32)

run_federated_training(
  client_epochs_per_round=1, 
  train_batch_size=16, 
  test_batch_size=128, 
  cnn_num_blocks=4, 
  conv_width_multiplier=4,
  server_learning_rate=1.0, 
  client_learning_rate=0.01,
  total_rounds=5,
  clients_per_round=16, 
  rounds_per_eval=2,
  )
Model: "cnn-26-4"
_________________________________________________________________
Layer (type)         Output Shape       Param #  
=================================================================
input_1 (InputLayer)     [(None, 28, 28, 1)]    0     
_________________________________________________________________
tf.image.per_image_standardi (None, 28, 28, 1)     0     
_________________________________________________________________
conv2d (Conv2D)       (None, 28, 28, 64)    576    
_________________________________________________________________
conv2d_1 (Conv2D)      (None, 28, 28, 64)    36864   
_________________________________________________________________
activation (Activation)   (None, 28, 28, 64)    0     
_________________________________________________________________
conv2d_2 (Conv2D)      (None, 28, 28, 64)    36864   
_________________________________________________________________
activation_1 (Activation)  (None, 28, 28, 64)    0     
_________________________________________________________________
conv2d_3 (Conv2D)      (None, 28, 28, 64)    36864   
_________________________________________________________________
activation_2 (Activation)  (None, 28, 28, 64)    0     
_________________________________________________________________
conv2d_4 (Conv2D)      (None, 28, 28, 64)    36864   
_________________________________________________________________
activation_3 (Activation)  (None, 28, 28, 64)    0     
_________________________________________________________________
conv2d_5 (Conv2D)      (None, 28, 28, 64)    36864   
_________________________________________________________________
activation_4 (Activation)  (None, 28, 28, 64)    0     
_________________________________________________________________
conv2d_6 (Conv2D)      (None, 28, 28, 64)    36864   
_________________________________________________________________
activation_5 (Activation)  (None, 28, 28, 64)    0     
_________________________________________________________________
conv2d_7 (Conv2D)      (None, 28, 28, 64)    36864   
_________________________________________________________________
activation_6 (Activation)  (None, 28, 28, 64)    0     
_________________________________________________________________
conv2d_8 (Conv2D)      (None, 28, 28, 64)    36864   
_________________________________________________________________
activation_7 (Activation)  (None, 28, 28, 64)    0     
_________________________________________________________________
conv2d_9 (Conv2D)      (None, 14, 14, 128)    73728   
_________________________________________________________________
activation_8 (Activation)  (None, 14, 14, 128)    0     
_________________________________________________________________
conv2d_10 (Conv2D)      (None, 14, 14, 128)    147456  
_________________________________________________________________
activation_9 (Activation)  (None, 14, 14, 128)    0     
_________________________________________________________________
conv2d_11 (Conv2D)      (None, 14, 14, 128)    147456  
_________________________________________________________________
activation_10 (Activation)  (None, 14, 14, 128)    0     
_________________________________________________________________
conv2d_12 (Conv2D)      (None, 14, 14, 128)    147456  
_________________________________________________________________
activation_11 (Activation)  (None, 14, 14, 128)    0     
_________________________________________________________________
conv2d_13 (Conv2D)      (None, 14, 14, 128)    147456  
_________________________________________________________________
activation_12 (Activation)  (None, 14, 14, 128)    0     
_________________________________________________________________
conv2d_14 (Conv2D)      (None, 14, 14, 128)    147456  
_________________________________________________________________
activation_13 (Activation)  (None, 14, 14, 128)    0     
_________________________________________________________________
conv2d_15 (Conv2D)      (None, 14, 14, 128)    147456  
_________________________________________________________________
activation_14 (Activation)  (None, 14, 14, 128)    0     
_________________________________________________________________
conv2d_16 (Conv2D)      (None, 14, 14, 128)    147456  
_________________________________________________________________
activation_15 (Activation)  (None, 14, 14, 128)    0     
_________________________________________________________________
conv2d_17 (Conv2D)      (None, 7, 7, 256)     294912  
_________________________________________________________________
activation_16 (Activation)  (None, 7, 7, 256)     0     
_________________________________________________________________
conv2d_18 (Conv2D)      (None, 7, 7, 256)     589824  
_________________________________________________________________
activation_17 (Activation)  (None, 7, 7, 256)     0     
_________________________________________________________________
conv2d_19 (Conv2D)      (None, 7, 7, 256)     589824  
_________________________________________________________________
activation_18 (Activation)  (None, 7, 7, 256)     0     
_________________________________________________________________
conv2d_20 (Conv2D)      (None, 7, 7, 256)     589824  
_________________________________________________________________
activation_19 (Activation)  (None, 7, 7, 256)     0     
_________________________________________________________________
conv2d_21 (Conv2D)      (None, 7, 7, 256)     589824  
_________________________________________________________________
activation_20 (Activation)  (None, 7, 7, 256)     0     
_________________________________________________________________
conv2d_22 (Conv2D)      (None, 7, 7, 256)     589824  
_________________________________________________________________
activation_21 (Activation)  (None, 7, 7, 256)     0     
_________________________________________________________________
conv2d_23 (Conv2D)      (None, 7, 7, 256)     589824  
_________________________________________________________________
activation_22 (Activation)  (None, 7, 7, 256)     0     
_________________________________________________________________
conv2d_24 (Conv2D)      (None, 7, 7, 256)     589824  
_________________________________________________________________
activation_23 (Activation)  (None, 7, 7, 256)     0     
_________________________________________________________________
global_average_pooling2d (Gl (None, 256)        0     
_________________________________________________________________
dense (Dense)        (None, 10)        2570   
=================================================================
Total params: 5,827,658
Trainable params: 5,827,658
Non-trainable params: 0
_________________________________________________________________
Round 0 training loss: 2.4691953659057617, time: 17.81941556930542 secs
Round 0 validation accuracy: 10.817495346069336
Round 1 training loss: 2.3081436157226562, time: 12.986191034317017 secs
Round 2 training loss: 2.3028159141540527, time: 9.518500963846842 secs
Round 2 validation accuracy: 11.500783920288086
Round 3 training loss: 2.303886651992798, time: 8.989932537078857 secs
Round 4 training loss: 2.3030669689178467, time: 8.733866214752197 secs
Round 4 validation accuracy: 12.992260932922363

เพิ่มประสิทธิภาพการทำงาน

เทคนิคในการลุยที่สามารถบรรลุประสิทธิภาพที่ดีขึ้นโดยทั่วไปจะสามารถนำมาใช้ในฉิบหายเช่น การฝึกอบรมที่มีความแม่นยำผสม และ XLA speedup (บน GPUs เหมือน V100) และบันทึกความทรงจำของความแม่นยำผสมมักจะมีนัยสำคัญซึ่งสามารถตรวจสอบโดย tf.profiler

# Mixed precision training. 
cpu_device = tf.config.list_logical_devices('CPU')[0]
gpu_devices = tf.config.list_logical_devices('GPU')
tff.backends.native.set_local_python_execution_context(
  server_tf_device=cpu_device,
  client_tf_devices=gpu_devices, 
  clients_per_thread=1, 
  max_fanout=32)
policy = tf.keras.mixed_precision.experimental.Policy('mixed_float16')
tf.keras.mixed_precision.experimental.set_policy(policy)


run_federated_training(
  client_epochs_per_round=1, 
  train_batch_size=16, 
  test_batch_size=128, 
  cnn_num_blocks=4, 
  conv_width_multiplier=4,
  server_learning_rate=1.0, 
  client_learning_rate=0.01,
  total_rounds=5,
  clients_per_round=16, 
  rounds_per_eval=2,
  logdir='mixed'
  )
Model: "cnn-26-4"
_________________________________________________________________
Layer (type)         Output Shape       Param #  
=================================================================
input_1 (InputLayer)     [(None, 28, 28, 1)]    0     
_________________________________________________________________
tf.image.per_image_standardi (None, 28, 28, 1)     0     
_________________________________________________________________
conv2d (Conv2D)       (None, 28, 28, 64)    576    
_________________________________________________________________
conv2d_1 (Conv2D)      (None, 28, 28, 64)    36864   
_________________________________________________________________
activation (Activation)   (None, 28, 28, 64)    0     
_________________________________________________________________
conv2d_2 (Conv2D)      (None, 28, 28, 64)    36864   
_________________________________________________________________
activation_1 (Activation)  (None, 28, 28, 64)    0     
_________________________________________________________________
conv2d_3 (Conv2D)      (None, 28, 28, 64)    36864   
_________________________________________________________________
activation_2 (Activation)  (None, 28, 28, 64)    0     
_________________________________________________________________
conv2d_4 (Conv2D)      (None, 28, 28, 64)    36864   
_________________________________________________________________
activation_3 (Activation)  (None, 28, 28, 64)    0     
_________________________________________________________________
conv2d_5 (Conv2D)      (None, 28, 28, 64)    36864   
_________________________________________________________________
activation_4 (Activation)  (None, 28, 28, 64)    0     
_________________________________________________________________
conv2d_6 (Conv2D)      (None, 28, 28, 64)    36864   
_________________________________________________________________
activation_5 (Activation)  (None, 28, 28, 64)    0     
_________________________________________________________________
conv2d_7 (Conv2D)      (None, 28, 28, 64)    36864   
_________________________________________________________________
activation_6 (Activation)  (None, 28, 28, 64)    0     
_________________________________________________________________
conv2d_8 (Conv2D)      (None, 28, 28, 64)    36864   
_________________________________________________________________
activation_7 (Activation)  (None, 28, 28, 64)    0     
_________________________________________________________________
conv2d_9 (Conv2D)      (None, 14, 14, 128)    73728   
_________________________________________________________________
activation_8 (Activation)  (None, 14, 14, 128)    0     
_________________________________________________________________
conv2d_10 (Conv2D)      (None, 14, 14, 128)    147456  
_________________________________________________________________
activation_9 (Activation)  (None, 14, 14, 128)    0     
_________________________________________________________________
conv2d_11 (Conv2D)      (None, 14, 14, 128)    147456  
_________________________________________________________________
activation_10 (Activation)  (None, 14, 14, 128)    0     
_________________________________________________________________
conv2d_12 (Conv2D)      (None, 14, 14, 128)    147456  
_________________________________________________________________
activation_11 (Activation)  (None, 14, 14, 128)    0     
_________________________________________________________________
conv2d_13 (Conv2D)      (None, 14, 14, 128)    147456  
_________________________________________________________________
activation_12 (Activation)  (None, 14, 14, 128)    0     
_________________________________________________________________
conv2d_14 (Conv2D)      (None, 14, 14, 128)    147456  
_________________________________________________________________
activation_13 (Activation)  (None, 14, 14, 128)    0     
_________________________________________________________________
conv2d_15 (Conv2D)      (None, 14, 14, 128)    147456  
_________________________________________________________________
activation_14 (Activation)  (None, 14, 14, 128)    0     
_________________________________________________________________
conv2d_16 (Conv2D)      (None, 14, 14, 128)    147456  
_________________________________________________________________
activation_15 (Activation)  (None, 14, 14, 128)    0     
_________________________________________________________________
conv2d_17 (Conv2D)      (None, 7, 7, 256)     294912  
_________________________________________________________________
activation_16 (Activation)  (None, 7, 7, 256)     0     
_________________________________________________________________
conv2d_18 (Conv2D)      (None, 7, 7, 256)     589824  
_________________________________________________________________
activation_17 (Activation)  (None, 7, 7, 256)     0     
_________________________________________________________________
conv2d_19 (Conv2D)      (None, 7, 7, 256)     589824  
_________________________________________________________________
activation_18 (Activation)  (None, 7, 7, 256)     0     
_________________________________________________________________
conv2d_20 (Conv2D)      (None, 7, 7, 256)     589824  
_________________________________________________________________
activation_19 (Activation)  (None, 7, 7, 256)     0     
_________________________________________________________________
conv2d_21 (Conv2D)      (None, 7, 7, 256)     589824  
_________________________________________________________________
activation_20 (Activation)  (None, 7, 7, 256)     0     
_________________________________________________________________
conv2d_22 (Conv2D)      (None, 7, 7, 256)     589824  
_________________________________________________________________
activation_21 (Activation)  (None, 7, 7, 256)     0     
_________________________________________________________________
conv2d_23 (Conv2D)      (None, 7, 7, 256)     589824  
_________________________________________________________________
activation_22 (Activation)  (None, 7, 7, 256)     0     
_________________________________________________________________
conv2d_24 (Conv2D)      (None, 7, 7, 256)     589824  
_________________________________________________________________
activation_23 (Activation)  (None, 7, 7, 256)     0     
_________________________________________________________________
global_average_pooling2d (Gl (None, 256)        0     
_________________________________________________________________
dense (Dense)        (None, 10)        2570   
=================================================================
Total params: 5,827,658
Trainable params: 5,827,658
Non-trainable params: 0
_________________________________________________________________
Round 0 training loss: 2.4187185764312744, time: 18.763780117034912 secs
Round 0 validation accuracy: 9.977468490600586
Round 1 training loss: 2.305102825164795, time: 13.712820529937744 secs
Round 2 training loss: 2.304737091064453, time: 9.993690172831217 secs
Round 2 validation accuracy: 11.779976844787598
Round 3 training loss: 2.2996833324432373, time: 9.29404467344284 secs
Round 4 training loss: 2.299349308013916, time: 9.195427560806275 secs
Round 4 validation accuracy: 11.224040031433105
%tensorboard --logdir=mixed --port=0