Missed TensorFlow Dev Summit? Check out the video playlist. Watch recordings

tf.distribute.DistributedValues

View source on GitHub

Base class for representing distributed values.

tf.distribute.DistributedValues(
    values
)

A subclass instance of DistributedValues is created when creating variables within a distribution strategy, iterating a tf.Dataset or through strategy.run. This base class should never be instantiated directly. DistributedValues contains a value per replica. Depending on the subclass, the values could either be synced on update, synced on demand, or never synced.

DistributedValues can be reduced to obtain single value across replicas, as input into run or the per replica values inspected using experimental_local_results.

Example usage:

  1. Created from Dataset:
strategy = tf.distribute.MirroredStrategy() 
dataset = tf.data.Dataset.from_tensor_slices([5., 6., 7., 8.]).batch(2) 
dataset_iterator = iter(strategy.experimental_distribute_dataset(dataset)) 
distributed_values = next(dataset_iterator) 
  1. Returned by run:
strategy = tf.distribute.MirroredStrategy() 
@tf.function 
def run(): 
  ctx = tf.distribute.get_replica_context() 
  return ctx.replica_id_in_sync_group 
distributed_values = strategy.run(run) 
  1. As input into run:
strategy = tf.distribute.MirroredStrategy() 
dataset = tf.data.Dataset.from_tensor_slices([5., 6., 7., 8.]).batch(2) 
dataset_iterator = iter(strategy.experimental_distribute_dataset(dataset)) 
distributed_values = next(dataset_iterator) 
@tf.function 
def run(input): 
  return input + 1.0 
updated_value = strategy.run(run, args=(distributed_values,)) 
  1. Reduce value:
strategy = tf.distribute.MirroredStrategy() 
dataset = tf.data.Dataset.from_tensor_slices([5., 6., 7., 8.]).batch(2) 
dataset_iterator = iter(strategy.experimental_distribute_dataset(dataset)) 
distributed_values = next(dataset_iterator) 
reduced_value = strategy.reduce(tf.distribute.ReduceOp.SUM, 
                                distributed_values, 
                                axis = 0) 
  1. Inspect per replica values:
strategy = tf.distribute.MirroredStrategy() 
dataset = tf.data.Dataset.from_tensor_slices([5., 6., 7., 8.]).batch(2) 
dataset_iterator = iter(strategy.experimental_distribute_dataset(dataset)) 
per_replica_values = strategy.experimental_local_results( 
   distributed_values) 
per_replica_values 
(<tf.Tensor: shape=(2,), dtype=float32, 
 numpy=array([5., 6.], dtype=float32)>,)