このページは Cloud Translation API によって翻訳されました。
Switch to English

テキスト生成のための連合学習

TensorFlow.orgで表示 GoogleColabで実行 GitHubでソースを表示

このチュートリアルは、画像分類のための連合学習チュートリアルの概念に基づいており、連合学習のための他のいくつかの有用なアプローチを示しています。

特に、以前にトレーニングされたKerasモデルをロードし、(シミュレートされた)分散データセットでフェデレーショントレーニングを使用してモデルを改良します。これはいくつかの理由で実際に重要です。シリアル化されたモデルを使用できるため、フェデレーション学習を他のMLアプローチと簡単に組み合わせることができます。さらに、これにより、事前にトレーニングされたモデルの範囲が広がります。たとえば、事前にトレーニングされたモデルが多数利用できるようになったため、言語モデルを最初からトレーニングする必要はほとんどありません( TF Hubなどを参照)。代わりに、事前にトレーニングされたモデルから開始し、特定のアプリケーションの分散データの特定の特性に適応して、フェデレーション学習を使用してモデルを改良する方が理にかなっています。

このチュートリアルでは、ASCII文字を生成するRNNから始めて、フェデレーション学習を介してそれを改良します。また、最終的な重みを元のKerasモデルにフィードバックして、標準ツールを使用して簡単に評価およびテキスト生成できるようにする方法も示します。

!pip install --quiet --upgrade tensorflow_federated_nightly
!pip install --quiet --upgrade nest_asyncio

import nest_asyncio
nest_asyncio.apply()
import collections
import functools
import os
import time

import numpy as np
import tensorflow as tf
import tensorflow_federated as tff

np.random.seed(0)

# Test the TFF is working:
tff.federated_computation(lambda: 'Hello, World!')()
b'Hello, World!'

事前にトレーニングされたモデルをロードする

熱心な実行を伴うRNNを使用したTensorFlowチュートリアルテキスト生成に従って事前トレーニングされたモデルをロードします。ただし、シェイクスピア全集でトレーニングするではなく、チャールズ・ディケンズの「二都物語」「クリスマスキャロル」のテキストでモデルを事前にトレーニングしました。

語彙を拡張する以外は、元のチュートリアルを変更しなかったため、この初期モデルは最先端ではありませんが、妥当な予測を生成し、チュートリアルの目的には十分です。最終的なモデルはtf.keras.models.save_model(include_optimizer=False)保存されました。

このチュートリアルでは、TFFが提供するデータのフェデレーションバージョンを使用して、フェデレーション学習を使用してシェイクスピアのこのモデルを微調整します。

語彙ルックアップテーブルを生成する

# A fixed vocabularly of ASCII chars that occur in the works of Shakespeare and Dickens:
vocab = list('dhlptx@DHLPTX $(,048cgkoswCGKOSW[_#\'/37;?bfjnrvzBFJNRVZ"&*.26:\naeimquyAEIMQUY]!%)-159\r')

# Creating a mapping from unique characters to indices
char2idx = {u:i for i, u in enumerate(vocab)}
idx2char = np.array(vocab)

事前にトレーニングされたモデルをロードし、テキストを生成します

def load_model(batch_size):
  urls = {
      1: 'https://storage.googleapis.com/tff-models-public/dickens_rnn.batch1.kerasmodel',
      8: 'https://storage.googleapis.com/tff-models-public/dickens_rnn.batch8.kerasmodel'}
  assert batch_size in urls, 'batch_size must be in ' + str(urls.keys())
  url = urls[batch_size]
  local_file = tf.keras.utils.get_file(os.path.basename(url), origin=url)  
  return tf.keras.models.load_model(local_file, compile=False)
