TensorFlow passes

Stay organized with collections Save and categorize content based on your preferences.

TF dialect passes

-cluster-tf-ops-by-host: Cluster the TensorFlow ops by host so that each function only contains ops placed on the same host

-constant-op-device-assignment: Assign device for tf.Const ops

-convert-tf-control-flow-to-scf: Convert TensorFlow control flow to SCF.

This pass can be used for all direct control flow lowerings from the TensorFlow dialect to the SCF dialect.

Prepares TPU computation module attached to _TPUCompileMlir op for TensorFlow graph export by making transformation such as replacing or removing MLIR or XLA specific attributes that are not legal in TensorFlow graph.

-tf-batch-matmul-to-tf-einsum: Replace TF BatchMatMul op by TF Einsum op.

-tf-broadcast-fold: Fold explicit broadcasts into the following operations if they support implicit broadcasting on their operand.

-tf-canonicalize-compile-and-replicate-attributes: Canonicalize compilation and replication attributes.

A pass that converts existing compilation and replication attributes into unified attributes. For example, _tpu_replicate="cluster" in the following code

%control = tf_executor.island wraps "tf.TPUReplicateMetadata"() {_tpu_replicate = "cluster", allow_soft_placement = false, computation_shape = [], device = "", device_assignment = [], host_compute_core = [], name = "TPUReplicateMetadata", num_cores_per_replica = 1 : i64, num_replicas = 1 : i64, step_marker_location = "STEP_MARK_AT_ENTRY", topology = "", use_tpu = true, use_spmd_for_xla_partitioning = false} : () -> ()

wll be replaced by _replication_info="cluster" and _xla_compile_device_type="TPU".

%control = tf_executor.island wraps "tf.TPUReplicateMetadata"() {_replication_info = "cluster", _xla_compile_device_type = "TPU", allow_soft_placement = false, computation_shape = [], device = "", device_assignment = [], host_compute_core = [], name = "TPUReplicateMetadata", num_cores_per_replica = 1 : i64, num_replicas = 1 : i64, step_marker_location = "STEP_MARK_AT_ENTRY", topology = "", use_spmd_for_xla_partitioning = false, use_tpu = true} : () -> ()

_XlaMustCompile=true in the following code

%outputs_67, %control_68 = tf_executor.island wraps "tf.PartitionedCall"(%arg0, %outputs_0) {_XlaMustCompile = true, _collective_manager_ids = [], _read_only_resource_inputs = [], config = "", config_proto = "\0A\07\0A\03CPU\10\01\0A\07\0A\03GPU\10\00\0A\07\0A\03TPU\10\02\0A\0E\0A\0ATPU_SYSTEM\10\012\02J\008\01\82\01\05h\01\88\01\01", device = "", executor_type = "", f = @__inference__jit_compiled_convolution_op_1510} : (tensor<4x32x32x8xf32>, tensor<*xf32>) -> tensor<*xf32>

will be replaced by _xla_compile_device_type, with its value set to the value of device.

%outputs_67, %control_68 = tf_executor.island wraps "tf.PartitionedCall"(%arg0, %outputs_0) {_collective_manager_ids = [], _read_only_resource_inputs = [], _xla_compile_device_type = "", config = "", config_proto = "\0A\07\0A\03CPU\10\01\0A\07\0A\03GPU\10\00\0A\07\0A\03TPU\10\02\0A\0E\0A\0ATPU_SYSTEM\10\012\02J\008\01\82\01\05h\01\88\01\01", device = "", executor_type = "", f = @__inference__jit_compiled_convolution_op_1510} : (tensor<4x32x32x8xf32>, tensor<*xf32>) -> tensor<*xf32>

-tf-convert-to-legacy-compile-and-replicate-attributes: Convert unified compilation and replication attributes back to legacy attributes.

This transformation pass converts unified compilation and replication attributes (_replication_info and _xla_compile_device_type) into legacy attributes. This ensures the unified attributes do not get exposed outside of the MLIR bridge with V1 pipeline in some cases. The pass expects to have either none or both of the unified attributes present in an op for the conversion to happen. Otherwise it will fail.

For example, _replication_info="cluster" and _xla_compile_device_type="TPU" in the following code

%control = tf_executor.island wraps "tf.TPUReplicateMetadata"() {_replication_info = "cluster", _xla_compile_device_type = "TPU", allow_soft_placement = false, computation_shape = [], device = "", device_assignment = [], host_compute_core = [], name = "TPUReplicateMetadata", num_cores_per_replica = 1 : i64, num_replicas = 1 : i64, step_marker_location = "STEP_MARK_AT_ENTRY", topology = "", use_spmd_for_xla_partitioning = false, use_tpu = true} : () -> ()

wll be replaced by _tpu_replicate="cluster" as follows,

%control = tf_executor.island wraps "tf.TPUReplicateMetadata"() {_tpu_replicate = "cluster", allow_soft_placement = false, computation_shape = [], device = "", device_assignment = [], host_compute_core = [], name = "TPUReplicateMetadata", num_cores_per_replica = 1 : i64, num_replicas = 1 : i64, step_marker_location = "STEP_MARK_AT_ENTRY", topology = "", use_tpu = true, use_spmd_for_xla_partitioning = false} : () -> ()

-tf-data-optimization: Performs tf.data optimizations

-tf-decompose-reduce-dataset: Decomposes ReduceDataset op into dataset operations.

Decomposes ReduceDataset op into a while loop that iterates the dataset and calls into the reduction function. This decomposition is only done if the ReduceDataset op is marked for compilation with the _xla_compile_device_type attribute.

