テキスト生成のフェデレーテッドラーニング

TensorFlow.orgで表示 Google Colab で実行 GitHub でソースを表示{

注意: この Colab は tensorflow_federated pip パッケージの最新リリースバージョンでの動作が確認されていますが、Tensorflow Federated プロジェクトは現在もプレリリース開発の段階にあるため、master では動作しない可能性があります。

このチュートリアルは、画像分類のフェデレーテッドラーニングチュートリアルの概念に基づいて構成されており、フェデレーテッドラーニングの便利なアプローチをいくつか実演します。

具体的には、以前にトレーニングした Keras モデルを読み込み、(シミュレーションされた)分散データセットでフェデレーテッドラーニングを使ってそのモデルをさらに洗練します。これはいくつかの理由により特に重要な作業です。シリアル化されたモデルを使用できることで、フェデレーテッドラーニングをほかの機械学習アプローチに簡単に混ぜることができるようになります。さらに、広範なトレーニング済みのモデルを使用することも可能です。たとえば、トレーニング済みの言語モデルは広く提供されてるため(TF Hub など)、モデルをゼロからトレーニングする必要はほとんどありません。そのため、トレーニング済みのモデルを開始点に、フェデレーテッドラーニングを使って洗練させ、特定のアプリケーションに使用する分散データセットの特性に合わせて調整する方が合理的と言えます。

このチュートリアルでは、ASCII 文字を生成する RNN より開始し、フェデレーテッドラーニングを通じて精緻化します。また、最終的な重みを元の Keras モデルにフィードし直し、評価とテキスト生成を標準のツールを使って簡単に行う方法も紹介します。

pip install --quiet --upgrade tensorflow_federated
import collections
import functools
import os
import time

import numpy as np
import tensorflow as tf
import tensorflow_federated as tff

np.random.seed(0)

# Test the TFF is working:
tff.federated_computation(lambda: 'Hello, World!')()
b'Hello, World!'

トレーニング済みモデルを読み込む

TensorFlow チュートリアル「Eager execution を使った RNN によるテキスト生成」に従ってトレーニングされたモデルを使用しますが、The Complete Works of Shakespeare を使用する代わりに、チャールズ・ディケンズの「A Tale of Two Cities」と「A Christmas Carol」のテキストでモデルを事前トレーニングしています。

語彙を拡大する以外は元のチュートリアルを変更していないため、初期モデルは最新の状態ではありませんが、合理的な予測を生成するものであり、このチュートリアルの目的には十分と言えます。最終モデルは tf.keras.models.save_model(include_optimizer=False) を使って保存されています。

このチュートリアルでは、フェデレーテッドラーニングを使用して、このシェイクスピアのモデルを精緻化します。TFF が提供するフェデレーテッドバージョンのデータを使用します。

vocab ルックアップテーブルの生成

# A fixed vocabularly of ASCII chars that occur in the works of Shakespeare and Dickens:
vocab = list('dhlptx@DHLPTX $(,048cgkoswCGKOSW[_#\'/37;?bfjnrvzBFJNRVZ"&*.26:\naeimquyAEIMQUY]!%)-159\r')

# Creating a mapping from unique characters to indices
char2idx = {u:i for i, u in enumerate(vocab)}
idx2char = np.array(vocab)

トレーニング済みモデルの読み込みとテキストの生成

def load_model(batch_size):
  urls = {
      1: 'https://storage.googleapis.com/tff-models-public/dickens_rnn.batch1.kerasmodel',
      8: 'https://storage.googleapis.com/tff-models-public/dickens_rnn.batch8.kerasmodel'}
  assert batch_size in urls, 'batch_size must be in ' + str(urls.keys())
  url = urls[batch_size]
  local_file = tf.keras.utils.get_file(os.path.basename(url), origin=url)  
  return tf.keras.models.load_model(local_file, compile=False)
def generate_text(model, start_string):
  # From https://www.tensorflow.org/tutorials/sequences/text_generation
  num_generate = 200
  input_eval = [char2idx[s] for s in start_string]
  input_eval = tf.expand_dims(input_eval, 0)
  text_generated = []
  temperature = 1.0

  model.reset_states()
  for i in range(num_generate):
    predictions = model(input_eval)
    predictions = tf.squeeze(predictions, 0)
    predictions = predictions / temperature
    predicted_id = tf.random.categorical(
        predictions, num_samples=1)[-1, 0].numpy()
    input_eval = tf.expand_dims([predicted_id], 0)
    text_generated.append(idx2char[predicted_id])

  return (start_string + ''.join(text_generated))
# Text generation requires a batch_size=1 model.
keras_model_batch1 = load_model(batch_size=1)
print(generate_text(keras_model_batch1, 'What of TensorFlow Federated, you ask? '))
Downloading data from https://storage.googleapis.com/tff-models-public/dickens_rnn.batch1.kerasmodel
16195584/16193984 [==============================] - 0s 0us/step
16203776/16193984 [==============================] - 0s 0us/step
What of TensorFlow Federated, you ask? Sall
yesterday. Received the Bailey."

"Mr. Lorry, grimmering himself, or low varked thends the winter, and the eyes of Monsieur
Defarge. "Let his mind, hon in his
life and message; four declare

Shakespere のフェデレーテッドデータを読み込んで事前処理する

tff.simulation.datasets パッケージには、"clients" に分割されたさまざまなデータセットが含まれます。各 client はフェデレーテッドラーニングに含まれる可能性のある特定のデバイス上のデータセットに対応しています。

これらのデータセットは、実際の分散データでのトレーニングの課題をシミュレーションで再現する現実的な非 IID データ分布を示します。このデータの事前処理は、Leaf projectgithub)のツールを使用して行われています。