def generate_text(model, start_string):
  # From https://www.tensorflow.org/tutorials/sequences/text_generation
  num_generate = 200
  input_eval = [char2idx[s] for s in start_string]
  input_eval = tf.expand_dims(input_eval, 0)
  text_generated = []
  temperature = 1.0

  model.reset_states()
  for i in range(num_generate):
    predictions = model(input_eval)
    predictions = tf.squeeze(predictions, 0)
    predictions = predictions / temperature
    predicted_id = tf.random.categorical(
        predictions, num_samples=1)[-1, 0].numpy()
    input_eval = tf.expand_dims([predicted_id], 0)
    text_generated.append(idx2char[predicted_id])

  return (start_string + ''.join(text_generated))
# Text generation requires a batch_size=1 model.
keras_model_batch1 = load_model(batch_size=1)
print(generate_text(keras_model_batch1, 'What of TensorFlow Federated, you ask? '))
Downloading data from https://storage.googleapis.com/tff-models-public/dickens_rnn.batch1.kerasmodel
16195584/16193984 [==============================] - 0s 0us/step
16203776/16193984 [==============================] - 0s 0us/step
What of TensorFlow Federated, you ask? Sall
yesterday. Received the Bailey."

"Mr. Lorry, grimmering himself, or low varked thends the winter, and the eyes of Monsieur
Defarge. "Let his mind, hon in his
life and message; four declare 

フェデレーションシェイクスピアデータをロードして前処理する

tff.simulation.datasetsパッケージは、「クライアント」に分割されたさまざまなデータセットを提供します。各クライアントは、フェデレーション学習に参加する可能性のある特定のデバイス上のデータセットに対応します。

これらのデータセットは、実際の分散データのトレーニングの課題をシミュレーションで再現する現実的な非IIDデータ分布を提供します。このデータの前処理の一部は、 Leafプロジェクトgithub )のツールを使用して行われました。

train_data, test_data = tff.simulation.datasets.shakespeare.load_data()

Tensors shakespeare.load_data()によって提供されるデータセットは、シェイクスピア劇の特定のキャラクターによって話されたTensorsに1つずつ、文字列Tensorsシーケンスで構成されています。クライアントキーは、そう、たとえば、文字の名前で参加しました遊びの名前で構成さMUCH_ADO_ABOUT_NOTHING_OTHELLOプレイから騒ぎ内の文字オセロのためのラインに対応しています。実際のフェデレーション学習シナリオでは、クライアントがIDによって識別または追跡されることはありませんが、シミュレーションでは、キー付きデータセットを操作すると便利です。

ここでは、たとえば、リア王からのいくつかのデータを見ることができます。

# Here the play is "The Tragedy of King Lear" and the character is "King".
raw_example_dataset = train_data.create_tf_dataset_for_client(
    'THE_TRAGEDY_OF_KING_LEAR_KING')
# To allow for future extensions, each entry x
# is an OrderedDict with a single key 'snippets' which contains the text.
for x in raw_example_dataset.take(2):
  print(x['snippets'])
tf.Tensor(b'', shape=(), dtype=string)
tf.Tensor(b'What?', shape=(), dtype=string)

ここで、tf.data.Dataset変換を使用して、上記でロードしたtf.data.Datasetをトレーニングするためにこのデータを準備します。

# Input pre-processing parameters
SEQ_LENGTH = 100
BATCH_SIZE = 8
BUFFER_SIZE = 100  # For dataset shuffling
# Construct a lookup table to map string chars to indexes,
# using the vocab loaded above:
table = tf.lookup.StaticHashTable(
    tf.lookup.KeyValueTensorInitializer(
        keys=vocab, values=tf.constant(list(range(len(vocab))),
                                       dtype=tf.int64)),
    default_value=0)


def to_ids(x):
  s = tf.reshape(x['snippets'], shape=[1])
  chars = tf.strings.bytes_split(s).values
  ids = table.lookup(chars)
  return ids


def split_input_target(chunk):
  input_text = tf.map_fn(lambda x: x[:-1], chunk)
  target_text = tf.map_fn(lambda x: x[1:], chunk)
  return (input_text, target_text)


