tf.distribute.experimental.ValueContext

View source on GitHub

A class wrapping information needed by a distribute function.

tf.distribute.experimental.ValueContext(
    replica_id_in_sync_group=0, num_replicas_in_sync=1
)

This is a context class that is passed to the value_fn in strategy.experimental_distribute_values_from_function and contains information about the compute replicas. The num_replicas_in_sync and replica_id can be used to customize the value on each replica.

Example usage:

  1. Directly constructed.
def value_fn(context): 
  return context.replica_id_in_sync_group/context.num_replicas_in_sync 
context = tf.distribute.experimental.ValueContext( 
  replica_id_in_sync_group=2, num_replicas_in_sync=4) 
per_replica_value = value_fn(context) 
per_replica_value 
0.5 
  1. Passed in by experimental_distribute_values_from_function.
strategy = tf.distribute.MirroredStrategy() 
def value_fn(value_context): 
  return value_context.num_replicas_in_sync 
distributed_values = ( 
     strategy.experimental_distribute_values_from_function( 
       value_fn)) 
local_result = strategy.experimental_local_results(distributed_values) 
local_result 
(1,) 

Args:

  • replica_id_in_sync_group: the current replica_id, should be an int in [0,num_replicas_in_sync).
  • num_replicas_in_sync: the number of replicas that are in sync.

Attributes:

  • num_replicas_in_sync: Returns the number of compute replicas in sync.
  • replica_id_in_sync_group: Returns the replica ID.