tf.distribute.ReplicaContext

Class ReplicaContext

Aliases:

  • Class tf.contrib.distribute.ReplicaContext
  • Class tf.distribute.ReplicaContext

Defined in tensorflow/python/distribute/distribute_lib.py.

tf.distribute.Strategy API when in a replica context.

To be used inside your replicated step function, such as in a tf.distribute.StrategyExtended.call_for_each_replica call.

__init__

__init__(
    strategy,
    replica_id_in_sync_group
)

Initialize self. See help(type(self)) for accurate signature.

Properties

devices

The devices this replica is to be executed on, as a tuple of strings.

num_replicas_in_sync

Returns number of replicas over which gradients are aggregated.

replica_id_in_sync_group

Which replica is being defined, from 0 to num_replicas_in_sync - 1.

strategy

The current tf.distribute.Strategy object.

Methods

__enter__

__enter__()

__exit__

__exit__(
    exception_type,
    exception_value,
    traceback
)

merge_call

merge_call(
    merge_fn,
    args=(),
    kwargs=None
)

Merge args across replicas and run merge_fn in a cross-replica context.

This allows communication and coordination when there are multiple calls to a model function triggered by a call to strategy.extended.call_for_each_replica(model_fn, ...).

See tf.distribute.StrategyExtended.call_for_each_replica for an explanation.

If not inside a distributed scope, this is equivalent to:

strategy = tf.distribute.get_strategy()
with cross-replica-context(strategy):
  return merge_fn(strategy, *args, **kwargs)

Args:

  • merge_fn: function that joins arguments from threads that are given as PerReplica. It accepts tf.distribute.Strategy object as the first argument.
  • args: List or tuple with positional per-thread arguments for merge_fn.
  • kwargs: Dict with keyword per-thread arguments for merge_fn.

Returns:

The return value of merge_fn, except for PerReplica values which are unpacked.