def preprocess(dataset):
  return (
      # Map ASCII chars to int64 indexes using the vocab
      dataset.map(to_ids)
      # Split into individual chars
      .unbatch()
      # Form example sequences of SEQ_LENGTH +1
      .batch(SEQ_LENGTH + 1, drop_remainder=True)
      # Shuffle and form minibatches
      .shuffle(BUFFER_SIZE).batch(BATCH_SIZE, drop_remainder=True)
      # And finally split into (input, target) tuples,
      # each of length SEQ_LENGTH.
      .map(split_input_target))

元のシーケンスの形成および上記のバッチの形成では、簡単にするためにdrop_remainder=Trueを使用していることに注意してください。これは、少なくとも(SEQ_LENGTH + 1) * BATCH_SIZE文字のテキストを持たない文字(クライアント)には空のデータセットがあることを(SEQ_LENGTH + 1) * BATCH_SIZEます。これに対処するための一般的なアプローチは、バッチに特別なトークンをパディングし、パディングトークンを考慮しないように損失をマスクすることです。

これにより例が多少複雑になるため、このチュートリアルでは、標準のチュートリアルと同様に、完全なバッチのみを使用します。ただし、フェデレーション設定では、多くのユーザーが小さなデータセットを使用している可能性があるため、この問題はより重大です。

これで、 raw_example_datasetを前処理し、タイプを確認できます。

example_dataset = preprocess(raw_example_dataset)
print(example_dataset.element_spec)
(TensorSpec(shape=(8, 100), dtype=tf.int64, name=None), TensorSpec(shape=(8, 100), dtype=tf.int64, name=None))

モデルをコンパイルし、前処理されたデータでテストします

コンパイルされていないkerasモデルをロードしましたが、 keras_model.evaluateを実行するには、損失とメトリックを使用してコンパイルする必要があります。また、フェデレーションラーニングでオンデバイスオプティマイザーとして使用されるオプティマイザーでコンパイルします。

元のチュートリアルには、文字レベルの精度(最も高い確率が正しい次の文字に配置された予測の割合)がありませんでした。これは便利な指標なので、追加します。ただし、予測にはランク3( BATCH_SIZE * SEQ_LENGTH予測ごとのロジットのベクトル)があり、 SparseCategoricalAccuracyはランク2の予測のみを期待するため、このための新しいメトリッククラスを定義する必要があります。

class FlattenedCategoricalAccuracy(tf.keras.metrics.SparseCategoricalAccuracy):

  def __init__(self, name='accuracy', dtype=tf.float32):
    super().__init__(name, dtype=dtype)

  def update_state(self, y_true, y_pred, sample_weight=None):
    y_true = tf.reshape(y_true, [-1, 1])
    y_pred = tf.reshape(y_pred, [-1, len(vocab), 1])
    return super().update_state(y_true, y_pred, sample_weight)

これで、モデルをコンパイルして、 example_dataset評価できます。

BATCH_SIZE = 8  # The training and eval batch size for the rest of this tutorial.
keras_model = load_model(batch_size=BATCH_SIZE)
keras_model.compile(
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=[FlattenedCategoricalAccuracy()])

# Confirm that loss is much lower on Shakespeare than on random data
loss, accuracy = keras_model.evaluate(example_dataset.take(5), verbose=0)
print(
    'Evaluating on an example Shakespeare character: {a:3f}'.format(a=accuracy))

# As a sanity check, we can construct some completely random data, where we expect
# the accuracy to be essentially random:
random_guessed_accuracy = 1.0 / len(vocab)
print('Expected accuracy for random guessing: {a:.3f}'.format(
    a=random_guessed_accuracy))
random_indexes = np.random.randint(
    low=0, high=len(vocab), size=1 * BATCH_SIZE * (SEQ_LENGTH + 1))
data = collections.OrderedDict(
    snippets=tf.constant(
        ''.join(np.array(vocab)[random_indexes]), shape=[1, 1]))
