xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/mlir/tfrt/transforms/passes.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_PASSES_H_
17 #define TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_PASSES_H_
18 
19 #include <memory>
20 
21 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
22 #include "mlir/Pass/Pass.h"  // from @llvm-project
23 #include "mlir/Transforms/DialectConversion.h"  // from @llvm-project
24 #include "tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.h"
25 #include "tensorflow/compiler/mlir/tfrt/transforms/tpu_passes.h"
26 
27 namespace mlir {
28 class PassManager;
29 }
30 
31 namespace tensorflow {
32 
33 namespace tfrt_compiler {
34 
35 // Create a pass to insert kernels that copy fallback tensors when they are
36 // passed to multiple threads, to avoid atomic contention on their refcounts.
37 std::unique_ptr<mlir::OperationPass<mlir::func::FuncOp>>
38 CreateInsertFallbackTensorCopyPass();
39 
40 // Create a pass to reorder tf.Assert ops or tf.If ops that contains only
41 // tf.Assert ops to the end of the function, to avoid unnecessary control
42 // dependencies to other ops.
43 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
44 CreateReorderTfAssertPass();
45 
46 // Create a pass to optimize the side-effect of control flow ops. eg. if both
47 // branches of a tf.If op contains only non-side-effecting ops, its
48 // `is_stateless` attribute will be set to true.
49 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
50 CreateOptimizeTfControlFlowSideEffectPass();
51 
52 // Create a pass to remove tf.If ops' operands that are produced by tf.Const
53 // ops.
54 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
55 CreateRemoveTfIfConstArgsPass();
56 
57 // Create a pass to merge non-side-effecting tf.If ops that have the same
58 // operands.
59 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>> CreateMergeTfIfOpsPass();
60 
61 // Create a pass to deduplicate the function invoked by tf.BatchFunction with
62 // the same shared_name.
63 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
64 CreateDeduplicateFunctionsInovkedByBatchFunctionPass();
65 
66 // Create a pass to fuse the TPU Ops for TFRT.
67 std::unique_ptr<mlir::OperationPass<mlir::func::FuncOp>>
68 CreateFuseTpuCompileAndExecutePass();
69 
70 // Create a pass to optimize TF dialect for TFRT workflow.
71 std::unique_ptr<mlir::OperationPass<mlir::func::FuncOp>>
72 CreateOptimizeTfForTfrtPass();
73 
74 }  // namespace tfrt_compiler
75 
76 class CoreRTConverter;
77 
78 // Create a pass that rewrites tf_saved_model dialect's ops according to TFRT's
79 // requirements.
80 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
81 CreateLowerTFSavedModelPass(bool hoist_invariant_ops);
82 
83 // Create a pass that converts ref variables to resource variables in a limited
84 // number of cases.
85 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
86 CreateConvertReferenceVariableToResourceVariablePass();
87 
88 // Run *ToCoreRTConversionPassRun as free functions. Useful for
89 // reusing the pass logic in a custom pass with additional conversions.
90 mlir::LogicalResult TFSavedModelToCoreRTConversionPassRun(
91     mlir::MLIRContext* context, mlir::func::FuncOp func,
92     mlir::ConversionTarget* target, mlir::RewritePatternSet* patterns,
93     CoreRTConverter* corert_converter);
94 
95 // Create an operation pass that converts each tfrt_dist.remote_execute_func op
96 // into a combination of tfrt_dist.register_tfrt_function op and
97 // tfrt_dist.remote_execute op.
98 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
99 CreateDistRemoteRunEncapsulatePass();
100 
101 // Create an operation pass that removes the device attribute from every
102 // corert.executeop.
103 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
104 CreateRemoveDeviceAttributePass();
105 
106 // Create an operation pass that inserts corert.transfer op to make sure any
107 // argument of any op is on the same device of the op itself.
108 std::unique_ptr<mlir::OperationPass<mlir::func::FuncOp>>
109 CreateCrossDeviceTransferPass();
110 
111 struct TfrtPipelineOptions
112     : public mlir::PassPipelineOptions<TfrtPipelineOptions> {
113   Option<std::string> default_device{
114       *this, "default-device", llvm::cl::desc("default device assignment"),
115       llvm::cl::init("/job:localhost/replica:0/task:0/device:CPU:0")};
116   Option<bool> enable_optimizer{
117       *this, "enable-optimizer",
118       llvm::cl::desc("run optimization passes on corert dialect"),
119       llvm::cl::init(false)};
120   Option<bool> decompose_resource_ops{
121       *this, "decompose-resource-ops",
122       llvm::cl::desc("decompose composite resource ops into ReadVariableOp and "
123                      "non-resource ops. This is currently used in TFRT "
124                      "savedmodel pipeline."),
125       llvm::cl::init(false)};
126   Option<std::string> force_data_format{
127       *this, "force-data-format",
128       llvm::cl::desc("force data format for all layout sensitive operations")};
129   // TODO(tfrt-devs): consider making compiler to figure out whether to fold
130   // transpose or not instead of exposing the specific option.
131   Option<bool> skip_fold_transpose_in_ops{
132       *this, "skip-fold-transpose-in-ops",
133       llvm::cl::desc("Skip folding transpose operands in Ops which can support "
134                      "different layouts.")};
135   Option<bool> target_tpurt{*this, "target-tpurt",
136                             llvm::cl::desc("target TPURT dialect if true"),
137                             llvm::cl::init(false)};
138   Option<bool> tpu_use_core_selector{
139       *this, "tpu-use-core-selector",
140       llvm::cl::desc("If true, use ServingCoreSelector to pick TPU core. "
141                      "Otherwise, use the assigned core. Currently we use "
142                      "core selector for Servo serving use cases."),
143       llvm::cl::init(true)};
144   Option<bool> tpu_use_bundled_transfer{
145       *this, "tpu-use-bundled-transfer",
146       llvm::cl::desc("If true, use BundledTransferToTpuOp to transfer "
147                      "variables and input tensors to TPU."),
148       llvm::cl::init(true)};
149   Option<bool> tpu_lower_to_fallback{
150       *this, "tpu-lower-to-fallback",
151       llvm::cl::desc("If true, lower an TF op that's placed on TPU device "
152                      "to be executed by tfrt_fallback.execute."),
153       llvm::cl::init(true)};
154   Option<bool> tpu_fuse_ops{
155       *this, "tpu-fuse-ops",
156       llvm::cl::desc("If true, use the TPU fused compile_and_execute kernel"),
157       llvm::cl::init(false)};
158   // TODO(b/194081364): remove this option once we unify servo TPU serving
159   // result transfer behavior.
160   Option<bool> tpu_transfer_result_to_host{
161       *this, "tpu-transfer-result-to-host",
162       llvm::cl::desc("If true, transfer the result of tpurt.execute from TPU "
163                      "to host."),
164       llvm::cl::init(true)};
165   Option<bool> use_tpu_host_allocator_for_inputs{
166       *this, "use-tpu-host-allocator-for-inputs",
167       llvm::cl::desc("If true, fallback executeops that produce inputs to tpu "
168                      "program will use tpu host allocator."),
169       llvm::cl::init(false)};
170   Option<bool> enable_native_ops{
171       *this, "enable-native-ops",
172       llvm::cl::desc(
173           "If true, native ops will be used on an opt-in basis instead of "
174           "fallback ops. If false, no native ops are used."),
175       llvm::cl::init(true)};
176   Option<bool> func_use_fallback_tensor{
177       *this, "func-use-fallback-tensor",
178       llvm::cl::desc(
179           "If true, use TF tensor as input/output types in func (and other "
180           "control flow) ops."),
181       llvm::cl::init(false)};
182 
183   Option<bool> enable_while_parallel_iterations{
184       *this, "enable-while-parallel-iterations",
185       llvm::cl::desc("If true, tf.While op will be parallelized. This is "
186                      "currently experimental."),
187       llvm::cl::init(false)};
188 
189   Option<bool> hoist_invariant_ops{
190       *this, "hoist-invariant-ops",
191       llvm::cl::desc("If true, invariant ops in savedmodels will be hoisted "
192                      "out to run during loading."),
193       llvm::cl::init(false)};
194 
195   Option<uint64_t> cost_threshold{
196       *this, "tfrt-cost-threshold",
197       llvm::cl::desc(
198           "The cost threshold to decide whether a sequence of operations is "
199           "cheap, and then whether it can be executed inline."),
200       llvm::cl::init(1)};
201 
202   Option<int64_t> upper_cost_threshold{
203       *this, "tfrt-upper-cost-threshold",
204       llvm::cl::desc(
205           "The threshold to limit the merging of dependent sequence."),
206       llvm::cl::init(-1)};
207 
208   Option<bool> merge_inter_dependent_streams{
209       *this, "tfrt-merge-inter-dependent-streams",
210       llvm::cl::desc("If true, streams with inter data depenedencies will be "
211                      "preferred to be merged for inline execution."),
212       llvm::cl::init(false)};
213 
214   // A set of flags to control auto-fusion: automatic clustering of Tensorflow
215   // operations and compiling outlined regions using MLIR based compilation
216   // stack.
217   //
218   // WARNING: These flags are experimental and are intended for manual testing
219   // of different auto-fusion strategies. They will be removed in the future.
220 
221   ListOption<std::string> auto_fusion_oplist{
222       *this, "auto-fusion-oplist",
223       llvm::cl::desc("A list of Tensorflow operations to cluster together for "
224                      "JIT compilation. Alternatively use 'tier1', ..., 'all' "
225                      "to allow clustering for all operations included in the "
226                      "given clustering tier.")};
227 
228   Option<int> auto_fusion_min_cluster_size{
229       *this, "auto-fusion-min-cluster-size",
230       llvm::cl::desc("Minimum size of the cluster that should be outlined for "
231                      "compilation"),
232       llvm::cl::init(2)};
233 };
234 
235 // Create a pass that converts MLIR TF dialect to MLIR TFRT dialect.
236 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
237 CreateTfToTfrtConversionPass(const TfrtPipelineOptions& options);
238 
239 // Creates a pipeline of passes that lowers MLIR TF Executor dialect to TF
240 // dialect for CoreRT purposes.
241 void CreateTFExecutorToTFPipeline(mlir::OpPassManager& pm,
242                                   const TfrtPipelineOptions& options);
243 
244 // Creates a pipeline of passes that lowers MLIR TF dialect from tf.function to
245 // TFRT dialect. SavedModel related conversions are not included.
246 void CreateTfExecutorToTfrtPipeline(mlir::PassManager& pm,
247                                     const TfrtPipelineOptions& options);
248 
249 }  // namespace tensorflow
250 
251 #endif  // TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_PASSES_H_
252