View source on GitHub |
Runs a barrier on the mesh.
tf.experimental.dtensor.barrier(
mesh: tf.experimental.dtensor.Mesh
,
barrier_name: Optional[str] = None,
timeout_in_ms: Optional[int] = None
)
Upon returning from the barrier, all operations run before the barrier would have completed across all clients. Currently we allocate a fully sharded tensor with mesh shape and run an all_reduce on it.
Example:
A barrier can be used before application exit to ensure completion of pending ops.
x = [1, 2, 3]
x = dtensor.relayout(x, dtensor.Layout.batch_sharded(mesh, 'batch', 1))
dtensor.barrier(mesh)
# At this point all devices on all clients in the mesh have completed
# operations before the barrier. Therefore it is OK to tear down the clients.
sys.exit()