ML Community Day is November 9! Join us for updates from TensorFlow, JAX, and more Learn more

tf.distribute.ReplicaContext

A class with a collection of APIs that can be called in a replica context.

You can use tf.distribute.get_replica_context to get an instance of ReplicaContext, which can only be called inside the function passed to tf.distribute.Strategy.run.

strategy = tf.distribute.MirroredStrategy(['GPU:0', 'GPU:1'])
def func():
  replica_context = tf.distribute.get_replica_context()
  return replica_context.replica_id_in_sync_group
strategy.run(func)
PerReplica:{
  0: <tf.Tensor: shape=(), dtype=int32, numpy=0>,
  1: <tf.Tensor: shape=(), dtype=int32, numpy=1>
}

strategy A tf.distribute.Strategy.
replica_id_in_sync_group An integer, a Tensor or None. Prefer an integer whenever possible to avoid issues with nested tf.function. It accepts a Tensor only to be compatible with tpu.replicate.

devices Returns the devices this replica is to be executed on, as a tuple of strings. (deprecated)

num_replicas_in_sync Returns number of replicas that are kept in sync.
replica_id_in_sync_group Returns the id of the replica.

This identifies the replica amon