train_data, test_data = tff.simulation.datasets.shakespeare.load_data()

shakespeare.load_data() が提供するデータセットは、文字列 Tensors で構成されています。各行はシェイクスピア劇の登場人物のセリフです。client キーは、劇の名前と登場人物の名前を結合したもので、たとえば
MUCH_ADO_ABOUT_NOTHING_OTHELLO は「Much Ado About Nothing」という劇の登場人物オセロのセリフに対応しています。実勢のフェデレーテッドラーニングシナリオでは、client は ID で識別または追跡されることはありませんが、シミュレーションでは、キー付きのデータセットを使用する方が役に立ちます。

ここでは、King Lear のデータを例とします。

# Here the play is "The Tragedy of King Lear" and the character is "King".
raw_example_dataset = train_data.create_tf_dataset_for_client(
    'THE_TRAGEDY_OF_KING_LEAR_KING')
# To allow for future extensions, each entry x
# is an OrderedDict with a single key 'snippets' which contains the text.
for x in raw_example_dataset.take(2):
  print(x['snippets'])
tf.Tensor(b'', shape=(), dtype=string)
tf.Tensor(b'What?', shape=(), dtype=string)

tf.data.Dataset 変換を使用して、このデータを上記で読み込んだ文字列 RNN のトレーニング用に準備します。

# Input pre-processing parameters
SEQ_LENGTH = 100
BATCH_SIZE = 8
BUFFER_SIZE = 10000  # For dataset shuffling
# Construct a lookup table to map string chars to indexes,
# using the vocab loaded above:
table = tf.lookup.StaticHashTable(
    tf.lookup.KeyValueTensorInitializer(
        keys=vocab, values=tf.constant(list(range(len(vocab))),
                                       dtype=tf.int64)),
    default_value=0)


def to_ids(x):
  s = tf.reshape(x['snippets'], shape=[1])
  chars = tf.strings.bytes_split(s).values
  ids = table.lookup(chars)
  return ids


def split_input_target(chunk):
  input_text = tf.map_fn(lambda x: x[:-1], chunk)
  target_text = tf.map_fn(lambda x: x[1:], chunk)
  return (input_text, target_text)


