tf.experimental.dtensor.barrier

Runs a barrier on the mesh.

Upon returning from the barrier, all operations run before the barrier would have completed across all clients. Currently we allocate a fully sharded tensor with mesh shape and run an all_reduce on it.

Example:

A barrier can be used before application exit to ensure completion of pending ops.


x = [1, 2, 3]
x = dtensor.relayout(x, dtensor.Layout.batch_sharded(mesh, 'batch', 1))
dtensor.barrier(mesh)

# At this point all devices on all clients in the mesh have completed
# operations before the barrier. Therefore it is OK to tear down the clients.
sys.exit()

mesh The mesh to run the barrier on.
barrier_name The name of the barrier. Mainly used for logging purpose.
timeout_in_ms The timeout of the barrier in ms. If omitted, blocks indefinitely till the barrier is reached from all clients.