Public API for tf.experimental.dtensor namespace.
Classes
class DTensorCheckpoint
: Manages saving/restoring trackable values to disk, for DTensor. (deprecated)
class DTensorDataset
: A dataset of DTensors.
class DVariable
: A replacement for tf.Variable which follows initial value placement.
class Layout
: Represents the layout information of a DTensor.
class Mesh
: Represents a Mesh configuration over a certain list of Mesh Dimensions.
Functions
barrier(...)
: Runs a barrier on the mesh.
call_with_layout(...)
: Calls a function in the DTensor device scope if layout
is not None.
check_layout(...)
: Asserts that the layout of the DTensor is layout
.
client_id(...)
: Returns this client's ID.
copy_to_mesh(...)
: Copies a tf.Tensor onto the DTensor device with the given layout.
create_distributed_mesh(...)
: Creates a distributed mesh.
create_mesh(...)
: Creates a single-client mesh.
create_tpu_mesh(...)
: Returns a distributed TPU mesh optimized for AllReduce ring reductions.
device_name(...)
: Returns the singleton DTensor device's name.
enable_save_as_bf16(...)
: Allows float32 DVariables to be checkpointed and restored as bfloat16.
fetch_layout(...)
: Fetches the layout of a DTensor.
full_job_name(...)
: Returns the fully qualified TF job name for this or another task.
heartbeat_enabled(...)
: Returns true if DTensor heartbeat service is enabled.
initialize_accelerator_system(...)
: Initializes accelerators and communication fabrics for DTensor.
initialize_multi_client(...)
: Initializes accelerators and communication fabrics for DTensor.
initialize_tpu_system(...)
: Initializes accelerators and communication fabrics for DTensor.
is_dtensor(...)
: Check whether the input tensor is a DTensor.
job_name(...)
: Returns the job name used by all clients in this DTensor cluster.
jobs(...)
: Returns a list of job names of all clients in this DTensor cluster.
local_devices(...)
: Returns a list of device specs configured on this client.
name_based_restore(...)
: Restores from checkpoint_prefix to name based DTensors.
name_based_save(...)
: Saves name based Tensor into a Checkpoint.
num_clients(...)
: Returns the number of clients in this DTensor cluster.
num_global_devices(...)
: Returns the number of devices of device_type in this DTensor cluster.
num_local_devices(...)
: Returns the number of devices of device_type configured on this client.
pack(...)
: Packs tf.Tensor
components into a DTensor.
preferred_device_type(...)
: Returns the preferred device type for the accelerators.
relayout(...)
: Changes the layout of tensor
.
run_on(...)
: Runs enclosed functions in the DTensor device scope.
sharded_save(...)
: Saves given named tensor slices in a sharded, multi-client safe fashion.
shutdown_accelerator_system(...)
: Shuts down the accelerator system.
shutdown_tpu_system(...)
: Shuts down the accelerator system.
unpack(...)
: Unpacks a DTensor into tf.Tensor
components.
Other Members | |
---|---|
MATCH |
'match'
|
UNSHARDED |
'unsharded'
|