日付を保存! Google I / Oが5月18日から20日に戻ってきます今すぐ登録

High-performance Simulation with Kubernetes

This tutorial will describe how to set up high-performance simulation using a TFF runtime running on Kubernetes. The model is the same as in the previous tutorial, High-performance simulations with TFF. The only difference is that here we use a worker pool instead of a local executor.

This tutorial refers to Google Cloud's GKE to create the Kubernetes cluster, but all the steps after the cluster is created can be used with any Kubernetes installation.

View on TensorFlow.org Run in Google Colab View source on GitHub

GKE で TFF ワーカーを起動する

注: このチュートリアルは、ユーザーが既存の GCP プロジェクトを持っていることを前提としています。

Kubernetes クラスタの作成

次の手順は一度だけ実行する必要があります。クラスタは、今後のワークロードに再利用できます。

GKE の指示に従って、コンテナクラスタを作成します。このチュートリアルの以下の部分では、クラスタ名がtff-clusterであると想定していますが、実際の名前は重要ではありません。「ステップ 5: アプリケーションのデプロイ」に到達したら、指示に従うのをやめます。

TFF ワーカーアプリケーションをデプロイする

GCP とやり取りするコマンドは、ローカルまたは Google Cloud Shell で実行できます。Google Cloud Shell では、追加の設定は必要ないため、Google Cloud Shell の使用をお勧めします。

  1. 次のコマンドを実行して、Kubernetes アプリケーションを起動します。
$ kubectl create deployment tff-workers --image=gcr.io/tensorflow-federated/remote-executor-service:{ {version} }
  1. アプリケーションのロードバランサを追加します。
$ kubectl expose deployment tff-workers --type=LoadBalancer --port 80 --target-port 8000

注: これにより、デプロイメントがインターネットに公開されますが、これはデモのみのためです。運用環境では、ファイアウォールと認証を強くお勧めします。

Google Cloud Console でロードバランサの IP アドレスを検索します。後でトレーニングループをワーカーアプリに接続するために必要になります。

(または) Docker コンテナをローカルで起動する

$ docker run --rm -p 8000:8000 gcr.io/tensorflow-federated/remote_executor_service:{ {version} }

TFF 環境の設定

pip install --upgrade tensorflow_federated

トレーニングするモデルの定義

import collections
import time

import tensorflow as tf
import tensorflow_federated as tff

source, _ = tff.simulation.datasets.emnist.load_data()


def map_fn(example):
  return collections.OrderedDict(
      x=tf.reshape(example['pixels'], [-1, 784]), y=example['label'])


def client_data(n):
  ds = source.create_tf_dataset_for_client(source.client_ids[n])
  return ds.repeat(10).batch(20).map(map_fn)


train_data = [client_data(n) for n in range(10)]
input_spec = train_data[0].element_spec


def model_fn():
  model = tf.keras.models.Sequential([
      tf.keras.layers.Input(shape=(784,)),
      tf.keras.layers.Dense(units=10, kernel_initializer='zeros'),
      tf.keras.layers.Softmax(),
  ])
  return tff.learning.from_keras_model(
      model,
      input_spec=input_spec,
      loss=tf.keras.losses.SparseCategoricalCrossentropy(),
      metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])


trainer = tff.learning.build_federated_averaging_process(
    model_fn, client_optimizer_fn=lambda: tf.keras.optimizers.SGD(0.02))


def evaluate(num_rounds=10):
  state = trainer.initialize()
  for round in range(num_rounds):
    t1 = time.time()
    state, metrics = trainer.next(state, train_data)
    t2 = time.time()
    print('Round {}: loss {}, round time {}'.format(round, metrics.loss, t2 - t1))

リモートエグゼキュータのセットアップ

デフォルトでは、TFF はすべての計算をローカルで実行します。このステップでは、上で設定した Kubernetes サービスに接続するよう TFF に指示します。サービスの IP アドレスは必ずここにコピーします。

import grpc

ip_address = '0.0.0.0' 
port = 80 

client_ex = []
for i in range(10):
  channel = grpc.insecure_channel('{}:{}'.format(ip_address, port))
  client_ex.append(tff.framework.RemoteExecutor(channel, rpc_mode='STREAMING'))

factory = tff.framework.worker_pool_executor_factory(client_ex)
context = tff.framework.ExecutionContext(factory)
tff.framework.set_default_context(context)

トレーニングの実行

evaluate()
Round 0: loss 4.370407581329346, round time 4.201097726821899
Round 1: loss 4.1407670974731445, round time 3.3283166885375977
Round 2: loss 3.865147590637207, round time 3.098310947418213
Round 3: loss 3.534019708633423, round time 3.1565616130828857
Round 4: loss 3.272688388824463, round time 3.175067663192749
Round 5: loss 2.935391664505005, round time 3.008434534072876
Round 6: loss 2.7399251461029053, round time 3.31435227394104
Round 7: loss 2.5054931640625, round time 3.4411356449127197
Round 8: loss 2.290508985519409, round time 3.158798933029175
Round 9: loss 2.1194536685943604, round time 3.1348156929016113