tf.keras.distribution.DeviceMesh

A cluster of computation devices for distributed computation.

This API is aligned with jax.sharding.Mesh and tf.dtensor.Mesh, which represents the computation devices in the global context.

See more details in jax.sharding.Mesh and tf.dtensor.Mesh.

shape tuple of list of integers. The shape of the overall DeviceMesh, e.g. (8,) for a data parallel only distribution, or (4, 2) for a model+data parallel distribution.
axis_names List of string. The logical name of the each axis for the DeviceMesh. The length of the axis_names should match to the rank of the shape. The axis_names will be used to match/create the TensorLayout when distribute the data and variables.
devices Optional list of devices. Defaults to all the available devices locally from keras.distribution.list_devices().

axis_names

devices

shape