def preprocess(dataset):
  return (
      # Map ASCII chars to int64 indexes using the vocab
      dataset.map(to_ids)
      # Split into individual chars
      .unbatch()
      # Form example sequences of SEQ_LENGTH +1
      .batch(SEQ_LENGTH + 1, drop_remainder=True)
      # Shuffle and form minibatches
      .shuffle(BUFFER_SIZE).batch(BATCH_SIZE, drop_remainder=True)
      # And finally split into (input, target) tuples,
      # each of length SEQ_LENGTH.
      .map(split_input_target))

元のシーケンスの形成と上記のバッチの形成では、drop_remainder=True を使って単純化しています。つまり、少なくともテキストの (SEQ_LENGTH + 1) * BATCH_SIZE 文字を持たない登場人物(client))のデータセットは空となります。この状況を解消するために使用される一般的なアプローチはバッチを特殊なトークンでパッドし、パディングトークンを考慮しないように損失量をマスクする方法です。

これではサンプルが複雑化してしまうため、このチュートリアルでは標準的なチュートリアルと同様にフルバッチのみを使用します。ただし、多数のユーザーが小さなデータセットを持つことになるため、フェデレーテッドの設定ではこの問題はより明確に現れます。

では、raw_example_dataset を事前処理し、型を確認しましょう。

example_dataset = preprocess(raw_example_dataset)
print(example_dataset.element_spec)
(TensorSpec(shape=(8, 100), dtype=tf.int64, name=None), TensorSpec(shape=(8, 100), dtype=tf.int64, name=None))

モデルをコンパイルし、事前処理済みのデータでテストする

コンパイルされていない Keras モデルを読み込みましたが、keras_model.evaluate を実行するには、損失とメトリックとともにコンパイルする必要があります。また、オプティマイザにコンパイルし、フェデレーテッドラーニングのオンデバイスオプティマイザとして使用します。

元のチュートリアルには文字レベルの精度(最高確率が適切な次の文字に配置される予測の割合)が含まれていませんでした。これは便利なメトリックであるため、追加することにします。ただし、予測の階数は 3(BATCH_SIZE * SEQ_LENGTH の各予測に対するロジットのベクトル)であり、SparseCategoricalAccuracy は階数 2 の予測のみを期待するため、新しいメトリッククラスを定義する必要があります。

class FlattenedCategoricalAccuracy(tf.keras.metrics.SparseCategoricalAccuracy):

  def __init__(self, name='accuracy', dtype=tf.float32):
    super().__init__(name, dtype=dtype)

  def update_state(self, y_true, y_pred, sample_weight=None):
    y_true = tf.reshape(y_true, [-1, 1])
    y_pred = tf.reshape(y_pred, [-1, len(vocab), 1])
    return super().update_state(y_true, y_pred, sample_weight)

これで、モデルをコンパイルし、example_dataset で評価できるようになりました。

BATCH_SIZE = 8  # The training and eval batch size for the rest of this tutorial.
keras_model = load_model(batch_size=BATCH_SIZE)
keras_model.compile(
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=[FlattenedCategoricalAccuracy()])

# Confirm that loss is much lower on Shakespeare than on random data
loss, accuracy = keras_model.evaluate(example_dataset.take(5), verbose=0)
print(
    'Evaluating on an example Shakespeare character: {a:3f}'.format(a=accuracy))

# As a sanity check, we can construct some completely random data, where we expect
# the accuracy to be essentially random:
random_guessed_accuracy = 1.0 / len(vocab)
print('Expected accuracy for random guessing: {a:.3f}'.format(
    a=random_guessed_accuracy))
random_indexes = np.random.randint(
    low=0, high=len(vocab), size=1 * BATCH_SIZE * (SEQ_LENGTH + 1))
data = collections.OrderedDict(
    snippets=tf.constant(
        ''.join(np.array(vocab)[random_indexes]), shape=[1, 1]))
