此页面由 Cloud Translation API 翻译。
Switch to English

tf.distribute.DistributedValues

GitHub上查看源代码

基类,代表分布值。

的一个子类的实例tf.distribute.DistributedValues一个分销策略中创建变量时,一个迭代创建tf.distribute.DistributedDataset或通过tf.distribute.Strategy.run 。这个基类不应该被直接实例化。 tf.distribute.DistributedValues包含每个副本的值。根据不同的子类,该值可以既可以在更新同步,同步的需求,或从不同步。

tf.distribute.DistributedValues可以减小,以获得跨越副本单个值,作为输入到tf.distribute.Strategy.run或使用每复制品值检查tf.distribute.Strategy.experimental_local_results

实例:

  1. 从创建tf.distribute.DistributedDataset
strategy = tf.distribute.MirroredStrategy()
dataset = tf.data.Dataset.from_tensor_slices([5., 6., 7., 8.]).batch(2)
dataset_iterator = iter(strategy.experimental_distribute_dataset(dataset))
distributed_values = next(dataset_iterator)
  1. 通过返回run
strategy = tf.distribute.MirroredStrategy()
@tf.function
def run():
  ctx = tf.distribute.get_replica_context()
  return ctx.replica_id_in_sync_group
distributed_values = strategy.run(run)
  1. 由于投入run
strategy = tf.distribute.MirroredStrategy()
dataset = tf.data.Dataset.from_tensor_slices([5., 6., 7., 8.]).batch(2)
dataset_iterator = iter(strategy.experimental_distribute_dataset(dataset))
distributed_values = next(dataset_iterator)
@tf.function
def run(input):
  return input + 1.0
updated_value = strategy.run(run, args=(distributed_values,))
  1. 降低值:
strategy = tf.distribute.MirroredStrategy()
dataset = tf.data.Dataset.from_tensor_slices([5., 6., 7., 8.]).batch(2)
dataset_iterator = iter(strategy.experimental_distribute_dataset(dataset))
distributed_values = next(dataset_iterator)
reduced_value = strategy.reduce(tf.distribute.ReduceOp.SUM,
                                distributed_values,
                                axis = 0)
  1. 每个副本值检查:
strategy = tf.distribute.MirroredStrategy()
dataset = tf.data.Dataset.from_tensor_slices([5., 6., 7., 8.]).batch(2)
dataset_iterator = iter(strategy.experimental_distribute_dataset(dataset))
per_replica_values = strategy.experimental_local_results(
   distributed_values)
per_replica_values
(<tf.Tensor: shape=(2,), dtype=float32,
 numpy=array([5., 6.], dtype=float32)>,)