質問があります? TensorFlowフォーラム訪問フォーラムでコミュニティとつながる

tff.federated_select を使用して特定のクライアントに異なるデータを送信する

TensorFlow.org で見る Google Colabで実行 GitHub でソースを表示ノートをダウンロード

このチュートリアルでは、さまざまなデータをさまざまなクライアントに送信する必要があるカスタム フェデレーション アルゴリズムを TFF に実装する方法を示します。サーバーに配置された単一の値をすべてのクライアントに送信するtff.federated_broadcastについては、すでにご存知かもしれません。このチュートリアルでは、サーバーベースの値のさまざまな部分がさまざまなクライアントに送信される場合に焦点を当てています。これは、モデル全体を単一のクライアントに送信しないようにするために、モデルの一部を異なるクライアントに分割するのに役立つ場合があります。

tensorflowtensorflow_federated両方をインポートすることから始めましょう。

!pip install --quiet --upgrade tensorflow-federated-nightly
!pip install --quiet --upgrade nest-asyncio

import nest_asyncio
nest_asyncio.apply()
import tensorflow as tf
import tensorflow_federated as tff

クライアント データに基づいて異なる値を送信する

サーバーに配置されたリストがあり、クライアントが配置したデータに基づいて各クライアントにいくつかの要素を送信する場合を考えてみましょう。たとえば、サーバー上の文字列のリスト、およびクライアント上の、ダウンロードするインデックスのコンマ区切りリストです。次のようにそれを実装できます。

list_of_strings_type = tff.TensorType(tf.string, [None])
# We only ever send exactly two values to each client. The number of keys per
# client must be a fixed number across all clients.
number_of_keys_per_client = 2
keys_type = tff.TensorType(tf.int32, [number_of_keys_per_client])
get_size = tff.tf_computation(lambda x: tf.size(x))
select_fn = tff.tf_computation(lambda val, index: tf.gather(val, index))
client_data_type = tf.string

# A function from our client data to the indices of the values we'd like to
# select from the server.
@tff.tf_computation(client_data_type)
@tff.check_returns_type(keys_type)
def keys_for_client(client_string):
  # We assume our client data is a single string consisting of exactly three
  # comma-separated integers indicating which values to grab from the server.
  split = tf.strings.split([client_string], sep=',')[0]
  return tf.strings.to_number([split[0], split[1]], tf.int32)

@tff.tf_computation(tff.SequenceType(tf.string))
@tff.check_returns_type(tf.string)
def concatenate(values):
  def reduce_fn(acc, item):
    return tf.cond(tf.math.equal(acc, ''),
                   lambda: item,
                   lambda: tf.strings.join([acc, item], ','))
  return values.reduce('', reduce_fn)

@tff.federated_computation(tff.type_at_server(list_of_strings_type), tff.type_at_clients(client_data_type))
def broadcast_based_on_client_data(list_of_strings_at_server, client_data):
  keys_at_clients = tff.federated_map(keys_for_client, client_data)
  max_key = tff.federated_map(get_size, list_of_strings_at_server)
  values_at_clients = tff.federated_select(keys_at_clients, max_key, list_of_strings_at_server, select_fn)
  value_at_clients = tff.federated_map(concatenate, values_at_clients)
  return value_at_clients

次に、サーバーに配置された文字列のリストと各クライアントの文字列データを提供することで、計算をシミュレートできます。

client_data = ['0,1', '1,2', '2,0']
broadcast_based_on_client_data(['a', 'b', 'c'], client_data)
[<tf.Tensor: shape=(), dtype=string, numpy=b'a,b'>,
 <tf.Tensor: shape=(), dtype=string, numpy=b'b,c'>,
 <tf.Tensor: shape=(), dtype=string, numpy=b'c,a'>]

ランダム化された要素を各クライアントに送信する

または、サーバー データのランダムな部分を各クライアントに送信すると便利な場合があります。これを実装するには、最初に各クライアントでランダム キーを生成し、次に上記で使用したものと同様の選択プロセスに従います。

@tff.tf_computation(tf.int32)
@tff.check_returns_type(tff.TensorType(tf.int32, [1]))
def get_random_key(max_key):
  return tf.random.uniform(shape=[1], minval=0, maxval=max_key, dtype=tf.int32)

list_of_strings_type = tff.TensorType(tf.string, [None])
get_size = tff.tf_computation(lambda x: tf.size(x))
select_fn = tff.tf_computation(lambda val, index: tf.gather(val, index))

@tff.tf_computation(tff.SequenceType(tf.string))
@tff.check_returns_type(tf.string)
def get_last_element(sequence):
  return sequence.reduce('', lambda _initial_state, val: val)

@tff.federated_computation(tff.type_at_server(list_of_strings_type))
def broadcast_random_element(list_of_strings_at_server):
  max_key_at_server = tff.federated_map(get_size, list_of_strings_at_server)
  max_key_at_clients = tff.federated_broadcast(max_key_at_server)
  key_at_clients = tff.federated_map(get_random_key, max_key_at_clients)
  random_string_sequence_at_clients = tff.federated_select(
      key_at_clients, max_key_at_server, list_of_strings_at_server, select_fn)
  # Even though we only passed in a single key, `federated_select` returns a
  # sequence for each client. We only care about the last (and only) element.
  random_string_at_clients = tff.federated_map(get_last_element, random_string_sequence_at_clients)
  return random_string_at_clients

当社以来broadcast_random_element機能は、任意のクライアント・配置されたデータにはなりません、我々は使用するようにクライアントのデフォルト数とTFFシミュレーションランタイムを設定する必要があります。

tff.backends.native.set_local_execution_context(default_num_clients=3)

次に、選択をシミュレートできます。上記のdefault_num_clientsと以下の文字列のリストを変更して異なる結果を生成するか、単純に計算を再実行して異なるランダム出力を生成することができます。

broadcast_random_element(tf.convert_to_tensor(['foo', 'bar', 'baz']))