random_dataset = preprocess(tf.data.Dataset.from_tensor_slices(data))
loss, accuracy = keras_model.evaluate(random_dataset, steps=10, verbose=0)
print('Evaluating on completely random data: {a:.3f}'.format(a=accuracy))
Downloading data from https://storage.googleapis.com/tff-models-public/dickens_rnn.batch8.kerasmodel
16195584/16193984 [==============================] - 0s 0us/step
16203776/16193984 [==============================] - 0s 0us/step
Evaluating on an example Shakespeare character: 0.402750
Expected accuracy for random guessing: 0.012
Evaluating on completely random data: 0.013

フェデレーテッドラーニングでモデルを微調整する

TFF はすべての TensorFlow 計算をシリアル化するため、非 Python 環境で実行することが可能です(現時点では、Python で実装されたシミュレーションランタイムのみを利用できます)。Eager モードで実行してはいますが(TF 2.0)、現時点では、TFF は "with tf.Graph.as_default()" 文のコンテキスト内に必要な演算を作成して、TensorFlow 計算をシリアル化しています。したがって、モデルを関数が制御するグラフに導入するために TFF が使用できる関数を指定する必要があります。これを次のようにして行います。

# Clone the keras_model inside `create_tff_model()`, which TFF will
# call to produce a new copy of the model inside the graph that it will 
# serialize. Note: we want to construct all the necessary objects we'll need 
# _inside_ this method.
def create_tff_model():
  # TFF uses an `input_spec` so it knows the types and shapes
  # that your model expects.
  input_spec = example_dataset.element_spec
  keras_model_clone = tf.keras.models.clone_model(keras_model)
  return tff.learning.from_keras_model(
      keras_model_clone,
      input_spec=input_spec,
      loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
      metrics=[FlattenedCategoricalAccuracy()])

これで、フェデレーテッドアベレージングのイテレーション処理を構築する準備が整いました。これをモデルの改善に使用します(フェデレーテッドアベレージングアルゴリズムの詳細は、論文「Communication-Efficient Learning of Deep Networks from Decentralized Data」をご覧ください)。

コンパイルされた Keras モデルを使用して、フェデレーテッドトレーニングの各ラウンドの後に標準的な(非フェデレーテッド)評価を実行します。これは、シミュレーションされたフェデレーテッドラーニングを行う場合で標準的なテストデータセットがある場合の研究目的に役立ちます。

現実的な実稼働環境では、これと同じテクニックを使用してフェデレーテッドラーニングでモデルをトレーニングし、テストや QA を行えるように分散ベンチマークデータセットで評価します。

# This command builds all the TensorFlow graphs and serializes them: 
fed_avg = tff.learning.build_federated_averaging_process(
    model_fn=create_tff_model,
    client_optimizer_fn=lambda: tf.keras.optimizers.SGD(lr=0.5))

次は最も単純なループで、1 つのバッチの単一の client における 1 つのラウンドで、フェデレーテッドアベレージングを実行します。

state = fed_avg.initialize()
state, metrics = fed_avg.next(state, [example_dataset.take(5)])
train_metrics = metrics['train']
print('loss={l:.3f}, accuracy={a:.3f}'.format(
    l=train_metrics['loss'], a=train_metrics['accuracy']))
loss=4.403, accuracy=0.132

では、もう少し興味深いトレーニングと評価ループを記述してみましょう。

このシミュレーションは比較的素早く実行します。各ラウンドで同一の 3 つの client でトレーニングしますが、それぞれで 2 つのミニバッチのみを考慮しています。

def data(client, source=train_data):
  return preprocess(source.create_tf_dataset_for_client(client)).take(5)


clients = [
    'ALL_S_WELL_THAT_ENDS_WELL_CELIA', 'MUCH_ADO_ABOUT_NOTHING_OTHELLO',
]

train_datasets = [data(client) for client in clients]

# We concatenate the test datasets for evaluation with Keras by creating a 
# Dataset of Datasets, and then identity flat mapping across all the examples.
test_dataset = tf.data.Dataset.from_tensor_slices(
    [data(client, test_data) for client in clients]).flat_map(lambda x: x)