random_dataset = preprocess(tf.data.Dataset.from_tensor_slices(data))
loss, accuracy = keras_model.evaluate(random_dataset, steps=10, verbose=0)
print('Evaluating on completely random data: {a:.3f}'.format(a=accuracy))
Downloading data from https://storage.googleapis.com/tff-models-public/dickens_rnn.batch8.kerasmodel
16195584/16193984 [==============================] - 0s 0us/step
16203776/16193984 [==============================] - 0s 0us/step
Evaluating on an example Shakespeare character: 0.402000
Expected accuracy for random guessing: 0.012
Evaluating on completely random data: 0.011

FederatedLearningを使用してモデルを微調整します

TFFはすべてのTensorFlow計算をシリアル化するため、Python以外の環境で実行できる可能性があります(現時点では、Pythonで実装されたシミュレーションランタイムのみが利用可能です)。イーガーモード(TF 2.0)で実行している場合でも、現在TFFは、「 with tf.Graph.as_default() 」ステートメントのコンテキスト内で必要なopsを構築することにより、TensorFlow計算with tf.Graph.as_default()ます。したがって、TFFが制御するグラフにモデルを導入するために使用できる関数を提供する必要があります。これは次のように行います。

# Clone the keras_model inside `create_tff_model()`, which TFF will
# call to produce a new copy of the model inside the graph that it will 
# serialize. Note: we want to construct all the necessary objects we'll need 
# _inside_ this method.
def create_tff_model():
  # TFF uses an `input_spec` so it knows the types and shapes
  # that your model expects.
  input_spec = example_dataset.element_spec
  keras_model_clone = tf.keras.models.clone_model(keras_model)
  return tff.learning.from_keras_model(
      keras_model_clone,
      input_spec=input_spec,
      loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
      metrics=[FlattenedCategoricalAccuracy()])

これで、モデルを改善するために使用するFederated Averaging反復プロセスを構築する準備が整いました(Federated Averagingアルゴリズムの詳細については、「分散データからのディープネットワークの通信効率の高い学習」を参照してください)。

コンパイルされたKerasモデルを使用して、フェデレーショントレーニングの各ラウンド後に標準(非フェデレーション)評価を実行します。これは、シミュレートされた連合学習を行う場合の研究目的に役立ち、標準のテストデータセットがあります。

現実的な生産環境では、これと同じ手法を使用して、連合学習でトレーニングされたモデルを取得し、テストまたは品質保証の目的で一元化されたベンチマークデータセットで評価することができます。

# This command builds all the TensorFlow graphs and serializes them: 
fed_avg = tff.learning.build_federated_averaging_process(
    model_fn=create_tff_model,
    client_optimizer_fn=lambda: tf.keras.optimizers.SGD(lr=0.5))

これは可能な限り最も単純なループであり、単一のバッチの単一のクライアントで1ラウンドのフェデレーション平均を実行します。

state = fed_avg.initialize()
state, metrics = fed_avg.next(state, [example_dataset.take(5)])
train_metrics = metrics['train']
print('loss={l:.3f}, accuracy={a:.3f}'.format(
    l=train_metrics['loss'], a=train_metrics['accuracy']))
loss=4.403, accuracy=0.132

それでは、もう少し興味深いトレーニングと評価のループを書いてみましょう。

このシミュレーションが比較的高速に実行されるように、各ラウンドで同じ3つのクライアントでトレーニングを行い、それぞれに2つのミニバッチのみを考慮します。

def data(client, source=train_data):
  return preprocess(source.create_tf_dataset_for_client(client)).take(5)


clients = [
    'ALL_S_WELL_THAT_ENDS_WELL_CELIA', 'MUCH_ADO_ABOUT_NOTHING_OTHELLO',
]

train_datasets = [data(client) for client in clients]

# We concatenate the test datasets for evaluation with Keras by creating a 
# Dataset of Datasets, and then identity flat mapping across all the examples.
test_dataset = tf.data.Dataset.from_tensor_slices(
    [data(client, test_data) for client in clients]).flat_map(lambda x: x)

fed_avg.initialize()によって生成されたモデルの初期状態は、 fed_avg.initialize()が重みを複製しないため、ロードされた重みではなく、 clone_model()モデルのランダム初期化子に基づいています。事前にトレーニングされたモデルからトレーニングを開始するには、ロードされたモデルから直接サーバー状態のモデルの重みを設定します。

