-cluster-ops-by-policy
: Clusters ops according to specified policy.
This pass clusters ops according to the policy specified by the pass options. Clustered ops are moved to a tf_device::clusterOp region.
The only currently supported option is 'oplist='. This option
specifies the names of the ops that should be clustered if they form
a single use def-use chain, that is, the next op in the list uses the result
of the previous op and is the only user of that result. The ops should
be located in the same block, be assigned to the same device and have no
side effects.
For example, running this pass with option oplist="tf.Cast, tf.Add" on:
func @cluster_oplist(%arg0 : tensor<f32>, %arg1 : tensor<i32>) -> tensor<i32> {
%0 = "tf.Cast"(%arg0) : (tensor<f32>) -> tensor<i32>
%1 = "SomeOp" (%arg1) : (tensor<i32>) -> tensor<i32>
%2 = "tf.Add"(%0, %1) : (tensor<i32>, tensor<i32>) -> tensor<i32>
return %2 : tensor<i32>
}
will produce tf_device::opCluster enclosing tf.Add and tf.Neg:
func @cluster_oplist(%arg0: tensor<f32>, %arg1: tensor<i32>) -> tensor<i32> {
%0 = "SomeOp"(%arg1) : (tensor<i32>) -> tensor<i32>
%1 = "tf_device.cluster"() ( {
%2 = "tf.Cast"(%arg0) : (tensor<f32>) -> tensor<i32>
%3 = "tf.Add"(%2, %0) : (tensor<i32>, tensor<i32>) -> tensor<i32>
tf_device.return %3 : tensor<i32>
}) : () -> tensor<i32>
return %1 : tensor<i32>
}
Options
-oplist : Cluster listed ops when they form a single use def-use chain, such that each op's single user is the next op in the list.
-tf-device-cluster-outlining
: Outlines regions of tf_device.cluster operations
This pass outlines the body of a tf_device.cluster
into a function and
replaces the tf_device.cluster
op with an equivalent tf_device.cluster_func
op. Implicit operands will be captured and materialized as explicit arguments to
the newly created functions and associated tf_device.cluster_func
ops.
For example, the following:
func @computation(%arg0: tensor<i32>) -> tensor<i32> {
%cluster = "tf_device.cluster"() ( {
%identity = "tf.Identity"(%arg0) : (tensor<i32>) -> tensor<i32>
tf_device.return %identity : tensor<i32>
}) : () -> (tensor<i32>)
return %cluster : tensor<i32>
}
will be transformed into:
func @computation(%arg0: tensor<i32>) -> tensor<i32> {
%cluster = "tf_device.cluster_func"(%arg0) {func = @_func} : (tensor<i32>) -> tensor<i32>
return %cluster : tensor<i32>
}
func @_func(%arg0: tensor<i32>) -> tensor<i32> {
%identity = "tf.Identity"(%arg0) : (tensor<i32>) -> tensor<i32>
return %identity : tensor<i32>
}
-tf-device-constant-sinking
: Sinks constants implicitly captured in a tf_device.cluster region.
This pass sinks implicitly captured constants (tf.Const
ops) used by and into
a tf_device.cluster
region. Performing this prior to outlining will reduce the
number of arguments of the outlined function.
For example, the following:
func @cluster() -> tensor<i32> {
%const = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
%cluster = "tf_device.cluster"() ( {
%identity = "tf.Identity"(%const) : (tensor<i32>) -> tensor<i32>
tf_device.return %identity : tensor<i32>
}) : () -> (tensor<i32>)
return %cluster : tensor<i32>
}
will be transformed into:
func @cluster() -> tensor<i32> {
%cluster = "tf_device.cluster"() ( {
%const = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
%identity = "tf.Identity"(%const) : (tensor<i32>) -> tensor<i32>
tf_device.return %identity : tensor<i32>
}) : () -> (tensor<i32>)
return %cluster : tensor<i32>
}
-tf-executor-graph-pruning
: Prunes unreachable ops in a tf_executor.graph
This pass removes ops from a tf_executor.graph
that are not transitively, via
data or control dependencies, connected to the associated tf_executor.fetch
op. The order of ops will be preserved. Functions named main
with no
tf.entry_function
attribute will not be pruned, as such graphs/functions may
have been imported from a V1 TensorFlow graph, where feeds/fetches/targets are
not provided at certain stages of IR transformation (e.g. pre-placement).
Option ops-to-preserve
allows to specify ops that should not be pruned,
regardless of their reachability.
For example, the following:
func @graph(%arg0: tensor<i32>, %arg1: tensor<i32>) -> tensor<i32> {
%graph = tf_executor.graph {
%transitive_reachable_data:2 = tf_executor.island wraps "tf.Identity"(%arg0) : (tensor<i32>) -> tensor<i32>
%reachable_data:2 = tf_executor.island wraps "tf.Identity"(%transitive_reachable_data#0) : (tensor<i32>) -> tensor<i32>
%unreachable_data:2 = tf_executor.island wraps "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
%transitive_reachable_control = tf_executor.island wraps "tf.NoOp"() : () -> ()
%reachable_control = tf_executor.island(%transitive_reachable_control) wraps "tf.NoOp"() : () -> ()
%unreachable_control = tf_executor.island wraps "tf.NoOp"() : () -> tensor<i32>
tf_executor.fetch %reachable_data#0, %reachable_control : tensor<i32>, !tf_executor.control
}
return %graph : tensor<i32>
}
will be transformed into:
func @graph(%arg0: tensor<i32>, %arg1: tensor<i32>) -> tensor<i32> {
%graph = tf_executor.graph {
%transitive_reachable_data:2 = tf_executor.island wraps "tf.Identity"(%arg0) : (tensor<i32>) -> tensor<i32>
%reachable_data:2 = tf_executor.island wraps "tf.Identity"(%transitive_reachable_data#0) : (tensor<i32>) -> tensor<i32>
%transitive_reachable_control = tf_executor.island wraps "tf.NoOp"() : () -> ()
%reachable_control = tf_executor.island(%transitive_reachable_control) wraps "tf.NoOp"() : () -> ()
tf_executor.fetch %reachable_data#0, %reachable_control : tensor<i32>, !tf_executor.control
}
return %graph : tensor<i32>
}
Options
-ops-to-preserve : Comma separated list of ops that should not be pruned regardless of reachability
-tf-executor-to-functional-conversion
: Lifts tf_executor.island inner ops from a tf_executor.graph
This pass converts tf_executor.graphs consisting of only tf_executor.islands and a tf_executor.fetch into a sea of nodes consisting of TensorFlow Dialect ops by lifting such ops out of a tf_executor.graph's tf_executor.islands. If V1 control flow ops are present in a tf_executor.graph, an error will be returned.
For example, the following:
func @my_fn(%arg0: tensor<i32>, %arg1: tensor<i32>) -> (tensor<i32>, tensor<i32>) {
%graph_results:2 = tf_executor.graph {
%island_0_result, %island_0_control = tf_executor.island {
%identity = "tf.Identity"(%arg0) : (tensor<i32>) -> tensor<i32>
tf_executor.yield %identity : tensor<i32>
}
%island_1_result, %island_1_control = tf_executor.island {
%identity_n:2 = "tf.IdentityN"(%arg1, %island_0_result) : (tensor<i32>, tensor<i32>) -> (tensor<i32>, tensor<i32>)
tf_executor.yield %identity_n#0
}
tf_executor.fetch %island_0_result, %island_1_result : tensor<i32>, tensor<i32>
}
return %graph_results#0, %graph_results#1 : tensor<i32>, tensor<i32>
}
will be transformed into:
func @my_fn(%arg0: tensor<i32>, %arg1: tensor<i32>) -> (tensor<i32>, tensor<i32>) {
%identity = "tf.Identity"(%arg0) : (tensor<i32>) -> tensor<i32>
%identity_n:2 = "tf.IdentityN"(%arg1, %identity) : (tensor<i32>, tensor<i32>) -> (tensor<i32>, tensor<i32>)
return %identity, %identity_n#0 : tensor<i32>, tensor<i32>
}
-tf-functional-control-flow-to-regions
: Transforms functional control flow operations to their region-based counterparts
This pass transforms functional control flow operations in the TensorFlow
dialect to their region-based counterparts, i.e., tf.If
is transformed to
tf.IfRegion
and tf.While
is transformed to tf.WhileRegion
.
For example, this functional operation
%0 = "tf.If"(%arg0, %arg1) {
then_branch = @then_branch_func, else_branch = @else_branch_func, is_stateless = false
} : (tensor<i1>, tensor<*xf32>) -> tensor<*xf32>
will be transformed into this region-based operation
%0 = "tf.IfRegion"(%arg0) ( {
%1 = call @then_branch_func(%arg1) : (tensor<*xf32>) -> tensor<*xf32>
"tf.Yield"(%1) : (tensor<*xf32>) -> ()
}, {
%1 = call @else_branch_func(%arg1) : (tensor<*xf32>) -> tensor<*xf32>
"tf.Yield"(%1) : (tensor<*xf32>) -> ()
}) {is_stateless = false} : (tensor<i1>) -> tensor<*xf32>
-tf-mark-ops-for-outside-compilation
: Marks ops in device cluster for outside compilation if they are unsupported on device.
This pass marks unsupported ops in a device cluster with
_xla_outside_compilation
attribute so the operations will run on the host
instead of the device. Unsupported ops are ops that can not be code
generated to run on the device for the cluster including:
- String operations on TPUs.
- Operations that don't have a kernel defined for the device.
This pass is conservative in that it will mark all ops for outside compilation that can not be compiled for the device. Exceptions for this are added for ops that will be rewritten or decomposed before compiling on device.
For example, tf_device.cluster op with an unsupported op, tf.UnsupportedOp:
func @unsupported_op() -> tensor<i32> {
%0 = "tf_device.cluster"() ( {
%1 = "tf.UnsupportedOp"() : () -> tensor<i32>
%2 = "tf.Identity"(%1) : (tensor<i32>) -> tensor<i32>
tf_device.return %2 : tensor<i32>
}) {allow_soft_placement = true, num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor<i32>
return %0 : tensor<i32>
}
will mark tf.UnsupportedOp with _xla_outside_compilation
attribute:
func @unsupported_op() -> tensor<i32> {
%0 = "tf_device.cluster"() ( {
%1 = "tf.UnsupportedOp"() {_xla_outside_compilation = "auto0"} : () -> tensor<i32>
%2 = "tf.Identity"(%1) : (tensor<i32>) -> tensor<i32>
tf_device.return %2 : tensor<i32>
}) {allow_soft_placement = true, device_assignment = [], num_cores_per_replica = 1 : i64, topology = ""} : () -> tensor<i32>
return %0 : tensor<i32>
}
-tf-region-control-flow-to-functional
: Transforms region-based control flow operations to their functional counterparts
This pass transforms region-based control flow operations in the TensorFlow
dialect to their functional counterparts, i.e., tf.IfRegion
is transformed to
tf.If
and tf.WhileRegion
is transformed to tf.While
.
For example, this region-based operation
%0 = "tf.IfRegion"(%arg0) ( {
%1 = call @then_branch_func(%arg1) : (tensor<*xf32>) -> tensor<*xf32>
"tf.Yield"(%1) : (tensor<*xf32>) -> ()
}, {
%1 = call @else_branch_func(%arg1) : (tensor<*xf32>) -> tensor<*xf32>
"tf.Yield"(%1) : (tensor<*xf32>) -> ()
}) {is_stateless = false} : (tensor<i1>) -> tensor<*xf32>
will be transformed into this functional operation
%0 = "tf.If"(%arg0, %arg1) {
then_branch = @then_branch_func, else_branch = @else_branch_func, is_stateless = false
} : (tensor<i1>, tensor<*xf32>) -> tensor<*xf32>
-tf-shape-inference
: Simple Shape Inference on TensorFlow Dialect
Options
-max-iterations : Maximum shape inference iterations
-tf-tpu-cluster-formation
: Forms clusters from operations assigned to the same TPU computation
TPU computations from the frontend are composed of a tf.TPUReplicateMetadata
op, a subgraph of ops (TensorFlow Dialect) each with a matching _tpu_replicate
attribute relative to the associated tf.TPUReplicateMetadata
op, and
optionally tf.TPUReplicatedInput
and tf.TPUReplicatedOutput
ops feeding in
inputs and outputs to and from a replicated TPU computation. The number of times
a TPU computation is replicated is defined in the tf.TPUReplicateMetadata
op
(num_replicas
attribute) and operand and result sizes of
tf.TPUReplicatedInput
and tf.TPUReplicatedOutput
respectively must match,
excluding packed tensors. It is also assumed ops of the same TPU computation do
not have ops outside of the TPU computation that are both inputs and outputs to
the same TPU computation.
This pass takes the TPU computation subgraph, moves them into a
tf_device.cluster
, and copies over attributes from the associated
tf.TPUReplicateMetadata
op to the newly created tf_device.cluster
. If the
computation is replicated (num_replicas
> 1), the num_replicas
attribute is
not copied over but instead the tf_device.cluster
is further wrapped with a
tf_device.replicate
, and associated tf.TPUReplicatedInput
and
tf.TPUReplicatedOutput
ops are replaced as the tf_device.replicate
operands
and results. Otherwise, the single operands and results of the associated
tf.TPUReplicatedInput
and tf.TPUReplicatedOutput
ops are simply forwarded to
the tf_device.cluster
.
For example, the following non replicated computation:
func @tpu_computation(%arg0: tensor<i32>) -> tensor<i32> {
// Metadata op for cluster `cluster` with 1 replica, 1 core per replica and
// with topology `<topology>`.
"tf.TPUReplicateMetadata"() {_tpu_replicate = "cluster", num_relicas = 1, num_cores_per_replica = 1, topology = "<topology>", device_assignment = [], padding_map = []} : () -> ()
%replicated_input = "tf.TPUReplicatedInput"(%arg0) : (tensor<i32>) -> tensor<i32>
%identity = "tf.Identity"(%replicated_input) {_tpu_replicate = "cluster"} : (tensor<i32>) -> tensor<i32>
%replicated_output = "tf.TPUReplicatedOutput(%identity) : (tensor<i32>) -> tensor<i32>
return %replicated_output : tensor<i32>
}
will be transformed into:
func @tpu_computation(%arg0: tensor<i32>) -> tensor<i32> {
%cluster = "tf_device.cluster"() ( {
%identity = "tf.Identity"(%arg0) : (tensor<i32>) -> tensor<i32>
tf_device.return %identity : tensor<i32>
}) {_tpu_replicate = "cluster", num_cores_per_replica = 1, topology = "topology", device_assignment = [], padding_map = []} : () -> (tensor<i32>)
return %cluster : tensor<i32>
}
The following replicated computation:
func @tpu_computation(%arg0: tensor<i32>, %arg1: tensor<i32>) -> (tensor<i32>, tensor<i32>) {
"tf.TPUReplicateMetadata"() {_tpu_replicate = "cluster", num_relicas = 2, num_cores_per_replica = 1, topology = "topology", device_assignment = [], padding_map = []} : () -> ()
%replicated_input = "tf.TPUReplicatedInput"(%arg0, %arg1) : (tensor<i32>, tensor<i32>) -> tensor<i32>
%identity = "tf.Identity"(%replicated_input) {_tpu_replicate = "cluster"} : (tensor<i32>) -> tensor<i32>
%replicated_output:2 = "tf.TPUReplicatedOutput(%identity) : (tensor<i32>) -> (tensor<i32>, tensor<i32>)
return %replicated_output#0, %replicated_output#1 : tensor<i32>, tensor<i32>
}
will be transformed into:
func @tpu_computation(%arg0: tensor<i32>, %arg1: tensor<i32>) -> (tensor<i32>, tensor<i32>) {
%replicate:2 = tf_device.replicate([%arg0, %arg1] as %replicated_input) {n = 2 : i32} {
%cluster = "tf_device.cluster"() ( {
%identity = "tf.Identity"(%replicated_input) : (tensor<i32>) -> tensor<i32>
tf_device.return %identity : tensor<i32>
}) {_tpu_replicate = "cluster", num_cores_per_replica = 1, topology = "topology", device_assignment = [], padding_map = []} : () -> (tensor<i32>)
tf_device.return %cluster : tensor<i32>
}
return %replicate#0, %replicate#1 : tensor<i32>, tensor<i32>
}
-tf-tpu-extract-outside-compilation
: Extracts TPU outside compilation computation to a separate tf_device.parallel_execute region.
This pass extracts a CPU computation cluster with _xla_outside_compilation
annotation, which denotes ops that should be run on CPU/host, from a TPU cluster.
Each outside compilation cluster is moved to
a tf_device.parallel_execute region. The TPU cluster is also moved to a
tf_device.parallel_execute region. Communication ops between device and host are
added to pass inputs/outputs to/from the outside compiled region.
For example, the following tf_device.cluster with an op marked for xla_outside_compilation
:
func @outside_compilation() -> tensor<f32> {
%0 = "tf_device.cluster"() ( {
%1 = "tf.Const"() {_xla_outside_compilation = "0", value = dense<1.0> : tensor<f32>} : () -> (tensor<f32>)
%2 = "tf.Identity"(%1) {_xla_outside_compilation = "0"} : (tensor<f32>) -> (tensor<f32>)
%3 = "tf.AddV2"(%1, %2) : (tensor<f32>, tensor<f32>) -> (tensor<f32>)
tf_device.return %3 : tensor<f32>
}) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor<f32>
return %0 : tensor<f32>
}
will become a tf_device.parallel_execute op with a CPU/host region and a tf_device.cluster with communication ops to send data to/from device/host:
func @outside_compilation() -> tensor<f32> {
%0 = "tf_device.parallel_execute"() ( {
"tf_device.launch"() ( {
%1 = "tf._TPUCompileMlirPlaceholderProgramKey"() : () -> tensor<3x!tf.string>
%2 = "tf._XlaRecvAtHost"(%1) {device_ordinal = 0 : i64, key = "host_compute_channel_0_0_args"} : (tensor<3x!tf.string>) -> tensor<f32>
%3 = "tf.Identity"(%2) : (tensor<f32>) -> tensor<f32>
"tf._XlaSendFromHost"(%3, %1) {device_ordinal = 0 : i64, key = "host_compute_channel_0_0_retvals"} : (tensor<f32>, tensor<3x!tf.string>) -> ()
tf_device.return
}) {device = "/job:worker/replica:0/task:0/device:CPU:0"} : () -> ()
tf_device.return
}, {
%1 = "tf_device.cluster"() ( {
%2 = "tf.Const"() {value = dense<1.000000e+00> : tensor<f32>} : () -> tensor<f32>
%3 = "tf._XlaHostComputeMlir"(%2) {recv_key = "host_compute_channel_0_0_retvals", send_key = "host_compute_channel_0_0_args", tpu_core = 0 : i64} : (tensor<f32>) -> tensor<f32>
%4 = "tf.AddV2"(%2, %3) : (tensor<f32>, tensor<f32>) -> tensor<f32>
tf_device.return %4 : tensor<f32>
}) {device_assignment = [], num_cores_per_replica = 1 : i64, topology = ""} : () -> tensor<f32>
tf_device.return %1 : tensor<f32>
}) : () -> tensor<f32>
return %0 : tensor<f32>
}
-tf-tpu-reorder-replicate-partitioned-inputs
: Reorder replicated and partitioned input ops.
This pass rewrites how data parallelism and model parallelism is expressed for
inputs. It reorders tf.TPUPartitionedInput
(model parallelism) and
tf.TPUReplicatedInput
(data parallelism) ops. It transforms a DAG where
multiple tf.TPUPartitionedInput
ops are feeding into a single
tf.TPUReplicatedInput
into a DAG where multiple tf.TPUReplicatedInput
ops
are feeding into a single tf.TPUPartitionedInput
. Transforming the IR in such
a manner will allow subsequent cluster formation pass to handle IR with both
data and model parallelism in an easier manner.
For example, the following:
!rtype = type tensor<!tf.resource<tensor<10x3xf32>>>
func @data_and_model_parallelism(%arg0: !rtype, %arg1: !rtype, %arg2: !rtype, %arg3: !rtype) -> !rtype {
%pi_0 = "tf.TPUPartitionedInput"(%arg0, %arg1) {_XlaSharding = "", device = "", partition_dim = -1 : i64} : (!rtype, !rtype) -> !rtype
%pi_1 = "tf.TPUPartitionedInput"(%arg2, %arg3) {_XlaSharding = "", device = "", partition_dim = -1 : i64} : (!rtype, !rtype) -> !rtype
%ri = "tf.TPUReplicatedInput"(%pi_0, %pi_1) : (!rtype, !rtype) -> !rtype
return %ri : !rtype
}
will be transformed into:
!rtype = type tensor<!tf.resource<tensor<10x3xf32>>>
func @data_and_model_parallelism(%arg0: !rtype, %arg1: !rtype, %arg2: !rtype, %arg3: !rtype) -> !rtype {
%ri_0 = "tf.TPUReplicatedInput"(%arg0, %arg2) : (!rtype, !rtype) -> !rtype
%ri_1 = "tf.TPUReplicatedInput"(%arg1, %arg3) : (!rtype, !rtype) -> !rtype
%pi = "tf.TPUPartitionedInput"(%ri_0, %ri_1) {_XlaSharding = "", device = "", partition_dim = -1 : i64} : (!rtype, !rtype) -> !rtype
return %pi : !rtype
}
-tf-tpu-resource-partition
: Partitions unpartitioned resource read/write to partitioned resource variables.
This pass creates individual resource reads/writes from the unpartitioned
resource variable (from tf.TPUPartitionedInput
) to individual partitioned
resource variables (tf.TPUPartitionedInput
operands). As resource op
decomposition/lifting occurs with the unpartitioned resource variables,
transforming the IR in such a manner will allow for subsequent passes to operate
on individual resource variable handles per core/device.
For example, the following:
func @cluster(%arg0: tensor<!tf.resource<tensor<i32>>>, %arg1: tensor<!tf.resource<tensor<i32>>>) {
%partitioned_variable = "tf.TPUPartitionedInput"(%arg0, %arg1) {N = 2 : i64, _XlaSharding = "", partition_dim = -1 : i64} : (tensor<!tf.resource<tensor<i32>>>, tensor<!tf.resource<tensor<i32>>>) -> tensor<!tf.resource<tensor<i32>>>
%read = "tf.ReadVariableOp"(%partitioned_variable) : (tensor<!tf.resource<tensor<i32>>>) -> tensor<i32>
%computation = "tf_device.cluster_func"(%read) {func = @computation, use_spmd_for_xla_partitioning = true} : (tensor<i32>) -> tensor<i32>
"tf.AssignVariableOp"(%partitioned_variable, %computation) : (tensor<!tf.resource<tensor<i32>>>, tensor<i32>) -> ()
return
}
func @computation(%arg0: tensor<i32>) -> tensor<i32> {
return %arg0: tensor<i32>
}
will be transformed into:
func @cluster(%arg0: tensor<!tf.resource<tensor<i32>>>, %arg1: tensor<!tf.resource<tensor<i32>>>) {
%read0 = "tf.ReadVariableOp"(%arg0) : (tensor<!tf.resource<tensor<i32>>>) -> tensor<i32>
%read1 = "tf.ReadVariableOp"(%arg1) : (tensor<!tf.resource<tensor<i32>>>) -> tensor<i32>
%partitioned_input = "tf.TPUPartitionedInput"(%read0, %read1) {N = 2 : i64, _XlaSharding = "", partition_dim = -1 : i64} : (tensor<i32>, tensor<i32>) -> tensor<i32>
%computation = "tf_device.cluster_func"(%partitioned_input) {func = @computation, use_spmd_for_xla_partitioning = true} : (tensor<i32>) -> tensor<i32>
%partitioned_output:2 = "tf.TPUPartitionedOutput"(%computation) {N = 2 : i64, _XlaSharding = "", partition_dim = -1 : i64} : (tensor<i32>) -> (tensor<i32>, tensor<i32>)
"tf.AssignVariableOp"(%arg0, %partitioned_output#0) : (tensor<!tf.resource<tensor<i32>>>, tensor<i32>) -> ()
"tf.AssignVariableOp"(%arg1, %partitioned_output#1) : (tensor<!tf.resource<tensor<i32>>>, tensor<i32>) -> ()
return
}
func @computation(%arg0: tensor<i32>) -> tensor<i32> {
return %arg0: tensor<i32>
}
-tf-tpu-resource-read-for-write
: Inserts tf.ReadVariableOp inputs to a TPU cluster for resource writes with no reads
This pass materializes tf.ReadVariableOp
inputs to an outlined TPU computation
for resource variables where only writes are present so later in the pipeline
such resource variables can be fused with generated tf.TPUExecute
ops, which
only supports resource variable read or read + write. For all TPU computations,
resource variables are required to be initialized prior to execution. Write only
resource variable uses can be generated currently via packed tensor uses.
For example, the following:
func @write_only_resource(%value: tensor<i32>, %resource: tensor<*x!tf.resource<tensor<i32>>>) {
%0 = "tf_device.cluster_func"(%value) {func = @cluster} : (tensor<i32>) -> tensor<i32>
"tf.AssignVariableOp"(%resource, %0) : (tensor<*x!tf.resource<tensor<i32>>>, tensor<i32>) -> ()
return
}
func @cluster(%arg0: tensor<i32>) -> tensor<i32> {
%identity = "tf.Identity"(%arg0) : (tensor<i32>) -> tensor<i32>
return %identity : tensor<i32>
}
will be transformed into:
func @write_only_resource(%value: tensor<i32>, %resource: tensor<*x!tf.resource<tensor<i32>>>) {
%resource_read = "tf.ReadVariableOp"(%resource) : (tensor<*x!tf.resource<tensor<i32>>>) -> tensor<i32>
%0 = "tf_device.cluster_func"(%value, %resource_read) {func = @cluster} : (tensor<i32>, tensor<i32>) -> tensor<i32>
"tf.AssignVariableOp"(%resource, %0) : (tensor<*x!tf.resource<tensor<i32>>>, tensor<i32>) -> ()
return
}
func @cluster(%arg0: tensor<i32>, %arg1: tensor<i32>) -> tensor<i32> {
%identity = "tf.Identity"(%arg0) : (tensor<i32>) -> tensor<i32>
return %identity : tensor<i32>
}
-tf-tpu-rewrite
: Rewrites a tf_device.cluster_func
on TPUs into TPU runtime operations.
This pass rewrites a tf_device.cluster_func
operation into a sequence of tf._TPUCompileMlir
and tf.TPUExecute
operations. tf._TPUCompileMlir
contains a MLIR module that is
functionally equivalent to the function referenced by tf_device.cluster_func
.
This makes the module to be jit-compiled and executed on TPU.
If it is not possible to rewrite the operation or device assignment fails,
a failure will be returned.
Note, many parameters to the tf_device.cluster_func
are ommited in this
and following examples.
For example, a non replicated tf_device.cluster_func
:
func @tf_tpu_rewrite(%arg0: tensor<i8>) {
%0 = "tf_device.cluster_func"(%arg0) {_tpu_replicate = "cluster0", func = @func} : (tensor<i8>) -> tensor<i8>
return
}
will be rewritten as:
func @tf_tpu_rewrite(%arg0: tensor<i8>) {
%0:2 = "tf_device.launch"() ( {
%compilation_status, %program = "tf._TPUCompileMlir"() {mlir_module = "<serialized func>"} : () -> (tensor<!tf.string>, tensor<3x!tf.string>)
tf_device.return %compilation_status, %program : tensor<!tf.string>, tensor<3x!tf.string>
}) {device = "/job:worker/replica:0/task:0/device:CPU:0"} : () -> (tensor<!tf.string>, tensor<3x!tf.string>)
"tf_device.launch"() ( {
"tf.TPUCompileSucceededAssert"(%0#0) : (tensor<!tf.string>) -> ()
tf_device.return
}) {device = "/job:worker/replica:0/task:0/device:CPU:0"} : () -> ()
%1 = "tf_device.launch"() ( {
%2 = "tf.TPUExecute"(%arg0, %0#1) : (tensor<i8>, tensor<3x!tf.string>) -> tensor<i8>
tf_device.return %2 : tensor<i8>
}) {device = "/job:worker/replica:0/task:0/device:TPU:0"} : () -> tensor<i8>
return
}
A replicated tf_device.cluster_func
:
func @tf_tpu_rewrite(%arg0: tensor<i8>, %arg1: tensor<i8>) {
%0:2 = tf_device.replicate([%arg0, %arg1] as %ri: tensor<i8>) {n = 2 : i32} {
%1 = "tf_device.cluster_func"(%ri) {_tpu_replicate = "cluster0", func = @func} : (tensor<i8>) -> tensor<i8>
tf_device.return %1 : tensor<i8>
}
return
}
will be rewritten as:
func @tf_tpu_rewrite(%arg0: tensor<i8>, %arg1: tensor<i8>) {
%0:2 = tf_device.replicate([%arg0, %arg1] as %arg2: tensor<i8>) {devices = {TPU_REPLICATED_CORE_0 = ["/job:worker/replica:0/task:0/device:TPU:0", "/job:worker/replica:0/task:0/device:TPU:1"], TPU_REPLICATED_HOST = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:CPU:0"]}, n = 2 : i32} {
%1:2 = "tf_device.launch"() ( {
%compilation_status, %program = "tf._TPUCompileMlir"() {mlir_module = "<serialized func>"} : () -> (tensor<!tf.string>, tensor<3x!tf.string>)
tf_device.return %compilation_status, %program : tensor<!tf.string>, tensor<3x!tf.string>
}) {device = "/job:worker/replica:0/task:0/device:CPU:0"} : () -> (tensor<!tf.string>, tensor<3x!tf.string>)
"tf_device.launch"() ( {
"tf.TPUCompileSucceededAssert"(%1#0) : (tensor<!tf.string>) -> ()
tf_device.return
}) {device = "/job:worker/replica:0/task:0/device:CPU:0"} : () -> ()
%2 = "tf_device.launch"() ( {
%3 = "tf.TPUExecute"(%arg2, %1#1) : (tensor<i8>, tensor<3x!tf.string>) -> tensor<i8>
tf_device.return %3 : tensor<i8>
}) {device = "TPU_REPLICATED_CORE_0"} : () -> tensor<i8>
tf_device.return %2 : tensor<i8>
}
return
}
A non replicated `tf_device.cluster_func` with the model parallelism:
```mlir
func @tf_tpu_rewrite(%arg0: tensor<8xi32>) -> tensor<8xi32> {
%0 = "tf_device.cluster_func"(%arg0) {_tpu_replicate = "cluster0", func = @func, num_cores_per_replica = 2, input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor<8xi32>) -> tensor<8xi32>
return %0 : tensor<8xi32>
}
will be rewritten as:
func @tf_tpu_rewrite(%arg0: tensor<8xi32>) -> tensor<8xi32> {
%0:3 = "tf_device.launch"() ( {
%compilation_status, %program:2 = "tf._TPUCompileMlir"() {mlir_module = "<serialized func>"} : () -> (tensor<!tf.string>, tensor<3x!tf.string>, tensor<3x!tf.string>)
tf_device.return %compilation_status, %program#0, %program#1 : tensor<!tf.string>, tensor<3x!tf.string>, tensor<3x!tf.string>
}) {device = "/job:localhost/replica:0/task:0/device:CPU:0"} : () -> (tensor<!tf.string>, tensor<3x!tf.string>, tensor<3x!tf.string>)
"tf_device.launch"() ( {
"tf.TPUCompileSucceededAssert"(%0#0) : (tensor<!tf.string>) -> ()
tf_device.return
}) {device = "/job:localhost/replica:0/task:0/device:CPU:0"} : () -> ()
%1 = "tf_device.parallel_execute"() ( {
%2 = "tf_device.launch"() ( {
%3 = "tf.TPUExecute"(%arg0, %0#1) : (tensor<8xi32>, tensor<3x!tf.string>) -> tensor<8xi32>
tf_device.return %3 : tensor<8xi32>
}) {device = "/job:localhost/replica:0/task:0/device:TPU:0"} : () -> tensor<8xi32>
tf_device.return %2 : tensor<8xi32>
}, {
"tf_device.launch"() ( {
"tf.TPUExecute"(%0#2) : (tensor<3x!tf.string>) -> ()
tf_device.return
}) {device = "/job:localhost/replica:0/task:0/device:TPU:1"} : () -> ()
tf_device.return
}) : () -> tensor<8xi32>
return %1 : tensor<8xi32>
}
-tf-verify-for-export
: Verify module is suitable for export back to TF Graph
Verifies whether all functions in module are of single tf_executor.graph and each tf_executor.island in tf_executor.graph only has a single op.