fed_avg.initialize() で生成されるモデルの最初の状態は、読み込まれた重みではなく、Keras モデルのランダムなイニシャライザに基づきます。clone_model() は重みを複製しないためです。トレーニング済みのモデルからトレーニングを始めるには、読み込んだモデルから直接、サーバー状態のモデルの重みを設定します。

NUM_ROUNDS = 5

# The state of the FL server, containing the model and optimization state.
state = fed_avg.initialize()

# Load our pre-trained Keras model weights into the global model state.
state = tff.learning.state_with_new_model_weights(
    state,
    trainable_weights=[v.numpy() for v in keras_model.trainable_weights],
    non_trainable_weights=[
        v.numpy() for v in keras_model.non_trainable_weights
    ])


def keras_evaluate(state, round_num):
  # Take our global model weights and push them back into a Keras model to
  # use its standard `.evaluate()` method.
  keras_model = load_model(batch_size=BATCH_SIZE)
  keras_model.compile(
      loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
      metrics=[FlattenedCategoricalAccuracy()])
  state.model.assign_weights_to(keras_model)
  loss, accuracy = keras_model.evaluate(example_dataset, steps=2, verbose=0)
  print('\tEval: loss={l:.3f}, accuracy={a:.3f}'.format(l=loss, a=accuracy))


for round_num in range(NUM_ROUNDS):
  print('Round {r}'.format(r=round_num))
  keras_evaluate(state, round_num)
  state, metrics = fed_avg.next(state, train_datasets)
  train_metrics = metrics['train']
  print('\tTrain: loss={l:.3f}, accuracy={a:.3f}'.format(
      l=train_metrics['loss'], a=train_metrics['accuracy']))

print('Final evaluation')
keras_evaluate(state, NUM_ROUNDS + 1)
Round 0
    Eval: loss=3.372, accuracy=0.395
    Train: loss=4.317, accuracy=0.083
Round 1
    Eval: loss=4.300, accuracy=0.129
    Train: loss=4.172, accuracy=0.184
Round 2
    Eval: loss=4.152, accuracy=0.201
    Train: loss=4.077, accuracy=0.191
Round 3
    Eval: loss=4.031, accuracy=0.189
    Train: loss=3.965, accuracy=0.192
Round 4
    Eval: loss=3.946, accuracy=0.183
    Train: loss=3.877, accuracy=0.196
    Eval: loss=3.885, accuracy=0.168

デフォルトの変更により、大きな違いを得るほどのトレーニングはまだ行われていませんが、より長時間、より多くの Shakespeare データをトレーニングする場合、更新したモデルに生成されるテキストのスタイルに違いがみられるようになります。

# Set our newly trained weights back in the originally created model.
keras_model_batch1.set_weights([v.numpy() for v in keras_model.weights])
# Text generation requires batch_size=1
print(generate_text(keras_model_batch1, 'What of TensorFlow Federated, you ask? '))
What of TensorFlow Federated, you ask? Shalways, I will call your
compet with any city brought their faces uncompany," besumed him. "When he
sticked Madame Defarge pushed the lamps.

"Have I often but no unison. She had probably come,

推奨される拡張

このチュートリアルは導入ステップにしかすぎません!次に、このノートブックを拡張するためのアイデアをいくつか示しています。

  • トレーニングする client をランダムにサンプリングするより現実的なトレーニングループを記述する。
  • client データセットに ".repeat(NUM_EPOCHS)" を使用して、ローカルトレーニングの複数のエポックを試してみる(McMahan et. al. で示す例)。これを行っている画像分類のフェデレーテッドラーニングもご覧ください。
  • compile() コマンドを変更して、client でさまざまな最適化アルゴリズムを使った実験を行う。
  • build_federated_averaging_processserver_optimizer 属性を使用し、サーバー上にモデルの更新を適用するためのさまざまなアルゴリズムを試してみる。
  • build_federated_averaging_processclient_weight_fn 属性を使用して、client のさまざまな重みづけを試してみる。デフォルトは、client のサンプル数で client の更新を重みづけしますが、client_weight_fn=lambda _: tf.constant(1.0) などのように行うことができます。