For example, for the following function the ReduceDataset op:

  func.func @single_state_single_dataset_type_no_arguments(
    %arg0: tensor<!tf_type.variant>,
    %arg1: tensor<i64>
  ) {
    %1 = "tf.ReduceDataset"(%arg0, %arg1) {
      Targuments = [],
      Tstate = [i64], device = "",
      f = @__reduce_func_1, f._tf_data_function = true,
      output_shapes = [#tf_type.shape<>],
      output_types = [i64], use_inter_op_parallelism = true, _xla_compile_device_type="TPU"} :
 (tensor<!tf_type.variant>, tensor<i64>) -> (tensor<i64>)
    func.return
 }
 ```

 with the following reduction function:

 ```mlir
 func.func private @__reduce_func_1(%arg0: tensor<i64> {tf._user_specified_name = "args_0"},
   %arg1: tensor<32xf32> {tf._user_specified_name = "args_1"}) -> (tensor<i64>)
   attributes {tf._tf_data_function = true, tf.signature.is_stateful} {
     %0 = "tf.JustPretend"(%arg1) : (tensor<32xf32>) -> (tensor<i64>)
     func.return %0 : tensor<i64>
 }
 ```

 will be transformed into:

 ```mlir
 func.func @single_state_single_dataset_type_no_arguments(%arg0: tensor<!tf_type.variant>, %arg1: tensor<i64>) {
  %0 = "tf.AnonymousIteratorV3"() {output_shapes = [#tf_type.shape<32>], output_types = [f32]} : () -> tensor<!tf_type.resource>
  "tf.MakeIterator"(%arg0, %0) : (tensor<!tf_type.variant>, tensor<!tf_type.resource>) -> ()
  %cst = "tf.Const"() {value = dense<true> : tensor<i1>} : () -> tensor<i1>
  %1:2 = "tf.WhileRegion"(%cst, %arg1) ({
  ^bb0(%arg2: tensor<i1>, %arg3: tensor<i64>):
    "tf.Yield"(%arg2) : (tensor<i1>) -> ()
  }, {
  ^bb0(%arg2: tensor<i1>, %arg3: tensor<i64>):
    %2 = "tf.IteratorGetNextAsOptional"(%0) {output_shapes = [#tf_type.shape<32>], output_types = [f32]} : (tensor<!tf_type.resource>) -> tensor<!tf_type.variant>
    %3 = "tf.OptionalHasValue"(%2) : (tensor<!tf_type.variant>) -> tensor<i1>
    %4 = "tf.IfRegion"(%3) ({
      %5 = "tf.OptionalGetValue"(%2) : (tensor<!tf_type.variant>) -> tensor<32xf32>
      %6 = func.call @__reduce_func_1(%arg3, %5) {_xla_compile_device_type = "TPU"} : (tensor<i64>, tensor<32xf32>) -> tensor<i64>
      "tf.Yield"(%6) : (tensor<i64>) -> ()
    }, {
      "tf.Yield"(%arg3) : (tensor<i64>) -> ()
    }) {_lower_using_switch_merge = true, is_stateless = false} : (tensor<i1>) -> tensor<i64>
    "tf.Yield"(%3, %4) : (tensor<i1>, tensor<i64>) -> ()
  }) {_lower_using_switch_merge = true, is_stateless = false, parallel_iterations = 10 : i64} : (tensor<i1>, tensor<i64>) -> (tensor<i1>, tensor<i64>)
  return
}
```
### `-tf-device-assignment-by-func-attr`: Device assignment in TF dialect using the device specified in the function attribute. {: .hide-from-toc }
### `-tf-device-cluster-formation`: Form clusters from instructions assigned to same device {: .hide-from-toc }
Clusters operations with the same device assignment id. For each
cluster, creates a "tf_device.device_launch" op with a Region containing the
ops in each cluster and replaces the ops with the new launch op.

For example, given the following program:

```mlir
  %2 = "tf.A"(%arg0) : (tensor<?xi32>) -> tensor<?xi32>
  %3 = "tf.B"(%2) {device = "tpu0"} : (tensor<?xi32>) -> tensor<?xi32>
  %4 = "tf.C"(%2, %3) {device = "tpu0"} : (tensor<?xi32>, tensor<?xi32>) -> tensor<?xi32>
  %5 = "tf.D"(%4) : (tensor<?xi32>) -> tensor<?xi32>
```

After the pass, we will have:

```mlir
  %0 = "tf.A"(%arg0) : (tensor<?xi32>) -> tensor<?xi32>
  %1 = "tf_device.launch"() ( {
    %3 = "tf.B"(%0) : (tensor<?xi32>) -> tensor<?xi32>
    %4 = "tf.C"(%0, %3) : (tensor<?xi32>, tensor<?xi32>) -> tensor<?xi32>
    tf_device.return %4 : tensor<?xi32>
  }) {device = "tpu0"} : () -> tensor<?xi32>
  %2 = "tf.D"(%1) : (tensor<?xi32>) -> tensor<?xi32>
  return %2 : tensor<?xi32>
```
### `-tf-device-cluster-outlining`: Outlines regions of tf_device.cluster operations {: .hide-from-toc }
This pass outlines the body of a `tf_device.cluster` into a function and
replaces the `tf_device.cluster` op with an equivalent `tf_device.cluster_func`
op. Implicit operands will be captured and materialized as explicit arguments to
the newly created functions and associated `tf_device.cluster_func` ops.

For example, the following:

```mlir
func @computation(%arg0: tensor<i32>) -> tensor<i32> {
  %cluster = "tf_device.cluster"() ( {
    %identity = "tf.Identity"(%arg0) : (tensor<i32>) -> tensor<i32>
    tf_device.return %identity : tensor<i32>
  }) : () -> (tensor<i32>)
  return %cluster : tensor<i32>
}
```

will be transformed into:

```mlir
func @computation(%arg0: tensor<i32>) -> tensor<i32> {
  %cluster = "tf_device.cluster_func"(%arg0) {func = @_func} : (tensor<i32>) -> tensor<i32>
  return %cluster : tensor<i32>
}

func @_func(%arg0: tensor<i32>) -> tensor<i32> {
  %identity = "tf.Identity"(%arg0) : (tensor<i32>) -> tensor<i32>
  return %identity : tensor<i32>
}
```
### `-tf-device-constant-sinking`: Sinks constants implicitly captured in a tf_device.cluster region. {: .hide-from-toc }
This pass sinks implicitly captured constants (`tf.Const` ops) used by and into
a `tf_device.cluster` region. Performing this prior to outlining will reduce the
number of arguments of the outlined function.

For example, the following:

```mlir
func @cluster() -> tensor<i32> {
  %const = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
  %cluster = "tf_device.cluster"() ( {
    %identity = "tf.Identity"(%const) : (tensor<i32>) -> tensor<i32>
    tf_device.return %identity : tensor<i32>
  }) : () -> (tensor<i32>)
  return %cluster : tensor<i32>
}
```

will be transformed into:

```mlir
func @cluster() -> tensor<i32> {
  %cluster = "tf_device.cluster"() ( {
    %const = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
    %identity = "tf.Identity"(%const) : (tensor<i32>) -> tensor<i32>
    tf_device.return %identity : tensor<i32>
  }) : () -> (tensor<i32>)
  return %cluster : tensor<i32>
}
```
### `-tf-device-convert-launch-func-to-tf-call`: Rewrites tf_device::LaunchFuncOp to TF::PartitionedCallOp {: .hide-from-toc }
This pass converts tf_device::LaunchFuncOp into an equivalent
TF::PartitionedCallOp so that it can be exported to TensorFlow GraphDef.
### `-tf-device-index-selector`: Fold tf.DeviceIndex to constant. {: .hide-from-toc }
### `-tf-device-launch-outlining`: Outlines regions of tf_device.launch operations {: .hide-from-toc }
This pass outlines the body of a `tf_device.launch` into a function and
replaces the `tf_device.launch` op with an equivalent `tf_device.launch_func`
op. Implicit operands will be captured and materialized as explicit arguments to
the newly created functions and associated `tf_device.launch_func` ops. The
`device` attribute from the `launch` op is transferred to `launch_func`.

For example, the following:

```mlir
func @computation(%arg0: tensor<i32>) -> tensor<i32> {
  %launch = "tf_device.launch"() ( {
    %identity = "tf.Identity"(%arg0) : (tensor<i32>) -> tensor<i32>
    tf_device.return %identity : tensor<i32>
  }) {device = "some_device"} : () -> (tensor<i32>)
  return %launch : tensor<i32>
}
```

will be transformed into:

```mlir
func @computation(%arg0: tensor<i32>) -> tensor<i32> {
  %launch = "tf_device.launch_func"(%arg0) {device = "some_device", func = @_func} : (tensor<i32>) -> tensor<i32>
  return %launch : tensor<i32>
}

func @_func(%arg0: tensor<i32>) -> tensor<i32> {
  %identity = "tf.Identity"(%arg0) : (tensor<i32>) -> tensor<i32>
  return %identity : tensor<i32>
}
```
### `-tf-device-mark-input-output-aliases`: Marks device cluster inputs-output pairs that read/write to the same variable as aliases {: .hide-from-toc }
This pass analyzes the inputs and outputs to device cluster and marks those
input-output pairs as aliases (using `tf.aliasing_output` attribute) which read
and write to the same resource. This aliasing information can then be propagated
to XLA compiler for input/output buffer space optimizations.
### `-tf-drop-while-shape-invariant`: Drop `shape_invariant` attribute from While/WhileRegion ops. {: .hide-from-toc }
Drop `shape_invariant` attribute from tf.While and tf.WhileRegion op. This
would allow shape inference pass to further refine operand/result shapes of
these ops. This is only safe to do when compiling to XLA.
### `-tf-drop-while-shape-invariant-in-device-cluster`: Drop `shape_invariant` attribute from While/WhileRegion ops inside device cluster. {: .hide-from-toc }
Drop `shape_invariant` attribute from tf.While and tf.WhileRegion op only
inside device cluster. This would allow shape inference pass to further
refine operand/result shapes of these ops. This is only safe to do when
compiling to XLA.
### `-tf-einsum`: Transform Einsum to other TF Ops for the supported variants {: .hide-from-toc }
### `-tf-executor-break-up-islands`: Transform from TF control dialect to TF executor dialect. {: .hide-from-toc }
### `-tf-executor-check-control-dependencies`: Checks control dependencies {: .hide-from-toc }
This pass analyzes control dependencies between islands and warns about
dependencies that are not explainable by side effects of the involved ops.
More precisely, for every minimal unexplainable control dependency path
we emit op warnings for all involved ops. The pass does not report
intermediate dummy ops for grouping control dependencies (Identity, NoOp),
unless they are part of an unexplainable path between other ops.
This pass is useful to understand control dependency conservatism for a
given MLIR module.

For example, the following function
```mlir
func.func @path_with_intermediate_ops(
  %arg0: tensor<!tf_type.resource<tensor<f32>>>,
  %arg1: tensor<!tf_type.resource<tensor<f32>>>,
  %arg2: tensor<f32>) -> () {
  tf_executor.graph {
    %island1 = tf_executor.island wraps "tf.AssignVariableOp"(%arg0, %arg2) : (tensor<!tf_type.resource<tensor<f32>>>, tensor<f32>) -> ()
    %island2 = tf_executor.island(%island1) wraps "tf.NoOp"() : () -> ()
    %island3 = tf_executor.island(%island2) wraps "tf.NoOp"() : () -> ()
    %island4 = tf_executor.island(%island3) wraps "tf.AssignVariableOp"(%arg1, %arg2) : (tensor<!tf_type.resource<tensor<f32>>>, tensor<f32>) -> ()
    tf_executor.fetch
  }
  func.return
}
```
produces the following warnings
```mlir
  6:45: warning: unexpected control dependency path: path 0, node 0 (source)
  %island1 = tf_executor.island wraps "tf.AssignVariableOp"(%arg0, %arg2) : (tensor<!tf_type.resource<tensor<f32>>>, tensor<f32>) -> ()
                                      ^
  6:45: note: see current operation: %control = tf_executor.island wraps "tf.AssignVariableOp"(%arg0, %arg2) : (tensor<!tf_type.resource<tensor<f32>>>, tensor<f32>) -> ()
  7:55: warning: unexpected control dependency path: path 0, node 1 (intermediate)
  %island2 = tf_executor.island(%island1) wraps "tf.NoOp"() : () -> ()
                                                ^
  7:55: note: see current operation: %control_0 = tf_executor.island(%control) wraps "tf.NoOp"() : () -> ()
  8:55: warning: unexpected control dependency path: path 0, node 2 (intermediate)
  %island3 = tf_executor.island(%island2) wraps "tf.NoOp"() : () -> ()
                                                ^
  8:55: note: see current operation: %control_1 = tf_executor.island(%control_0) wraps "tf.NoOp"() : () -> ()
  9:55: warning: unexpected control dependency path: path 0, node 3 (target)
  %island4 = tf_executor.island(%island3) wraps "tf.AssignVariableOp"(%arg1, %arg2) : (tensor<!tf_type.resource<tensor<f32>>>, tensor<f32>) -> ()
                                                ^
  9:55: note: see current operation: %control_2 = tf_executor.island(%control_1) wraps "tf.AssignVariableOp"(%arg1, %arg2) : (tensor<!tf_type.resource<tensor<f32>>>, tensor<f32>) -> ()
```
because the first and last `AssignVariableOp`s access different resources
and therefore should be independent. Note that the `NoOp`s are considered
as intermediate ops for control dependency grouping.
### `-tf-executor-convert-control-to-data-outputs`: Chain control outputs of while loop body {: .hide-from-toc }
This pass converts the control outputs of a while loop body function to data
outputs. Thus, inter iteration control dependencies are transformed to
data dependencies. Since data dependencies can express which particular
operations in the while loop body are dependent on which inputs, it captures
inter iteration parallelism in while loop. Control dependencies on the other
hand create a barrier at the end of while loop body thus blocking any
parallelism across iterations.

For example, the following while loop body has a `%barrier` at the end.
Although there is no data/control dependency between `tf.AssignVariableOp`
for `%arg0` to `tf.AssignVariableOp` for `%arg1` across any iteration, the
while loop body has a control barrier (`%barrier`) at the end which forces
a dependency and the two assign variable ops must wait for each other to
complete before starting the next iteration. Transforming these control
outputs to data outputs removes the dependency between the two assign
variable ops, thus allowing them to run in parallel across iterations.

Before:

```mlir
!tf_res = type tensor<!tf_type.resource<tensor<f32>>>
func @while_body(%arg0: !tf_res, %arg1: !tf_res, %arg2: tensor<f32>, %arg3: tensor<f32>) -> (!tf_res, !tf_res, tensor<f32>, tensor<f32>) {
  %graph:4 = tf_executor.graph {
    %assign_0_control = tf_executor.island wraps "tf.AssignVariableOp"(%arg0, %arg2) : (!tf_res, tensor<f32>) -> ()
    %assign_1_control = tf_executor.island wraps "tf.AssignVariableOp"(%arg1, %arg3) : (!tf_res, tensor<f32>) -> ()
    %add_out, %add_control = tf_executor.island wraps "tf.Add"(%arg2, %arg3) : (tensor<f32>, tensor<f32>) -> tensor<f32>
    %mul_out, %mul_control = tf_executor.island wraps "tf.Mul"(%arg2, %arg3) : (tensor<f32>, tensor<f32>) -> tensor<f32>
    %barrier = tf_executor.island(%assign_0_control, %assign_1_control, %add_control, %mul_control) wraps "tf.NoOp"() : () -> ()
    tf_executor.fetch %arg0, %arg1, %add_out, %mul_out, %barrier : !tf_res, !tf_res, tensor<f32>, tensor<f32>, !tf_executor.control
  }
  return %graph#0, %graph#1, %graph#2, %graph#3 : !tf_res, !tf_res, tensor<f32>, tensor<f32>
}
```

After:

```mlir
func @while_body(%arg0: !tf_res, %arg1: !tf_res, %arg2: tensor<f32>, %arg3: tensor<f32>, %chain_0: tensor<i32>, %chain_1: tensor<i32>) -> (!tf_res, !tf_res, tensor<f32>, tensor<f32>, tensor<i32>, tensor<i32>) {
  %graph:6 = tf_executor.graph {
    %_, %chain_0_src = tf_executor.island wraps "tf.Identity"(%chain_0) : (tensor<i32>) -> tensor<i32>
    %_, %chain_1_src = tf_executor.island wraps "tf.Identity"(%chain_1) : (tensor<i32>) -> tensor<i32>
    %assign_0_control = tf_executor.island(%chain_0_src) wraps "tf.AssignVariableOp"(%arg0, %arg2) : (!tf_res, tensor<f32>) -> ()
    %assign_1_control = tf_executor.island(%chain_1_src) wraps "tf.AssignVariableOp"(%arg1, %arg3) : (!tf_res, tensor<f32>) -> ()
    %add_out, %add_control = tf_executor.island wraps "tf.Add"(%arg2, %arg3) : (tensor<f32>, tensor<f32>) -> tensor<f32>
    %mul_out, %mul_control = tf_executor.island wraps "tf.Mul"(%arg2, %arg3) : (tensor<f32>, tensor<f32>) -> tensor<f32>
    %chain_0_sink, %_ = tf_executor.island(%assign_0_control) wraps "tf.Identity"(%chain_0) : (tensor<i32>) -> tensor<i32>
    %chain_1_sink, %_ = tf_executor.island(%assign_1_control) wraps "tf.Identity"(%chain_1) : (tensor<i32>) -> tensor<i32>
    tf_executor.fetch %arg0, %arg1, %add_out, %mul_out, %chain_0_sink, %chain_1_sink : !tf_res, !tf_res, tensor<f32>, tensor<f32>, tensor<i32>, tensor<i32>
  }
  return %graph#0, %graph#1, %graph#2, %graph#3, %graph#4, %graph#5 : !tf_res, !tf_res, tensor<f32>, tensor<f32>, tensor<i32>, tensor<i32>
}
```
### `-tf-executor-graph-pruning`: Prunes unreachable ops in a tf_executor.graph {: .hide-from-toc }
This pass removes ops from a `tf_executor.graph` that are not transitively, via
data or control dependencies, connected to the associated `tf_executor.fetch`
op. The order of ops will be preserved. Functions named `main` with no
`tf.entry_function` attribute will not be pruned, as such graphs/functions may
have been imported from a V1 TensorFlow graph, where feeds/fetches/targets are
not provided at certain stages of IR transformation (e.g. pre-placement).

Option `ops-to-preserve` allows to specify ops that should not be pruned,
regardless of their reachability.

For example, the following:

```mlir
func @graph(%arg0: tensor<i32>, %arg1: tensor<i32>) -> tensor<i32> {
  %graph = tf_executor.graph {
    %transitive_reachable_data:2 = tf_executor.island wraps "tf.Identity"(%arg0) : (tensor<i32>) -> tensor<i32>
    %reachable_data:2 = tf_executor.island wraps "tf.Identity"(%transitive_reachable_data#0) : (tensor<i32>) -> tensor<i32>
    %unreachable_data:2 = tf_executor.island wraps "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
    %transitive_reachable_control = tf_executor.island wraps "tf.NoOp"() : () -> ()
    %reachable_control = tf_executor.island(%transitive_reachable_control) wraps "tf.NoOp"() : () -> ()
    %unreachable_control = tf_executor.island wraps "tf.NoOp"() : () -> tensor<i32>
    tf_executor.fetch %reachable_data#0, %reachable_control : tensor<i32>, !tf_executor.control
  }
  return %graph : tensor<i32>
}
```

will be transformed into:

```mlir
func @graph(%arg0: tensor<i32>, %arg1: tensor<i32>) -> tensor<i32> {
  %graph = tf_executor.graph {
    %transitive_reachable_data:2 = tf_executor.island wraps "tf.Identity"(%arg0) : (tensor<i32>) -> tensor<i32>
    %reachable_data:2 = tf_executor.island wraps "tf.Identity"(%transitive_reachable_data#0) : (tensor<i32>) -> tensor<i32>
    %transitive_reachable_control = tf_executor.island wraps "tf.NoOp"() : () -> ()
    %reachable_control = tf_executor.island(%transitive_reachable_control) wraps "tf.NoOp"() : () -> ()
    tf_executor.fetch %reachable_data#0, %reachable_control : tensor<i32>, !tf_executor.control
  }
  return %graph : tensor<i32>
}
```

#### Options {: .hide-from-toc }
```
-ops-to-preserve : Comma separated list of ops that should not be pruned regardless of reachability
```
### `-tf-executor-island-coarsening`: Walks tf_executor::GraphOp and merges individual tf_executor::IslandOps. {: .hide-from-toc }
This pass performs whole graph analysis for a graph encapsulated into tf_executor::GraphOp.
The analysis identifies all IslandOps within the graph which could be merged together.
The goal is to merge as many islands as possible.
Once analysis is completed, the pass merges all IslandOps in a single scan.

For example given the following program with two disjunct islands:

```mlir
  func @test(%arg0 : tensor<i1>) -> tensor<f32> {
    %0 = tf_executor.graph {
      %1:2 = tf_executor.island {
        %3 = "tf.opA"(%arg0) : (tensor<i1>) -> tensor<i1>
        tf_executor.yield %3 : tensor<i1>
      }
      %2:2 = tf_executor.island(%1#1) {
        %4 = "tf.opB"() : () -> tensor<f32>
        tf_executor.yield %4 : tensor<f32>
      }
      tf_executor.fetch %2#0 : tensor<f32>
    }
    return %0 : tensor<f32>
  }
```

After running this pass, the two islands are merged:

```mlir
  func @test(%arg0: tensor<i1>) -> tensor<f32> {
    %0 = tf_executor.graph {
      %outputs, %control = tf_executor.island {
        %1 = "tf.opA"(%arg0) : (tensor<i1>) -> tensor<i1>
        %2 = "tf.opB"() : () -> tensor<f32>
        tf_executor.yield %2 : tensor<f32>
      }
      tf_executor.fetch %outputs : tensor<f32>
    }
    return %0 : tensor<f32>
  }
```
### `-tf-executor-split-into-island-per-op`: Transform from TF control dialect to TF executor dialect. {: .hide-from-toc }
Splits an island with multiple ops into multiple islands (one per op). Does
not create any control dependencies between new islands, and does not
propagate control dependencies that potentially existed between the old
islands into the new islands. Maintains existing data dependencies between
ops wrapped by the new islands.

Example: original program:

```mlir
    func.func @dangling_print(%arg0: tensor<*xi32>, %arg1: tensor<i32>) -> (tensor<*xi32>, tensor<*xi32>) {
      %graph:2 = tf_executor.graph {
        %island1:3 = tf_executor.island {
          %add1 = "tf.Add"(%arg0, %arg1) : (tensor<*xi32>, tensor<i32>) -> tensor<*xi32>
          %add2 = "tf.Add"(%add1, %arg1) : (tensor<*xi32>, tensor<i32>) -> tensor<*xi32>
          %res = "tf.Print"(%add2) { message = "add result" } : (tensor<*xi32>) -> (tensor<*xi32>)
          tf_executor.yield %add1, %add2 : tensor<*xi32>, tensor<*xi32>
        }
        tf_executor.fetch %island1#0, %island1#1 : tensor<*xi32>, tensor<*xi32>
      }
      func.return %graph#0, %graph#1 : tensor<*xi32>, tensor<*xi32>
    }
```

will be converted by this pass into:

```mlir
    func.func @dangling_print(%arg0: tensor<*xi32>, %arg1: tensor<i32>) -> (tensor<*xi32>, tensor<*xi32>) {
      %0:2 = tf_executor.graph {
        %outputs, %control = tf_executor.island wraps "tf.Add"(%arg0, %arg1) : (tensor<*xi32>, tensor<i32>) -> tensor<*xi32>
        %outputs_0, %control_1 = tf_executor.island wraps "tf.Add"(%outputs, %arg1) : (tensor<*xi32>, tensor<i32>) -> tensor<*xi32>
        %outputs_2, %control_3 = tf_executor.island wraps "tf.Print"(%outputs_0) {message = "add result"} : (tensor<*xi32>) -> tensor<*xi32>
        tf_executor.fetch %outputs, %outputs_0 : tensor<*xi32>, tensor<*xi32>
      }
      return %0#0, %0#1 : tensor<*xi32>, tensor<*xi32>
    }
```
### `-tf-executor-to-functional-conversion`: Lifts tf_executor.island inner ops from a tf_executor.graph {: .hide-from-toc }
This pass converts tf_executor.graphs consisting of only tf_executor.islands and
a tf_executor.fetch into a sea of nodes consisting of TensorFlow Dialect ops by
lifting such ops out of a tf_executor.graph's tf_executor.islands. If V1 control
flow ops are present in a tf_executor.graph, an error will be returned.

For example, the following:

```mlir
func @my_fn(%arg0: tensor<i32>, %arg1: tensor<i32>) -> (tensor<i32>, tensor<i32>) {
  %graph_results:2 = tf_executor.graph {
    %island_0_result, %island_0_control = tf_executor.island {
      %identity = "tf.Identity"(%arg0) : (tensor<i32>) -> tensor<i32>
      tf_executor.yield %identity : tensor<i32>
    }
    %island_1_result, %island_1_control = tf_executor.island {
      %identity_n:2 = "tf.IdentityN"(%arg1, %island_0_result) : (tensor<i32>, tensor<i32>) -> (tensor<i32>, tensor<i32>)
      tf_executor.yield %identity_n#0
    }
    tf_executor.fetch %island_0_result, %island_1_result : tensor<i32>, tensor<i32>
  }
  return %graph_results#0, %graph_results#1 : tensor<i32>, tensor<i32>
}
```

will be transformed into:

```mlir
func @my_fn(%arg0: tensor<i32>, %arg1: tensor<i32>) -> (tensor<i32>, tensor<i32>) {
  %identity = "tf.Identity"(%arg0) : (tensor<i32>) -> tensor<i32>
  %identity_n:2 = "tf.IdentityN"(%arg1, %identity) : (tensor<i32>, tensor<i32>) -> (tensor<i32>, tensor<i32>)
  return %identity, %identity_n#0 : tensor<i32>, tensor<i32>
}
```
### `-tf-executor-tpu-v1-island-coarsening`: Merges TPU clusters IslandOps, intended for V1 compatibility mode {: .hide-from-toc }
This pass is a variant of ExecutorIslandCoarseningPass that is limited to
TPU-annotated operations and intended to preserve backward compatibility with
TFv1.
### `-tf-executor-tpu-v1-island-inlining`: Inline calls to the nested TPU module. {: .hide-from-toc }
This pass inlines the islands calling into the nested module that was
outlined, thus reversing the effect of the
`-tf-executor-tpu-v1-island-outlining` pass.

For example, the following:
```mlir
module {
  func @foo(%arg0: tensor<f32>) -> tensor<f32> {
    %0 = tf_executor.graph {
      %outputs, %control = tf_executor.island wraps "tf.PartitionedCall"(%arg0) {f = @_tpu_v1_compat_outlined::@bar} : (tensor<f32>) -> tensor<f32>
      tf_executor.fetch %outputs : tensor<f32>
    }
    return %0 : tensor<f32>
  }
  module @_tpu_v1_compat_outlined {
    func nested @bar(%arg0: tensor<f32>) -> tensor<f32> {
      %0 = "tf.opA"(%arg0) : (tensor<f32>) -> tensor<f32>
      return %0 : tensor<f32>
    }
  }
}
```

will be transformed into:

```mlir
module  {
  func @foo(%arg0: tensor<f32>) -> tensor<f32> {
    %0 = tf_executor.graph {
      %outputs, %control = tf_executor.island {
        %1 = "tf.opA"(%arg0) : (tensor<f32>) -> tensor<f32>
        tf_executor.yield %1 : tensor<f32>
      }
      tf_executor.fetch %outputs : tensor<f32>
    }
    return %0 : tensor<f32>
  }
}
```
### `-tf-executor-tpu-v1-island-outlining`: Outline TPU clusters from island into a nested module, so it can be processed like a V2 module, intended for V1 compatibility mode {: .hide-from-toc }
Extract the islands containing a TPU cluster computation into an outlined
function in a nested module. This will allow to run the usual bridge on this
nested module which now exhibits a more friendly "V2-like" structure.
This is only intended for V1 compatibility mode where the bridge runs without
feed/fetches on session create/extend.

So given e.g.

```mlir
  func @test() -> tensor<i32> {
    %0 = tf_executor.graph {
      %output, %control = tf_executor.island {
        ...
        tf_executor.yield %result : tensor<i32>
      }
      tf_executor.fetch %output : tensor<i32>
    }
    return %0
  }
```

This pass will create an additional function containing the code in
tf_executor.island:

```mlir
  func nested @_tpu_v1_compat_outlined_func0() -> tensor<i32> {
    ...
  }
```

and will then replace the island with the wrapped call:

```mlir
  func @test() -> tensor<i32> {
    %0 = tf_executor.graph {
      %outputs, %control = tf_executor.island wraps "tf.PartitionedCall"() {
          f = @_tpu_v1_compat_outlined::@_tpu_v1_compat_outlined_func0
      } : () -> tensor<i32>
      tf_executor.fetch %outputs : tensor<i32>
    }
    return %0 : tensor<i32>
  }
```
### `-tf-executor-update-control-dependencies`: Computes and applies all necessary control dependencies based on side effect analysis. {: .hide-from-toc }
This pass is intended to run after the split_into_island_per_op
pass. That pass splits up multi-op islands into multiple individual islands
wrapping a single op without applying any control deps between the new
islands. So, this pass is needed in order to make preservation of the
semantic ordering relationships between ops as determined by side effect
analysis explicit in the IR.

Example: original program:

```mlir
    func.func @example(%arg0: tensor<*x!tf_type.resource<tensor<32xf32>>>, %arg1: tensor<32xf32>) -> (tensor<32xf32>) {
      %graph = tf_executor.graph {
        %read0, %read0_control = tf_executor.island wraps "tf.ReadVariableOp"(%arg0) : (tensor<*x!tf_type.resource<tensor<32xf32>>>) -> tensor<32xf32>
        %assign0_control = tf_executor.island wraps "tf.AssignVariableOp"(%arg0, %arg1) : (tensor<*x!tf_type.resource<tensor<32xf32>>>, tensor<32xf32>) -> ()
        %read1, %read1_control = tf_executor.island wraps "tf.ReadVariableOp"(%arg0) : (tensor<*x!tf_type.resource<tensor<32xf32>>>) -> tensor<32xf32>
        %print, %print_control = tf_executor.island wraps "tf.Print"(%read1) { message = "read1 value" } : (tensor<32xf32>) -> (tensor<32xf32>)
        tf_executor.fetch %read1#0 : tensor<32xf32>
      }
      func.return %graph : tensor<32xf32>
    }
```

will be converted by this pass into:

```mlir
    func.func @example(%arg0: tensor<*x!tf_type.resource<tensor<32xf32>>>, %arg1: tensor<32xf32>) -> tensor<32xf32> {
      %0 = tf_executor.graph {
        %read0, %read0_control = tf_executor.island wraps "tf.ReadVariableOp"(%arg0) : (tensor<*x!tf_type.resource<tensor<32xf32>>>) -> tensor<32xf32>
        %assign0_control = tf_executor.island(%read0_control) wraps "tf.AssignVariableOp"(%arg0, %arg1) : (tensor<*x!tf_type.resource<tensor<32xf32>>>, tensor<32xf32>) -> ()
        %read1, %read1_control = tf_executor.island(%assign0_control) wraps "tf.ReadVariableOp"(%arg0) : (tensor<*x!tf_type.resource<tensor<32xf32>>>) -> tensor<32xf32>
        %print, %print_control = tf_executor.island(%read1_control) wraps "tf.Print"(%read1) {message = "read1 value"} : (tensor<32xf32>) -> tensor<32xf32>
        tf_executor.fetch %read1, %print_control : tensor<32xf32>, !tf_executor.control
      }
      return %0 : tensor<32xf32>
    }
```
### `-tf-extract-head-tail-outside-compilation`: Extracts head or tail outside compilation to separate host launches before/after device cluster. {: .hide-from-toc }
This pass extracts a CPU computation cluster with `_xla_outside_compilation`
annotation from the head or tail of a Device cluster.

For example:

```mlir
  %cluster = "tf_device.cluster"() ( {
    %a = "tf.A"(%arg0) {_xla_outside_compilation = "cluster1"} : (tensor<i32>) -> tensor<i32>
    %b = "tf.B"(%a) : (tensor<i32>) -> tensor<i32>
    %c = "tf.C"(%b) {_xla_outside_compilation = "cluster1"} : (tensor<i32>) -> tensor<i32>
    tf_device.return %c : tensor<i32>
  }) {num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = []} : () -> tensor<i32>
  return %cluster : tensor<i32>
```

becomes:

```mlir
%0 = "tf_device.launch"() ( {
  %3 = "tf.A"(%arg0) : (tensor<i32>) -> tensor<i32>
  tf_device.return %3 : tensor<i32>
}) {device = "/job:worker/replica:0/task:0/device:CPU:0"} : () -> tensor<i32>
%1 = "tf_device.cluster"() ( {
  %3 = "tf.B"(%0) : (tensor<i32>) -> tensor<i32>
  tf_device.return %3 : tensor<i32>
}) {device_assignment = [], num_cores_per_replica = 1 : i64, padding_map = [], step_marker_location = "", topology = ""} : () -> tensor<i32>
%2 = "tf_device.launch"() ( {
  %3 = "tf.C"(%1) : (tensor<i32>) -> tensor<i32>
  tf_device.return %3 : tensor<i32>
}) {device = "/job:worker/replica:0/task:0/device:CPU:0"} : () -> tensor<i32>
return %2 : tensor<i32>

```
### `-tf-extract-outside-compilation`: Extracts device outside compilation computation to a separate tf_device.parallel_execute region. {: .hide-from-toc }
This pass extracts a CPU computation cluster with `_xla_outside_compilation`
annotation, which denotes ops that should be run on CPU/host, from a device cluster.
Each outside compilation cluster is moved to
a tf_device.parallel_execute region. The device cluster is also moved to a
tf_device.parallel_execute region. Communication ops between device and host are
added to pass inputs/outputs to/from the outside compiled region.

For example, the following tf_device.cluster with an op marked for `xla_outside_compilation`:

```mlir
func @outside_compilation() -> tensor<f32> {
  %0 = "tf_device.cluster"() ( {
    %1 = "tf.Const"() {_xla_outside_compilation = "0", value = dense<1.0> : tensor<f32>} : () -> (tensor<f32>)
    %2 = "tf.Identity"(%1) {_xla_outside_compilation = "0"} : (tensor<f32>) -> (tensor<f32>)
    %3 = "tf.AddV2"(%1, %2) : (tensor<f32>, tensor<f32>) -> (tensor<f32>)
    tf_device.return %3 : tensor<f32>
  }) {num_cores_per_replica = 1, topology =  "", device_assignment =  []} : () -> tensor<f32>
  return %0 : tensor<f32>
}
```

will become a tf_device.parallel_execute op with a CPU/host region and
a tf_device.cluster with communication ops to send data to/from device/host:

```mlir
func @outside_compilation() -> tensor<f32> {
  %0 = "tf_device.parallel_execute"() ( {
    "tf_device.launch"() ( {
      %1 = "tf._TPUCompileMlirPlaceholderProgramKey"() : () -> tensor<3x!tf_type.string>
      %2 = "tf._XlaRecvAtHost"(%1) {device_ordinal = 0 : i64, key = "host_compute_channel_0_0_args"} : (tensor<3x!tf_type.string>) -> tensor<f32>
      %3 = "tf.Identity"(%2) : (tensor<f32>) -> tensor<f32>
      "tf._XlaSendFromHost"(%3, %1) {device_ordinal = 0 : i64, key = "host_compute_channel_0_0_retvals"} : (tensor<f32>, tensor<3x!tf_type.string>) -> ()
      tf_device.return
    }) {device = "/job:worker/replica:0/task:0/device:CPU:0"} : () -> ()
    tf_device.return
  },  {
    %1 = "tf_device.cluster"() ( {
      %2 = "tf.Const"() {value = dense<1.000000e+00> : tensor<f32>} : () -> tensor<f32>
      %3 = "tf._XlaHostComputeMlir"(%2) {recv_key = "host_compute_channel_0_0_retvals", send_key = "host_compute_channel_0_0_args", tpu_core = 0 : i64} : (tensor<f32>) -> tensor<f32>
      %4 = "tf.AddV2"(%2, %3) : (tensor<f32>, tensor<f32>) -> tensor<f32>
      tf_device.return %4 : tensor<f32>
    }) {device_assignment = [], num_cores_per_replica = 1 : i64, topology = ""} : () -> tensor<f32>
    tf_device.return %1 : tensor<f32>
  }) : () -> tensor<f32>
  return %0 : tensor<f32>
}
```
### `-tf-extract-tpu-copy-with-dynamic-shape-op`: Extract the TPUCopyWithDynamicShapeOp out of the host launch and place it on device launch {: .hide-from-toc }
This pass looks for TPUCopyWithDynamicShapeOp which wraps in a
`tf_device.launch` with host device attribute. It extracts the ops and wrap
them in `tf_device.launch` with tpu device attribute so that ops can be
run on TPU instead of CPU while still being compiled on host.
### `-tf-functional-control-flow-to-cfg`: Transform functional control flow Ops to MLIR Control Form Graph (CFG) form {: .hide-from-toc }
### `-tf-functional-control-flow-to-regions`: Transforms functional control flow operations to their region-based counterparts {: .hide-from-toc }
This pass transforms functional control flow operations in the TensorFlow
dialect to their region-based counterparts, i.e., `tf.If` is transformed to
`tf.IfRegion` and `tf.While` is transformed to `tf.WhileRegion`.

For example, this functional operation

```mlir
  %0 = "tf.If"(%arg0, %arg1) {
    then_branch = @then_branch_func, else_branch = @else_branch_func, is_stateless = false
  } : (tensor<i1>, tensor<*xf32>) -> tensor<*xf32>
```

will be transformed into this region-based operation

```mlir
    %0 = "tf.IfRegion"(%arg0) ( {
      %1 = call @then_branch_func(%arg1) : (tensor<*xf32>) -> tensor<*xf32>
      "tf.Yield"(%1) : (tensor<*xf32>) -> ()
    },  {
      %1 = call @else_branch_func(%arg1) : (tensor<*xf32>) -> tensor<*xf32>
      "tf.Yield"(%1) : (tensor<*xf32>) -> ()
    }) {is_stateless = false} : (tensor<i1>) -> tensor<*xf32>
```
### `-tf-functional-to-executor-conversion`: Transform from func op to TF executor dialect. {: .hide-from-toc }
### `-tf-fused-kernel-matcher`: Matches computations corresponding to optimized fused kernels {: .hide-from-toc }
### `-tf-gpu-op-fusion`: Fusion optimization for GPU targets {: .hide-from-toc }
This pass is performing fusion specific to GPU targets. This is an ad-hoc
pass for now, but should be integrated with some notion of "target" in the
MLIR pipeline in the future.
### `-tf-group-by-dialect`: Groups ops into functions that only contain one dialect. {: .hide-from-toc }
Factors operations into subroutines such that all functions only
contain a single dialect. Which of the dialects are allowed in the
"top" function is configurable.

For example, the code
  x.a()
  x.b()
  %c = y.c()
  x.d(%c)
would be transformed into something like
  call @x_1()
  %c = call @y_1()
  call @x_2(%c)
with @x_1, @x_2 and @y_1 filled in.
### `-tf-guarantee-all-funcs-one-use`: Guarantee all FuncOp's have only a single use. {: .hide-from-toc }
### `-tf-hoist-loop-invariant`: Hoists loop invariant ops to the outside of the loop {: .hide-from-toc }
   Hoists loop invariant to the outside of the loop. The pass is similar to
   LoopInvariantCodeMotion pass, but it also hoists ReadVariableOps,
   if the variable is read only.

   For example, the following pseudo MLIR code (types are left out for
   brevity)
   ```mlir
     func.func @hoist_loop_invariant(%arg0, %arg1) {
%var = "tf.VarHandleOp"() {container="", shared_name="var_name", device = "/device:CPU:0"}
       %results:2 = "tf.WhileRegion"(%arg0, %arg1) ({
       ^bb0(%arg2, %arg3):
         %0 = "tf.OpA"() {is_stateless = true}
         "tf.Yield"(%0)
       }, {
       ^bb0(%arg2, %arg3):
  %1 = "tf.ReadVariableOp"(%var)
         %2 = "tf.OpB"(%1) {is_stateless = true}
         %3 = "tf.OpC"(%arg2, %2) {is_stateless = true}
         %4 = "tf.OpD"(%arg3, %2) {is_stateless = true}
         "tf.Yield"(%3, %4)
       }) {is_stateless = true}
       return %results#0, %results#1
     }
   ```
   would be transformed to
   ```mlir
    func.func @hoist_loop_invariant(%arg0, %arg1) {
%var = "tf.VarHandleOp"() {container="", shared_name="var_name", device = "/device:CPU:0"}
%1 = "tf.ReadVariableOp"(%var)
       %2 = "tf.OpB"(%1) {is_stateless = true}
       %results:2 = "tf.WhileRegion"(%arg0, %arg1) ({
       ^bb0(%arg2, %arg3):
         %0 = "tf.OpA"() {is_stateless = true}
         "tf.Yield"(%0)
       }, {
       ^bb0(%arg2, %arg3):
         %3 = "tf.OpC"(%arg2, %2) {is_stateless = true}
         %4 = "tf.OpD"(%arg3, %2) {is_stateless = true}
         "tf.Yield"(%3, %4)
       }) {is_stateless = true}
       return %results#0, %results#1
     }
   ```
   The `tf.ReadVariableOp` and `tf.OpB` can be hoisted to the outside of
   the loop.

### `-tf-hoist-replicate-invariant-resource-writes`: Hoists writes to replicate invariant resource variables. {: .hide-from-toc }
This pass hoists replicate invariant resource variable writes outside
tf_device.replicate op. These may have been inserted by other passes such as
resource op lifting. However, if the resource variable is not replicated, writes
to such variables for each replica are redundant and can be replaced by writing
a single value from first replica.

The benefit of this optimization is reduced memory requirement on host. For
multiple writes (one from each replica) to such variables, the host would
allocate buffer space to receive the device output from all replicas, which is
not required. We can use the output of first replica in such cases.
### `-tf-init-text-file-to-import`: convert InitializeTableFromTextFileV2 ops to LookupTableImportV2Op to remove the dependency on asset files {: .hide-from-toc }

#### Options {: .hide-from-toc }
```
-tf-saved-model-dir : Directory containing the model exported as a TensorFlow SavedModel. If your model is not based on the TensorFlow SavedModel, use an empty value.
```
### `-tf-layout-assignment`: Layout assignment pass. {: .hide-from-toc }

#### Options {: .hide-from-toc }
```
-force-data-format : Force data format for all layout sensitive ops.
```
### `-tf-legalize-hlo`: Legalize from HLO to the TF dialect {: .hide-from-toc }
### `-tf-localize-var-handles`: Creates VarHandleOps next to the operations that use them. {: .hide-from-toc }
Creates VarHandleOps right next to the operations that use them, one
per operation.
This is useful for transformations that only end up with a few small
snippets of remaining TF code, and wish for those snippets to be
self-contained.
For example, this would transform

"tf_saved_model.global_tensor"() { sym_name = "v" ... }
func @f(%arg0 {tf_saved_model.bound_input = @v}) {
  %1 = "tf.ReadVariableOp"(%arg0)
  ...
}

to

func @f(%arg0 {tf_saved_model.bound_input = @v}) {
  %0 = "tf.VarHandleOp"(sym_name = "v")
  %1 = "tf.ReadVariableOp"(%0)
  ...
}

Note that this pass might leave behind unused values
(like e.g. %arg0 in the example above), which can later be
pruned using DCE.
### `-tf-lower-quantized`: Lowers ops that require quantized input or output. {: .hide-from-toc }
This pass rewrites all ops that have at least one input or output that must
be a quantized type to ops whose inputs and outputs allow non-quantized
types. Examples of quantized types are TF_Qint8 or TF_Quint8.

An example is TF_DequantizeOp, which converts a quantized type to a float.
This op is rewritten to generic ops that perform the scale and shift
and can operate on non-quantized types.

Currently, TF_DequantizeOp is the only op with a lowering that falls
in this category. When more lowerings are added (e.g. QuantizeV2Op),
they should be added to this pass.
### `-tf-mark-ops-for-outside-compilation`: Marks ops in device cluster for outside compilation if they are unsupported on device. {: .hide-from-toc }
This pass marks unsupported ops in a device cluster with
`_xla_outside_compilation` attribute so the operations will run on the host
instead of the device. Unsupported ops are ops that can not be code
generated to run on the device for the cluster including:

1. String operations on TPUs.
2. Operations that don't have a kernel defined for the device.

This pass is conservative in that it will mark all ops for outside compilation
that can not be compiled for the device.  Exceptions for this are added for ops
that will be rewritten or decomposed before compiling on device.


For example, tf_device.cluster op with an unsupported op, tf.UnsupportedOp:

```mlir
func @unsupported_op() -> tensor<i32> {
  %0 = "tf_device.cluster"() ( {
    %1 = "tf.UnsupportedOp"() : () -> tensor<i32>
    %2 = "tf.Identity"(%1) : (tensor<i32>) -> tensor<i32>
    tf_device.return %2 : tensor<i32>
  }) {allow_soft_placement = true, num_cores_per_replica = 1, topology =  "", device_assignment =  []} : () -> tensor<i32>
  return %0 : tensor<i32>
}
```

will mark tf.UnsupportedOp with `_xla_outside_compilation` attribute:

```mlir
func @unsupported_op() -> tensor<i32> {
  %0 = "tf_device.cluster"() ( {
    %1 = "tf.UnsupportedOp"() {_xla_outside_compilation = "auto0"} : () -> tensor<i32>
    %2 = "tf.Identity"(%1) : (tensor<i32>) -> tensor<i32>
    tf_device.return %2 : tensor<i32>
  }) {allow_soft_placement = true, device_assignment = [], num_cores_per_replica = 1 : i64, topology = ""} : () -> tensor<i32>
  return %0 : tensor<i32>
}
```
### `-tf-materialize-passthrough-op`: Materialize the MlirPassthroughOp by replacing it with the MLIR module attached as an attribute {: .hide-from-toc }
A pass that replaces MlirPassthrough ops with the code they have in
their `mlir_module` string attribute.
### `-tf-merge-control-flow`: Merges IfRegion ops together with a common predicate. {: .hide-from-toc }
This pass merges IfRegion ops together if they have the same predicate and it
is safe to do so (there are no intermediate dependencies, they are in the
same block, etc).

For example:

```mlir
"tf.IfRegion"(%0) ( {
  %2 = "tf.A"() : () -> (tensor<f32>)
  "tf.Yield"() : () -> ()
  }, {
  "tf.Yield"() : () -> ()
 }) { is_stateless = true } : (tensor<i1>) -> ()
"tf.IfRegion"(%0) ( {
  %2 = "tf.B"() : () -> (tensor<f32>)
  "tf.Yield"() : () -> ()
  }, {
  "tf.Yield"() : () -> ()
  }) { is_stateless = true } : (tensor<i1>) -> ()
```

Would be transformed to:

```mlir
"tf.IfRegion"(%0) ( {
  %2 = "tf.A"() : () -> (tensor<f32>)
  %3 = "tf.B"() : () -> (tensor<f32>)
  "tf.Yield"() : () -> ()
  }, {
  "tf.Yield"() : () -> ()
  }) { is_stateless = true } : (tensor<i1>) -> ()
```
### `-tf-move-transposes`: Move transposes pass. {: .hide-from-toc }

#### Options {: .hide-from-toc }
```
-fold-transpose-in-ops : Whether to fold transposes in ops which can support folding.
-direction             : Move transposes to the beginning or the end of the block where they are defined.
```
### `-tf-name-anonymous-iterators`: Converts anonymous iterators to named iterators {: .hide-from-toc }
This converts AnonymousIterator ops to Iterator, thus giving them a name.
For example, this will convert
  %0 = "tf.AnonymousIteratorV3"() {...}
to
  %0 = "tf.Iterator"() {shared_name = "_iterator1", ...}
### `-tf-optimize`: Optimize TensorFlow module {: .hide-from-toc }
### `-tf-order-by-dialect`: Reorders ops so ops of the same dialect are next to each other. {: .hide-from-toc }
Performs a reordering of ops so that
  (a) ops of the same dialect are next to each other
  (b) order within a dialect is preserved
.
For example, this would transform
  %a = "x.f"()
  %b = "y.f"(%a)
  %c = "x.f"(%a)
to
  %a = "x.f"()
  %c = "x.f"(%a)
  %b = "y.f"(%a)
so that the two "x" dialect instructions are next to each other.
### `-tf-outside-compiled-to-host-launch`: Wraps each op with the _xla_outside_compiled attribute in a separate tf_device.launch on replicated host device. {: .hide-from-toc }
This pass wraps ops with the same `_xla_outside_compilation`
attribute value in a tf_device.launch op with host device assignment.

A simple example:

```mlir
  "tf_device.cluster"() ( {
    "tf.A"()
    "tf.B"() {_xla_outside_compilation = "cluster1"}
    "tf.C"()
    tf_device.return
  }) {num_cores_per_replica = 1, topology =  "", device_assignment =  []}
```

Would become the following ops (unimportant attribute, type are omitted):

```mlir
  "tf_device.cluster"() ( {
    "tf.A"()
    "tf_device.launch"() {
      "tf.B"() {_xla_outside_compilation = "cluster1"}
      tf_device.return
    } {device = "TPU_REPLICATED_HOST"} : () -> ()
    "tf.C"()
    tf_device.return
  }) {num_cores_per_replica = 1, topology =  "", device_assignment =  []}
```
### `-tf-parallel-execute-to-islands`: Lowers device parallel_execute to executor islands {: .hide-from-toc }

#### Options {: .hide-from-toc }
```
-legacy-graph-export : Determines whether or not this pass should execute logic that is reserved for the legacy graph export pipeline to maintain expected invariants. In the case of this pass, that means manually propagating controls to lifted parallel execute regions to the graph fetch to ensure the ops execute.
```
### `-tf-promote-resources-to-args`: Promote resources reads/writes to function inputs/outputs. {: .hide-from-toc }
This pass promotes resource accesses in function(s) (by default, the main)
to input arguments and outputs of the function(s).

Two types of resources are supported:
(1) A function argument of TF::ResourceType type (this pass).
(2) A VarHandleOp in the function (tf-promote-var-handles-to-args).

After the pass,

 . The function will have an input argument for each resource that is
   already provided as an input argument or is read. The type of the input
   argument will become the shape of the value represented by the resource.

 . The function will have an output for each resource that is written. The
   type of the output will become the shape of the resource.

The information of variable identification and input-output alising is
recorded as named attributes of the input argument or output:

 . 'tf.resource_name' matches 'shared_name' of VarHandleOp, which represents
   the identifier of the corresponding resource. This attribute is added to
   an input argument if the initial value of the resource is read, or to the
   output if the initial value is not read.

 . 'tf.aliasing_output' is the index of the function output that is an alias
   of the input argument. This attribute is added only to the input argument
   when the initial value of the corresponding resource is read, and the
   resource is written later.

Assumption of this pass:
 . Compound resource operations have already been decomposed.
 . Dead functions have already been removed, as resource arguments in dead
   functions can cause the pass to fail.

#### Options {: .hide-from-toc }
```
-functions : Comma separated list of functions whose resources read/writes should be promoted to function inputs/outputs.
```
### `-tf-promote-var-handles-to-args`: Promote tf.VarHandleOps to function arguments. {: .hide-from-toc }
See joint description in promote resources to args.### `-tf-readonly-references-to-resources`: Convert readonly reference variables to resource variables.
### `-tf-region-control-flow-to-functional`: Transforms region-based control flow operations to their functional counterparts {: .hide-from-toc }
This pass transforms region-based control flow operations in the TensorFlow
dialect to their functional counterparts, i.e., `tf.IfRegion` is transformed to
`tf.If` and `tf.WhileRegion` is transformed to `tf.While`.

For example, this region-based operation

```mlir
    %0 = "tf.IfRegion"(%arg0) ( {
      %1 = call @then_branch_func(%arg1) : (tensor<*xf32>) -> tensor<*xf32>
      "tf.Yield"(%1) : (tensor<*xf32>) -> ()
    },  {
      %1 = call @else_branch_func(%arg1) : (tensor<*xf32>) -> tensor<*xf32>
      "tf.Yield"(%1) : (tensor<*xf32>) -> ()
    }) {is_stateless = false} : (tensor<i1>) -> tensor<*xf32>
```

will be transformed into this functional operation

```mlir
  %0 = "tf.If"(%arg0, %arg1) {
    then_branch = @then_branch_func, else_branch = @else_branch_func, is_stateless = false
  } : (tensor<i1>, tensor<*xf32>) -> tensor<*xf32>
```
### `-tf-remove-unused-arguments`: Removes unused args from private functions & their callers. {: .hide-from-toc }
Removes arguments from functions that aren't used in the function
body, outside of returns. Also adjusts the callers of said functions.

For example, the code
  func.func @f(%arg0, %arg1) {
    SomeOpThatUsesArg0(%arg0)
    return %arg0
  }
  ...
  call @x_1(x, y)

would be transformed into
  func.func @f(%arg0) {
    return %arg0
  }
  ...
  call @x_1(x)

Note that, in the above example, both args would be removed if there
wasn't the "SomeOpThatUsesArg0(%arg0)" line.
### `-tf-remove-unused-while-results`: Removes unused results from tf.WhileRegion ops {: .hide-from-toc }
Removes unused results from `tf.WhileRegion` ops along with the defining
ops in the body, if it is safe to do so.
Currently, the pass detects results with following properties:
- the result is unused outside of the `tf.WhileRegion` op
- the defining op of the result in the body can be safely removed
- the operand corresponding to the result is not used by any other op in
  the condition or body (in particular, there must not be intermediate
  pass-through ops like `tf.Identity`)


For example, the following pseudo MLIR code (types are left out for
brevity)
```mlir
  func.func @remove_first_result(%arg0, %arg1) {
    %0:2 = "tf.WhileRegion"(%arg0, %arg1) ({
    ^bb0(%arg2, %arg3):
      %1 = "tf.OpA"() {is_stateless = true}
      "tf.Yield"(%1)
    }, {
    ^bb0(%arg2, %arg3):
      %1 = "tf.OpB"(%arg2) {is_stateless = true}
      %2 = "tf.OpC"(%arg3) {is_stateless = true}
      "tf.Yield"(%1, %2)
    }) {is_stateless = true}
    return %0#1
  }
```
would be transformed to
```mlir
  func.func @remove_first_result(%arg0, %arg1) {
    %0 = "tf.WhileRegion"(%arg1) ({
    ^bb0(%arg3):
      %1 = "tf.OpA"() {is_stateless = true}
      "tf.Yield"(%1)
    }, {
    ^bb0(%arg3):
      %1 = "tf.OpC"(%arg3) {is_stateless = true}
      "tf.Yield"(%1)
    }) {is_stateless = true}
    return %0
  }
```
(the first result can be removed along with its defining op `tf.OpB`).

### `-tf-replica-id-to-device-ordinal`: Set device ordinal with replica id {: .hide-from-toc }
This pass sets the device ordinal attribute of the ops using the replica id
attribute. This is run immediately after the replica_to_island pass which
sets the replica id attribute of these ops. Note for single chip usecase,
the pass will check if there is one op and sets the device ordinal attribute
to be zero.
### `-tf-replicate-invariant-op-hoisting`: Hoists replicate invariant operations out of replicate {: .hide-from-toc }
This pass looks for replicate invariant ops in a `tf_device.replicate` op
region and hoists them out. It also makes `tf.Shape` ops replicate invariant
if possible. This currently updates or replaces `tf.Shape` ops of replicated
arguments, either tensors or resources.

For example, the following

```mlir
tf_device.replicate([%0, %1] as %ri: tensor<*xi32>) {n = 2 : i32} {
  %2 = "tf.Shape"(%ri) : (tensor<*xi32>) -> tensor<?xi32>
  tf_device.return
}
```

gets converted to

```mlir
tf_device.replicate([%0, %1] as %ri: tensor<*xi32>) {n = 2 : i32} {
  %2 = "tf.Shape"(%0) : (tensor<*xi32>) -> tensor<?xi32>
  tf_device.return
}
```

and for resource variables the following

```mlir
tf_device.replicate([%0, %1] as %ri: tensor<*x!tf_type.resource>) {n = 2 : i32} {
  %2 = "tf.ReadVariableOp"(%ri) : tensor<*x!tf_type.resource> -> tensor<*xi32>
  %3 = "tf.Shape"(%2) : (tensor<*xi32>) -> tensor<?xi32>
  tf_device.return
}
```

gets converted to

```mlir
tf_device.replicate([%0, %1] as %ri: tensor<*x!tf_type.resource>) {n = 2 : i32} {
  %2 = "tf.ReadVariableOp"(%ri) : tensor<*x!tf_type.resource> -> tensor<*xi32>
  %3 = "tf.VariableShape"(%0) : (tensor<*x!tf_type.resource>) -> tensor<?xi32>
  tf_device.return
}
```
### `-tf-replicate-tensor-list-init-ops`: Replicate TensorList init ops for correct shape assignments in shape inference {: .hide-from-toc }
If we pass same TensorList to a while op as multiple arguments or just use
the same TensorList at multiple places and assign different
TensorListSetItem to elements of TensorList, the shape inference is then
unable to identify the Shape of these args and thus the input TensorList
shape is unidentifiable.
All of these args are supposed to be independent and not related to original
creation of TensorList.

This pass will create multiple instances of TensorList for each arg of the
while op and each use and thus there will be not a conflict in resolving the
shape of these different inputs.
### `-tf-replicate-to-island`: Lowers device replicate to executor islands {: .hide-from-toc }

#### Options {: .hide-from-toc }
```
-legacy-graph-export : Determines whether or not this pass should execute logic that is reserved for the legacy graph export pipeline to maintain expected invariants. In the case of this pass, that means manually propagating controls to lifted parallel execute regions to the graph fetch to ensure the ops execute, as well as determining whether or not the islands created by this pass should be split after the replicated ops have been lifted.
```
### `-tf-resource-device-inference`: Propagates the device attribute on resources from callers to callees. {: .hide-from-toc }
A pass that propagates device assignment of resources on a module. It
performs in-function propagation, as well as cross-function propagation from
callers to callees.

This pass changes the module by adding "tf.device" attribute to function
arguments and adding "device" attribute to TF ops.

For example, given the function

```mlir
  !tf_res = type tensor<*x!tf_type.resource<tensor<32xf32>>>

  func @test(%arg0: !tf_res {tf.device = "/TPU:0"}) {
    tf_executor.graph {
      %control = tf_executor.island {
        %id0 = "tf.Identity"(%arg0) : (!tf_res) -> !tf_res
        tf_executor.yield
      }
      tf_executor.fetch %control : !tf_executor.control
    }
    return
  }
```

Observe how the op inside the island obtains a `/TPU:0` device assignment:

```mlir
  !tf_res = type tensor<*x!tf_type.resource<tensor<32xf32>>>

  func @test(%arg0: !tf_res {tf.device = "/TPU:0"}) {
    tf_executor.graph {
      %control = tf_executor.island {
        %0 = "tf.Identity"(%arg0) {device = "/TPU:0"} : (!tf_res) -> !tf_res
        tf_executor.yield
      }
      tf_executor.fetch %control : !tf_executor.control
    }
    return
  }
```
### `-tf-rewrite-tpu-embedding-ops`: Rewrites TPU embedding send/recv ops by adding TPU embedding deduplication data {: .hide-from-toc }
### `-tf-shape-inference`: Shape inference on TF dialect and ops implementing InferTypeOpInterface {: .hide-from-toc }
Fixed point shape refinement pass that utilizes the shape functions
registered on ops using the InferTypeOpInterface as well as by bridging to
the TensorFlow op registry's shape functions. This is an interprocedural
pass that propagates information across function calls/control flow
operations where possible (the GuaranteeAllFuncsOneUsePass is often run
before this pass to enable more propagation opportunities). It refines
both the outermost element type of tensors as well as the nested component
type (e.g., for tensor lists).

During shape refinement this pass may insert additional cast operations as
well as fold some constant shape computations to enable more exact shape
inference. Therefore it does do some mutation of the graph. Constant folding
required to produce more exact shapes is also performed but these values
are only kept in the context rather than the ops folded/IR mutated.

#### Options {: .hide-from-toc }
```
-max-iterations : Maximum shape inference iterations
```
### `-tf-simple-device-assignment`: Simple device assignment in TF dialect. {: .hide-from-toc }
Assigns the default device to all ops that have an empty (or
nonexistent) device attribute.

For example, if we have the code

```mlir
  %0 = "tf.Const"() {value = dense<[[42.0]]> : tensor<1x1xf32>} : () -> tensor<1x1xf32>
  %1 = "tf.Const"() {device = "", value = dense<[[42.0]]> : tensor<1x1xf32>} : () -> tensor<1x1xf32>
  %2 = "tf.Const"() {device = "baz", value = dense<[[42.0]]> : tensor<1x1xf32>} : () -> tensor<1x1xf32>
```

then running this pass with 'default-device=foobar', we get:

```mlir
  %0 = "tf.Const"() {device = "foobar" value = dense<[[42.0]]> : tensor<1x1xf32>} : () -> tensor<1x1xf32>
  %1 = "tf.Const"() {device = "foobar", value = dense<[[42.0]]> : tensor<1x1xf32>} : () -> tensor<1x1xf32>
  %2 = "tf.Const"() {device = "baz", value = dense<[[42.0]]> : tensor<1x1xf32>} : () -> tensor<1x1xf32>
```

#### Options {: .hide-from-toc }
```
-default-device : The default device to assign.
```
### `-tf-stack-ops-decomposition`: Decompose stack operations into local variable operations. Needs static shapes. {: .hide-from-toc }
A pass that converts stack operations to tensor operations and read/assign
ops on local variables. A later resource lifting pass can further remove the
local variables.

This pass requires that the full shape of the stack can be inferred: 1) the
maximum size needs to be a constant and 2) a push op can be found with a
known shape, and all push ops need to have the same shape.

A stack creation op "tf.StackV2" will be turned in to two zero-initialized
variables, for the buffer and current size. Each push will be turned into
```mlir
  %old_val = "tf.ReadVariableOp"(%buffer)
  %old_size = "tf.ReadVariableOp"(%size)
  %offsets = "tf.ConcatV2"(%old_size, %other_dims_0s, %const0)
  %new_val = "tf.XlaDynamicUpdateSlice"(%old_val, %push_val, %offsets)
  "tf.AssignVariableOp"(%buffer, %new_val)
  %new_size = "tf.AddV2"(%old_size, %const1)
  "tf.AssignVariableOp"(%size, %new_size)
```

and each pop will be turned into

```mlir
  %old_val = "tf.ReadVariableOp"(%buffer)
  %old_size = "tf.ReadVariableOp"(%size)
  %new_size = "tf.Sub"(%old_size, %const1)
  %offsets = "tf.ConcatV2"(%old_size, %other_dims_0s, %const0)
  %slice = "tf.Slice"(%old_val, %offsets, %slice_size_const)
  %pop_result = "tf.Reshape"(%slice, %elem_size_const)
  "tf.AssignVariableOp"(%size, %new_size)
```

The pass also works across control flow and functional calls.
### `-tf-strip-noinline-attribute`: Strip the tf._noinline attribute from top-level functions. {: .hide-from-toc }
### `-tf-strip-tf-attributes`: Removes TF specific attributes {: .hide-from-toc }
Removes attributes that are TF specific (start with "tf.") or that
have a value from the TF dialect. Useful after legalizing TF graphs
to other dialects, to remove any TF remnants.
### `-tf-tensor-array-ops-decomposition`: Decompose tensor array operations into local variable operations. {: .hide-from-toc }
A pass that converts tensor array operations to tensor operations and
read/assign ops on local variables. A later resource lifting pass can further
remove the local variables.

This pass requires that the full shape of the tensor array can be inferred:
1) the size needs to be a constant, 2) it specifies the full element shape,
or that can be inferred from a later write, and 3) all elements have the same
shape.
### `-tf-tensor-device-copy`: Fold the tf.Identity op and the tf.IdentityN op if the op has the same device as its operand {: .hide-from-toc }
### `-tf-tensor-list-ops-decomposition`: Decomposes TensorList operations into generic operations on tensors. {: .hide-from-toc }
This pass rewrites TensorList operations into generic and non-mutating
operations on tensors. This results in operations that can be legalized to XLA.

The list is converted to a single large tensor that includes all list elements,
with a new first dimension for the list index. List update operations are
converted to operations that create a new tensor representing the list.

In the current implementation, the resulting operations are statically shaped,
which means it must be possible to infer a bound on the full shape of the
TensorList. That is, the `element_shape` and `num_elements` arguments to a
tensor list creation op are constant.

A tensor list creation op `tf.EmptyTensorList`/`tf.TensorListReserve` will be
turned in to a zero-initialized buffer, and the size is initialized to 0
for `tf.EmptyTensorList` or the specified size for `tf.TensorListReserve`.
Each push will be turned into `tf.XlaDynamicUpdateSlice` with the incremented
size, and each pop will be turned into a `tf.Slice` and a copy of the buffer
with decremented size. Each `tf.TensorListSetItem` will be turned into a
`tf.XlaDynamicUpdateSlice` with unchanged size, and each `tf.TensorListGetItem`
will be rewritten to a `tf.Slice`.

The pass also works across control flow and functional calls.

For example, the TensorList ops in the following function:

```mlir
func @main(%arg0: tensor<8x4xf32>) {
  %elem_shape = "tf.Const"() {value = dense<[8, 4]> : tensor<2xi32>} : () -> tensor<2xi32>
  %max_size = "tf.Const"() {value = dense<10> : tensor<i32>} : () -> tensor<i32>
  %tl = "tf.EmptyTensorList"(%elem_shape, %max_size) : (tensor<2xi32>, tensor<i32>) -> tensor<!tf_type.variant<tensor<8x4xf32>>>
  %push = "tf.TensorListPushBack"(%tl, %arg0) : (tensor<!tf_type.variant<tensor<8x4xf32>>>, tensor<8x4xf32>) -> tensor<!tf_type.variant<tensor<8x4xf32>>>
  return
}
```

will be transformed to:

```mlir
func @main(%arg0: tensor<8x4xf32>) {
  // EmptyTensorList lowering
  %emptyi = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
  %emptyf = "tf.Cast"(%emptyi) : (tensor<i32>) -> tensor<f32>
  %size_shape = "tf.Const"() {value = dense<[10, 8, 4]> : tensor<3xi32>} : () -> tensor<3xi32>
  %tl = "tf.BroadcastTo"(%emptyf, %size_shape) : (tensor<f32>, tensor<3xi32>) -> tensor<10x8x4xf32>
  // TensorListPushBack lowering
  %index_in_list = "tf.Const"() {value = dense<0> : tensor<1xi32>} : () -> tensor<1xi32>
  %arg0_shape = "tf.Const"() {value = dense<[1, 8, 4]> : tensor<3xi32>} : () -> tensor<3xi32>
  %arg0_reshaped = "tf.Reshape"(%arg0, %arg0_shape) : (tensor<8x4xf32>, tensor<3xi32>) -> tensor<1x8x4xf32>
  %zeroi2 = "tf.Const"() {value = dense<0> : tensor<2xi32>} : () -> tensor<2xi32>
  %axis = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
  %start_indices = "tf.ConcatV2"(%index_in_list, %zeroi2, %axis) : (tensor<1xi32>, tensor<2xi32>, tensor<i32>) -> tensor<3xi32>
  %push = "tf.XlaDynamicUpdateSlice"(%tl, %arg0_reshaped, %start_indices) : (tensor<10x8x4xf32>, tensor<1x8x4xf32>, tensor<3xi32>) -> tensor<10x8x4xf32>
  %one = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
  %next_index_in_list = "tf.AddV2"(%index_in_list, %one) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
  return
}
```
### `-tf-tpu-annotate-dynamic-shape-inputs`: Annotate the inputs returned by TPUCopyWithDynamicShapeOp with dynamic shape {: .hide-from-toc }
This pass looks for the usage of the result of TPUCopyWithDynamicShapeOp
and sets the shape of these inputs to be dynamic shaped. This will ensure
that the generated HLO program is correctly reflecting the dynamic shape.
### `-tf-tpu-cleanup-cluster-attributes`: Eliminate _replication_info and other attributes from ops in a cluster {: .hide-from-toc }
This pass eliminate `_replication_info` and `device` attribute on operations
that are contained in a tf_device.cluster op.
### `-tf-tpu-cluster-formation`: Forms clusters from operations assigned to the same TPU computation {: .hide-from-toc }
TPU computations from the frontend are composed of a `tf.TPUReplicateMetadata`
op, a subgraph of ops (TensorFlow Dialect) each with a matching
`_replication_info` attribute relative to the associated
`tf.TPUReplicateMetadata` op, and optionally `tf.TPUReplicatedInput` and
`tf.TPUReplicatedOutput` ops feeding in inputs and outputs to and from a
replicated TPU computation. The number of times a TPU computation is
replicated is defined in the `tf.TPUReplicateMetadata` op (`num_replicas`
attribute) and operand and result sizes of `tf.TPUReplicatedInput` and
`tf.TPUReplicatedOutput` respectively must match, excluding packed tensors.
It is also assumed ops of the same TPU computation do not have ops outside
of the TPU computation that are both inputs and outputs to the same TPU
computation. Furthermore, we assume that every node has either none or both
of `_replication_info` and `_xla_compile_device_type` attributes defined.

This pass takes the TPU computation subgraph, moves them into a
`tf_device.cluster`, and copies over attributes from the associated
`tf.TPUReplicateMetadata` op to the newly created `tf_device.cluster`. If the
computation is replicated (`num_replicas` > 1), the `num_replicas` attribute is
not copied over but instead the `tf_device.cluster` is further wrapped with a
`tf_device.replicate`, and associated `tf.TPUReplicatedInput` and
`tf.TPUReplicatedOutput` ops are replaced as the `tf_device.replicate` operands
and results. Otherwise, the single operands and results of the associated
`tf.TPUReplicatedInput` and `tf.TPUReplicatedOutput` ops are simply forwarded to
the `tf_device.cluster`.

For example, the following non replicated computation:

```mlir
func @tpu_computation(%arg0: tensor<i32>) -> tensor<i32> {
  // Metadata op for cluster `cluster` with 1 replica, 1 core per replica and
  // with topology `<topology>`.
  "tf.TPUReplicateMetadata"() {_xla_compile_device_type = "TPU", _replication_info = "cluster", num_relicas = 1, num_cores_per_replica = 1, topology = "<topology>", device_assignment = [], padding_map = []} : () -> ()
  %replicated_input = "tf.TPUReplicatedInput"(%arg0) : (tensor<i32>) -> tensor<i32>
  %identity = "tf.Identity"(%replicated_input) {_xla_compile_device_type = "TPU", _replication_info = "cluster"} : (tensor<i32>) -> tensor<i32>
  %replicated_output = "tf.TPUReplicatedOutput(%identity) : (tensor<i32>) -> tensor<i32>
  return %replicated_output : tensor<i32>
}
```

will be transformed into:

```mlir
func @tpu_computation(%arg0: tensor<i32>) -> tensor<i32> {
  %cluster = "tf_device.cluster"() ( {
    %identity = "tf.Identity"(%arg0) : (tensor<i32>) -> tensor<i32>
    tf_device.return %identity : tensor<i32>
  }) {_xla_compile_device_type = "TPU", _replication_info = "cluster", num_cores_per_replica = 1, topology = "topology", device_assignment = [], padding_map = []} : () -> (tensor<i32>)
  return %cluster : tensor<i32>
}
```

The following replicated computation:

```mlir
func @tpu_computation(%arg0: tensor<i32>, %arg1: tensor<i32>) -> (tensor<i32>, tensor<i32>) {
  "tf.TPUReplicateMetadata"() {_xla_compile_device_type = "TPU", _replication_info = "cluster", num_relicas = 2, num_cores_per_replica = 1, topology = "topology", device_assignment = [], padding_map = []} : () -> ()
  %replicated_input = "tf.TPUReplicatedInput"(%arg0, %arg1) : (tensor<i32>, tensor<i32>) -> tensor<i32>
  %identity = "tf.Identity"(%replicated_input) {_xla_compile_device_type = "TPU", _replication_info = "cluster"} : (tensor<i32>) -> tensor<i32>
  %replicated_output:2 = "tf.TPUReplicatedOutput(%identity) : (tensor<i32>) -> (tensor<i32>, tensor<i32>)
  return %replicated_output#0, %replicated_output#1 : tensor<i32>, tensor<i32>
}
```

will be transformed into:

```mlir
func @tpu_computation(%arg0: tensor<i32>, %arg1: tensor<i32>) -> (tensor<i32>, tensor<i32>) {
  %replicate:2 = tf_device.replicate([%arg0, %arg1] as %replicated_input) {n = 2 : i32} {
    %cluster = "tf_device.cluster"() ( {
      %identity = "tf.Identity"(%replicated_input) : (tensor<i32>) -> tensor<i32>
      tf_device.return %identity : tensor<i32>
    }) {_xla_compile_device_type = "TPU", _replication_info = "cluster", num_cores_per_replica = 1, topology = "topology", device_assignment = [], padding_map = []} : () -> (tensor<i32>)
    tf_device.return %cluster : tensor<i32>
  }
  return %replicate#0, %replicate#1 : tensor<i32>, tensor<i32>
}
```
### `-tf-tpu-colocate-composite-resource-ops`: Colocate resource with composite device assignment to TPU device. {: .hide-from-toc }
Pass that co-locates resource ops that use composite device resources
(packed tensors) with the underlying physical TPU device.

So for example, if we have a function that does (inside a `tf_device.replicate`):

```mlir
  %0 = "tf.ReadVariableOp"(%arg1) : (tensor<*x!tf_type.resource<tensor<4xf32>>>) -> tensor<4xf32>
```

Then said `ReadVariableOp` is going to get replaced by:

```mlir
  %0 = "tf_device.launch"() ( {
    %2 = "tf.ReadVariableOp"(%arg1) : (tensor<*x!tf_type.resource<tensor<4xf32>>>) -> tensor<4xf32>
    tf_device.return %2 : tensor<4xf32>
  }) {...} : () -> tensor<4xf32>
```
### `-tf-tpu-device-propagation`: Propagates TPU devices from ops to users {: .hide-from-toc }
### `-tf-tpu-dynamic-layout-pass`: Inserts TPU layout ops to determine layout at run time. {: .hide-from-toc }
A pass that allows TPU input layout to be determined after JIT compilation.
This is done by adding run-time ops that interpret compilation result and
copy the input to device with that layout.

Example: original program:

```mlir
  %input = "tf.IteratorGetNext"(...) {device = "/CPU:0"}
  %compile:2 = "tf._TPUCompileMlir"(...)
  %execute = "tf.TPUExecute"(%input, ..., %compile#1) {device = "/TPU:0"}
```

Without this pass, later TF graph partitioning passes will insert send/recv
between %input and %execute and data will be copied to device in a fixed
layout. With this pass, the program will be transformed into:

```mlir
  %input = "tf.IteratorGetNext"(...) {device = "/CPU:0"}
  %compile:2 = "tf._TPUCompileMlir"(...)
  %get_layout = "tf.TPUGetLayoutOp"(%compile#1) {...}
  %copy_to_device = "tf.TPUCopyWithLayout"(%input, %get_layout)
      {device = "/TPU:0"}
  %execute = "tf.TPUExecute"(%copy_to_device, ..., %compile#1)
      {device = "/TPU:0"}
```

This way, %compile will determine the layout, which will be respected by
%copy_to_device. There will not be send/recv ops added by later passes,
because tf.TPUCopyWithLayout accepts a host input and produces a device
output.
### `-tf-tpu-host-computation-expansion`: Expands host computation before and after TPU computation. {: .hide-from-toc }
This pass expands outside compilation attributes to Identity/Cast ops
at the head of TPU computation if it's only used by outside compiled ops.
### `-tf-tpu-identity-pruning`: Removes Identity/IdentityN ops from the TPU computation {: .hide-from-toc }
### `-tf-tpu-merge-variables-with-execute`: Merges device variable reads and updates into TPU execute ops {: .hide-from-toc }
This pass finds on-device resource variable reads and updates surrounding a
`tf.TPUExecute` op and merges them into a `tf.TPUExecuteAndUpdateVariables`
op. This allows the TPU execution to perform more efficient in-place
variable updates.

For example,

```mlir
  %0 = "tf.ReadVariableOp"(%arg0)
  %1 = "tf.ReadVariableOp"(%arg1)
  %2 = "tf.TPUExecute"(%0, %1, %compile)
  %3 = "tf.AssignVariableOp"(%arg0, %2)
```

will be transformed into

```mlir
  %2 = "tf.TPUExecuteAndUpdateVariables"(%arg0, %arg1, %compile)
    { device_var_reads_indices = [0, 1],
      device_var_updates_indices = [0, -1] }
````

The transformation happens only for on-device variables. The above
transformation requires `%arg0`, `%arg1` to have the same device assignment
as the `TPUExecute` op.
### `-tf-tpu-parallel-execute-sink-resource-write`: Moves tf.AssignVariableOp consumers of tf_device.parallel_execute into tf_device.parallel_execute regions {: .hide-from-toc }
### `-tf-tpu-partitioned-op-conversion`: Rewrite all TPU Partitioned ops into their V2 counterparts. {: .hide-from-toc }
### `-tf-tpu-reorder-replicate-partitioned-inputs`: Reorder replicated and partitioned input ops. {: .hide-from-toc }
This pass rewrites how data parallelism and model parallelism is expressed for
inputs. It reorders `tf.TPUPartitionedInput` (model parallelism) and
`tf.TPUReplicatedInput` (data parallelism) ops. It transforms a DAG where
multiple `tf.TPUPartitionedInput` ops are feeding into a single
`tf.TPUReplicatedInput` into a DAG where multiple `tf.TPUReplicatedInput` ops
are feeding into a single `tf.TPUPartitionedInput`. Transforming the IR in such
a manner will allow subsequent cluster formation pass to handle IR with both
data and model parallelism in an easier manner.

For example, the following:

```mlir
!rtype = type tensor<!tf_type.resource<tensor<10x3xf32>>>
func @data_and_model_parallelism(%arg0: !rtype, %arg1: !rtype, %arg2: !rtype, %arg3: !rtype) -> !rtype {
  %pi_0 = "tf.TPUPartitionedInput"(%arg0, %arg1) {_XlaSharding = "", device = "", partition_dim = -1 : i64} : (!rtype, !rtype) -> !rtype
  %pi_1 = "tf.TPUPartitionedInput"(%arg2, %arg3) {_XlaSharding = "", device = "", partition_dim = -1 : i64} : (!rtype, !rtype) -> !rtype
  %ri = "tf.TPUReplicatedInput"(%pi_0, %pi_1) : (!rtype, !rtype) -> !rtype
  return %ri : !rtype
}
```

will be transformed into:

```mlir
!rtype = type tensor<!tf_type.resource<tensor<10x3xf32>>>
func @data_and_model_parallelism(%arg0: !rtype, %arg1: !rtype, %arg2: !rtype, %arg3: !rtype) -> !rtype {
  %ri_0 = "tf.TPUReplicatedInput"(%arg0, %arg2) : (!rtype, !rtype) -> !rtype
  %ri_1 = "tf.TPUReplicatedInput"(%arg1, %arg3) : (!rtype, !rtype) -> !rtype
  %pi = "tf.TPUPartitionedInput"(%ri_0, %ri_1) {_XlaSharding = "", device = "", partition_dim = -1 : i64} : (!rtype, !rtype) -> !rtype
  return %pi : !rtype
}
```
### `-tf-tpu-resource-partition`: Partitions unpartitioned resource read/write to partitioned resource variables. {: .hide-from-toc }
This pass creates individual resource reads/writes from the unpartitioned
resource variable (from `tf.TPUPartitionedInput`) to individual partitioned
resource variables (`tf.TPUPartitionedInput` operands). As resource op
decomposition/lifting occurs with the unpartitioned resource variables,
transforming the IR in such a manner will allow for subsequent passes to operate
on individual resource variable handles per core/device.

For example, the following:

```mlir
func @cluster(%arg0: tensor<!tf_type.resource<tensor<i32>>>, %arg1: tensor<!tf_type.resource<tensor<i32>>>) {
  %partitioned_variable = "tf.TPUPartitionedInput"(%arg0, %arg1) {N = 2 : i64, _XlaSharding = "", partition_dim = -1 : i64} : (tensor<!tf_type.resource<tensor<i32>>>, tensor<!tf_type.resource<tensor<i32>>>) -> tensor<!tf_type.resource<tensor<i32>>>
  %read = "tf.ReadVariableOp"(%partitioned_variable) : (tensor<!tf_type.resource<tensor<i32>>>) -> tensor<i32>
  %computation = "tf_device.cluster_func"(%read) {func = @computation, use_spmd_for_xla_partitioning = true} : (tensor<i32>) -> tensor<i32>
  "tf.AssignVariableOp"(%partitioned_variable, %computation) : (tensor<!tf_type.resource<tensor<i32>>>, tensor<i32>) -> ()
  return
}

func @computation(%arg0: tensor<i32>) -> tensor<i32> {
  return %arg0: tensor<i32>
}
```

will be transformed into:

```mlir
func @cluster(%arg0: tensor<!tf_type.resource<tensor<i32>>>, %arg1: tensor<!tf_type.resource<tensor<i32>>>) {
  %read0 = "tf.ReadVariableOp"(%arg0) : (tensor<!tf_type.resource<tensor<i32>>>) -> tensor<i32>
  %read1 = "tf.ReadVariableOp"(%arg1) : (tensor<!tf_type.resource<tensor<i32>>>) -> tensor<i32>
  %partitioned_input = "tf.TPUPartitionedInput"(%read0, %read1) {N = 2 : i64, _XlaSharding = "", partition_dim = -1 : i64} : (tensor<i32>, tensor<i32>) -> tensor<i32>
  %computation = "tf_device.cluster_func"(%partitioned_input) {func = @computation, use_spmd_for_xla_partitioning = true} : (tensor<i32>) -> tensor<i32>
  %partitioned_output:2 = "tf.TPUPartitionedOutput"(%computation) {N = 2 : i64, _XlaSharding = "", partition_dim = -1 : i64} : (tensor<i32>) -> (tensor<i32>, tensor<i32>)
  "tf.AssignVariableOp"(%arg0, %partitioned_output#0) : (tensor<!tf_type.resource<tensor<i32>>>, tensor<i32>) -> ()
  "tf.AssignVariableOp"(%arg1, %partitioned_output#1) : (tensor<!tf_type.resource<tensor<i32>>>, tensor<i32>) -> ()
  return
}

func @computation(%arg0: tensor<i32>) -> tensor<i32> {
  return %arg0: tensor<i32>
}
```
### `-tf-tpu-resource-read-for-write`: Inserts tf.ReadVariableOp inputs to a TPU cluster for resource writes with no reads {: .hide-from-toc }
This pass materializes `tf.ReadVariableOp` inputs to an outlined TPU computation
for resource variables where only writes are present so later in the pipeline
such resource variables can be fused with generated `tf.TPUExecute` ops, which
only supports resource variable read or read + write. For all TPU computations,
resource variables are required to be initialized prior to execution. Write only
resource variable uses can be generated currently via packed tensor uses.

For example, the following:

```mlir
func @write_only_resource(%value: tensor<i32>, %resource: tensor<*x!tf_type.resource<tensor<i32>>>) {
  %0 = "tf_device.cluster_func"(%value) {func = @cluster} : (tensor<i32>) -> tensor<i32>
  "tf.AssignVariableOp"(%resource, %0) : (tensor<*x!tf_type.resource<tensor<i32>>>, tensor<i32>) -> ()
  return
}

func @cluster(%arg0: tensor<i32>) -> tensor<i32> {
  %identity = "tf.Identity"(%arg0) : (tensor<i32>) -> tensor<i32>
  return %identity : tensor<i32>
}
```

will be transformed into:

```mlir
func @write_only_resource(%value: tensor<i32>, %resource: tensor<*x!tf_type.resource<tensor<i32>>>) {
  %resource_read = "tf.ReadVariableOp"(%resource) : (tensor<*x!tf_type.resource<tensor<i32>>>) -> tensor<i32>
  %0 = "tf_device.cluster_func"(%value, %resource_read) {func = @cluster} : (tensor<i32>, tensor<i32>) -> tensor<i32>
  "tf.AssignVariableOp"(%resource, %0) : (tensor<*x!tf_type.resource<tensor<i32>>>, tensor<i32>) -> ()
  return
}

func @cluster(%arg0: tensor<i32>, %arg1: tensor<i32>) -> tensor<i32> {
  %identity = "tf.Identity"(%arg0) : (tensor<i32>) -> tensor<i32>
  return %identity : tensor<i32>
}
```
### `-tf-tpu-rewrite`: Rewrites a `tf_device.cluster_func` on TPUs into TPU runtime operations. {: .hide-from-toc }
This pass rewrites a `tf_device.cluster_func` operation into a sequence of `tf._TPUCompileMlir`
and `tf.TPUExecute` operations. `tf._TPUCompileMlir` contains a MLIR module that is
functionally equivalent to the function referenced by `tf_device.cluster_func`.
This makes the module to be jit-compiled and executed on TPU.
If it is not possible to rewrite the operation or device assignment fails,
a failure will be returned.

Note, many parameters to the `tf_device.cluster_func` are omitted in this
and following examples.
For example, a non replicated `tf_device.cluster_func`:

```mlir
func @tf_tpu_rewrite(%arg0: tensor<i8>) {
  %0 = "tf_device.cluster_func"(%arg0) {_xla_compile_device_type = "TPU", _replication_info = "cluster0", func = @func} : (tensor<i8>) -> tensor<i8>
  return
}
```

will be rewritten as:

```mlir
func @tf_tpu_rewrite(%arg0: tensor<i8>) {
  %0:2 = "tf_device.launch"() ( {
    %compilation_status, %program = "tf._TPUCompileMlir"() {mlir_module = "<serialized func>"} : () -> (tensor<!tf_type.string>, tensor<3x!tf_type.string>)
    tf_device.return %compilation_status, %program : tensor<!tf_type.string>, tensor<3x!tf_type.string>
  }) {device = "/job:worker/replica:0/task:0/device:CPU:0"} : () -> (tensor<!tf_type.string>, tensor<3x!tf_type.string>)
  "tf_device.launch"() ( {
    "tf.TPUCompileSucceededAssert"(%0#0) : (tensor<!tf_type.string>) -> ()
    tf_device.return
  }) {device = "/job:worker/replica:0/task:0/device:CPU:0"} : () -> ()
  %1 = "tf_device.launch"() ( {
    %2 = "tf.TPUExecute"(%arg0, %0#1) : (tensor<i8>, tensor<3x!tf_type.string>) -> tensor<i8>
    tf_device.return %2 : tensor<i8>
  }) {device = "/job:worker/replica:0/task:0/device:TPU:0"} : () -> tensor<i8>
  return
}
```

A replicated `tf_device.cluster_func`:

```mlir
func @tf_tpu_rewrite(%arg0: tensor<i8>, %arg1: tensor<i8>) {
  %0:2 = tf_device.replicate([%arg0, %arg1] as %ri: tensor<i8>) {n = 2 : i32} {
    %1 = "tf_device.cluster_func"(%ri) {_xla_compile_device_type = "TPU", _replication_info = "cluster0", func = @func} : (tensor<i8>) -> tensor<i8>
    tf_device.return %1 : tensor<i8>
  }
  return
}
```

will be rewritten as:

```mlir
func @tf_tpu_rewrite(%arg0: tensor<i8>, %arg1: tensor<i8>) {
  %0:2 = tf_device.replicate([%arg0, %arg1] as %arg2: tensor<i8>) {devices = {TPU_REPLICATED_CORE_0 = ["/job:worker/replica:0/task:0/device:TPU:0", "/job:worker/replica:0/task:0/device:TPU:1"], TPU_REPLICATED_HOST = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:CPU:0"]}, n = 2 : i32} {
    %1:2 = "tf_device.launch"() ( {
      %compilation_status, %program = "tf._TPUCompileMlir"() {mlir_module = "<serialized func>"} : () -> (tensor<!tf_type.string>, tensor<3x!tf_type.string>)
      tf_device.return %compilation_status, %program : tensor<!tf_type.string>, tensor<3x!tf_type.string>
    }) {device = "/job:worker/replica:0/task:0/device:CPU:0"} : () -> (tensor<!tf_type.string>, tensor<3x!tf_type.string>)
    "tf_device.launch"() ( {
      "tf.TPUCompileSucceededAssert"(%1#0) : (tensor<!tf_type.string>) -> ()
      tf_device.return
    }) {device = "/job:worker/replica:0/task:0/device:CPU:0"} : () -> ()
    %2 = "tf_device.launch"() ( {
      %3 = "tf.TPUExecute"(%arg2, %1#1) : (tensor<i8>, tensor<3x!tf_type.string>) -> tensor<i8>
      tf_device.return %3 : tensor<i8>
    }) {device = "TPU_REPLICATED_CORE_0"} : () -> tensor<i8>
    tf_device.return %2 : tensor<i8>
  }
  return
}
```

A non replicated `tf_device.cluster_func` with the model parallelism:

```mlir
func @tf_tpu_rewrite(%arg0: tensor<8xi32>) -> tensor<8xi32> {
  %0 = "tf_device.cluster_func"(%arg0) {_xla_compile_device_type = "TPU", _replication_info = "cluster0", func = @func, num_cores_per_replica = 2, input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor<8xi32>) -> tensor<8xi32>
  return %0 : tensor<8xi32>
}
```

will be rewritten as:

```mlir
func @tf_tpu_rewrite(%arg0: tensor<8xi32>) -> tensor<8xi32> {
  %0:3 = "tf_device.launch"() ( {
    %compilation_status, %program:2 = "tf._TPUCompileMlir"() {mlir_module = "<serialized func>"} : () -> (tensor<!tf_type.string>, tensor<3x!tf_type.string>, tensor<3x!tf_type.string>)
    tf_device.return %compilation_status, %program#0, %program#1 : tensor<!tf_type.string>, tensor<3x!tf_type.string>, tensor<3x!tf_type.string>
  }) {device = "/job:localhost/replica:0/task:0/device:CPU:0"} : () -> (tensor<!tf_type.string>, tensor<3x!tf_type.string>, tensor<3x!tf_type.string>)
  "tf_device.launch"() ( {
    "tf.TPUCompileSucceededAssert"(%0#0) : (tensor<!tf_type.string>) -> ()
    tf_device.return
  }) {device = "/job:localhost/replica:0/task:0/device:CPU:0"} : () -> ()
  %1 = "tf_device.parallel_execute"() ( {
    %2 = "tf_device.launch"() ( {
      %3 = "tf.TPUExecute"(%arg0, %0#1) : (tensor<8xi32>, tensor<3x!tf_type.string>) -> tensor<8xi32>
      tf_device.return %3 : tensor<8xi32>
    }) {device = "/job:localhost/replica:0/task:0/device:TPU:0"} : () -> tensor<8xi32>
    tf_device.return %2 : tensor<8xi32>
  },  {
    "tf_device.launch"() ( {
      "tf.TPUExecute"(%0#2) : (tensor<3x!tf_type.string>) -> ()
      tf_device.return
    }) {device = "/job:localhost/replica:0/task:0/device:TPU:1"} : () -> ()
    tf_device.return
  }) : () -> tensor<8xi32>
  return %1 : tensor<8xi32>
}
```

#### Options {: .hide-from-toc }
```
-tpu-compile-metadata-debug : Whether to serialize TPUCompileMetadataProto metadata in 'tf._TPUCompileMlir' op as a proto debug string
```
### `-tf-tpu-sharding-identification`: Identifies and handles inputs/outputs of TPU computation that is sharded across logical cores. {: .hide-from-toc }
Bubbles up sharding configuration from `cluster_func` regions into
the attributes of `cluster_func`. This is done by parsing the
`XlaSharding` / `TPUPartitionedOutput` / `TPUPartitionedInput` ops inside
`cluster_func`.

For example, given the following `cluster_func` wrapping `func`:

```mlir
  func @test(%arg0: tensor<*xi32>) {
    "tf_device.cluster_func"(%arg0) {
        func = @func,
        step_marker_location = ""} : (tensor<*xi32>) -> tensor<*xi32>
    return
  }

  func @func(%arg0: tensor<*xi32>) -> tensor<*xi32> {
    %0 = "tf.XlaSharding"(%arg0) {_XlaSharding = "\01\02\03",
                                  sharding = "\01\02\03"} : (tensor<*xi32>) -> tensor<*xi32>
    %1 = "tf.A"(%0) : (tensor<*xi32>) -> (tensor<*xi32>)
    return %1 : tensor<*xi32>
  }
```

Now, cluster_func receives the following `*_sharding_configuration`
attributes, and `func` receives the mhlo.sharding attribute:

```mlir
  func @test(%arg0: tensor<*xi32>) {
    %0 = "tf_device.cluster_func"(%arg0) {
        func = @func,
        input_sharding_configuration = ["\01\02\03"],
        output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"],
        step_marker_location = ""} : (tensor<*xi32>) -> tensor<*xi32>
    return
  }
  func @func(%arg0: tensor<*xi32> {mhlo.sharding = "\01\02\03"}) ->
            (tensor<*xi32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}) {
    %0 = "tf.XlaSharding"(%arg0) {_XlaSharding = "\01\02\03", sharding = "\01\02\03"} : (tensor<*xi32>) -> tensor<*xi32>
    %1 = "tf.A"(%0) : (tensor<*xi32>) -> tensor<*xi32>
    return %1 : tensor<*xi32>
  }
```
### `-tf-tpu-space-to-depth-pass`: Applies automatic space to depth transform for the first or frontier convolutions consume host inputs on TPU. {: .hide-from-toc }
Automatic space to depth transform is done by adding space to depth transform op after host input
and applying space to depth transform for the first convolution and its backprop filter on TPU.

For example, original program:

```mlir
module {
  func @while_body {
    %input = "tf.IteratorGetNext"(...) {device = "/CPU:0"}: -> tensor<2x224x224x3xf32>
    %device_launch = "tf_device.cluster_func"(%input,...) {func = @_func,...)
    return ...
  }
  func @_func(%input: tensor<2x224x224x3xf32>, %filter: tensor<7x7x3x64xf32>) {
    %6 = "tf.Conv2D"(%input, %filter)  {strides = [1, 2, 2, 1]}: (tensor<2x230x230x3xf32>, tensor<7x7x3x64xf32>) -> tensor<2x112x112x64xf32>
  }
}
```

The program will be transformed into:

```mlir
module {
  func @while_body {
    %input = "tf.IteratorGetNext"(...) {device = "/CPU:0"} -> tensor<2x224x224x3xf32>
    %space_to_depth = "tf.SpaceToDepth"(%input) {block_size = 2, ...}: (tensor<2x224x224x3xf32>) -> tensor<2x112x112x12xf32>
    %device_launch = "tf_device.cluster_func"(%space_to_depth,...) {func = @_func,...)
    return ...
  }
  func @_func(%input: tensor<2x112x112x12xf32>, %filter: tensor<7x7x3x64xf32>) {
    %filter_transform = "tf.Pad/tf.Transpose/tf.Reshape"(%filter): tensor<7x7x3x64xf32>) -> tensor<4x4x12x64xf32>
    %conv = "tf.Conv2D"(%input, %filter_transfrom) {strides = [1, 1, 1, 1]}: (tensor<2x112x112x12xf32>, tensor<4x4x12x64xf32>) -> tensor<2x112x112x64xf32>
  }
}
```

This way, the first convolution with 3 feature dimension will be transformed
to 12 feature dimension, which has better performance on TPU.
### `-tf-tpu-update-embedding-enqueue-op-inputs`: Updates inputs to TPU embedding enqueue ops depending on whether graph is in training mode or in evaluation mode. {: .hide-from-toc }
Updates inputs to TPU embedding enqueue ops depending on whether graph
is in training mode or in evaluation mode.
### `-tf-tpu-validate-inputs`: Validates inputs to the TPU TF/XLA bridge {: .hide-from-toc }
This pass checks that the IR has valid input to TPU TF/XLA bridge.
It checks the relations of multiple ops. Properties of single ops are
checked by the 'verify' method of ops.
### `-tf-tpu-variable-runtime-reformatting`: Adds device variable formatting op to allow compilation-guided variable formatting. {: .hide-from-toc }
A pass that takes advantage of a loop to add ops that allow the execution to
avoid repeatedly formatting variables back and forth. The desired formatting
is determined by TPU program compilation, so this pass does not include how
to reformat the variables, but only inserts general TPUReshardVariablesOps in
proper places, and TPUReshardVariablesOps interpret the compilation.

The core idea of this optimization is to keep track of the formatting state
of variables, and when the next desired state does not change, it can avoid
reformatting. We associate a set of variables on a device with a formatting
state, and TPUReshardVariablesOps compares the current state with a desired
state (which can be the compilation result). If they mismatch,
TPUReshardVariablesOp reformats the variables to the desired state; if they
match, TPUReshardVariablesOp is a no-op.

A major use of this pass is weight-update sharding in data parallelism, so we
require there is a tf_device.replicate in the loop.

For example, suppose we have a training loop (for simplicity we write the
loop body inine):

```mlir
  %var0 = ...
  %var1 = ...
  tf.while (..., %var0, %var1) {
    tf_device.replicate ([%var0, %var1] as %rvar) {
      %compile:2 = "tf._TPUCompileMlir"()
      tf.TPUExecuteAndUpdateVariablesOp(%rvar, compile#1)
    }
  }
```

This pass will transform it into

```mlir
  %var0 = ...
  %var1 = ...
  %state_var0 = ...
  %state_var1 = ...
  tf.while (..., %var0, %var1, %state_var0, %state_var1) {
    tf_device.replicate ([%var0, %var1] as %rvar,
                         [%state_var0, %state_var1] as %rstate) {
      %compile:2 = "tf._TPUCompileMlir"()
      tf.TPUReshardVariablesOp(%rvar, %compile#1, %rstate)
      tf.TPUExecuteAndUpdateVariablesOp(%rvar, compile#1)
    }
  }
  %default_format = tf.constant()
  tf_device.replicate ([%var0, %var1] as %rvar,
                       [%state_var0, %state_var1] as %rstate) {
    tf.TPUReshardVariablesOp(%rvar, %default_format, %rstate)
  }
```
### `-tf-unroll-batch-matmul`: Unroll TF BatchMatMul op into Reshape, Slice, MatMul, Pack ops. {: .hide-from-toc }
### `-tf-verify-for-export`: Verify module is suitable for export back to TF Graph {: .hide-from-toc }
Verifies whether all functions in module are of single tf_executor.graph and
each tf_executor.island in tf_executor.graph only has a single op.
### `-tfe-legalize-tfg`: Legalize from TFG to the TFE dialect {: .hide-from-toc }