NUM_ROUNDS = 5

# The state of the FL server, containing the model and optimization state.
state = fed_avg.initialize()

# Load our pre-trained Keras model weights into the global model state.
state = tff.learning.state_with_new_model_weights(
    state,
    trainable_weights=[v.numpy() for v in keras_model.trainable_weights],
    non_trainable_weights=[
        v.numpy() for v in keras_model.non_trainable_weights
    ])


def keras_evaluate(state, round_num):
  # Take our global model weights and push them back into a Keras model to
  # use its standard `.evaluate()` method.
  keras_model = load_model(batch_size=BATCH_SIZE)
  keras_model.compile(
      loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
      metrics=[FlattenedCategoricalAccuracy()])
  state.model.assign_weights_to(keras_model)
  loss, accuracy = keras_model.evaluate(example_dataset, steps=2, verbose=0)
  print('\tEval: loss={l:.3f}, accuracy={a:.3f}'.format(l=loss, a=accuracy))


for round_num in range(NUM_ROUNDS):
  print('Round {r}'.format(r=round_num))
  keras_evaluate(state, round_num)
  state, metrics = fed_avg.next(state, train_datasets)
  train_metrics = metrics['train']
  print('\tTrain: loss={l:.3f}, accuracy={a:.3f}'.format(
      l=train_metrics['loss'], a=train_metrics['accuracy']))

print('Final evaluation')
keras_evaluate(state, NUM_ROUNDS + 1)
Round 0
    Eval: loss=3.324, accuracy=0.401
    Train: loss=4.360, accuracy=0.155
Round 1
    Eval: loss=4.361, accuracy=0.049
    Train: loss=4.235, accuracy=0.164
Round 2
    Eval: loss=4.219, accuracy=0.177
    Train: loss=4.081, accuracy=0.221
Round 3
    Eval: loss=4.080, accuracy=0.174
    Train: loss=3.940, accuracy=0.226
Round 4
    Eval: loss=3.991, accuracy=0.176
    Train: loss=3.840, accuracy=0.226
Final evaluation
    Eval: loss=3.909, accuracy=0.171

デフォルトの変更では、大きな違いを生むのに十分なトレーニングを行っていませんが、より多くのシェイクスピアデータでより長くトレーニングすると、更新されたモデルで生成されたテキストのスタイルに違いが見られるはずです。

# Set our newly trained weights back in the originally created model.
keras_model_batch1.set_weights([v.numpy() for v in keras_model.weights])
# Text generation requires batch_size=1
print(generate_text(keras_model_batch1, 'What of TensorFlow Federated, you ask? '))
What of TensorFlow Federated, you ask? Shalways, I will call your
compet with any city brought their faces uncompany," besumed him. "When he
sticked Madame Defarge pushed the lamps.

"Have I often but no unison. She had probably come, 

提案された拡張機能

このチュートリアルは最初のステップにすぎません。このノートブックを拡張する方法について、いくつかのアイデアがあります。

  • クライアントをサンプリングしてランダムにトレーニングする、より現実的なトレーニングループを作成します。
  • クライアントデータセットで「 .repeat(NUM_EPOCHS) 」を使用して、ローカルトレーニングの複数のエポックを試してください(たとえば、 .repeat(NUM_EPOCHS) 。のように)。これを行う画像分類の連合学習も参照してください。
  • compile()コマンドを変更して、クライアントでさまざまな最適化アルゴリズムを使用してみてください。
  • 試してみてくださいserver_optimizerに引数をbuild_federated_averaging_processサーバー上のモデルの更新を適用するための異なるアルゴリズムをしようとします。
  • 試してみてくださいclient_weight_fnにに引数をbuild_federated_averaging_processクライアントの異なる重み付けをしようとします。デフォルトでは、クライアントの更新をクライアント上の例の数で重み付けしますが、たとえばclient_weight_fn=lambda _: tf.constant(1.0)実行できます。