xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/mlir/tfrt/jit/tf_jitrt_pipeline.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/mlir/tfrt/jit/tf_jitrt_pipeline.h"
17 
18 #include "mlir/Conversion/BufferizationToMemRef/BufferizationToMemRef.h"
19 #include "mlir/Conversion/ComplexToStandard/ComplexToStandard.h"
20 #include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h"
21 #include "mlir/Conversion/VectorToSCF/VectorToSCF.h"
22 #include "mlir/Dialect/Arithmetic/Transforms/Passes.h"
23 #include "mlir/Dialect/Bufferization/Transforms/Passes.h"
24 #include "mlir/Dialect/Func/IR/FuncOps.h"
25 #include "mlir/Dialect/Linalg/Passes.h"
26 #include "mlir/Dialect/MemRef/Transforms/Passes.h"
27 #include "mlir/Dialect/Shape/Transforms/Passes.h"
28 #include "mlir/Dialect/Tensor/Transforms/Passes.h"
29 #include "mlir/Transforms/Passes.h"
30 #include "tensorflow/compiler/jit/flags.h"
31 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
32 #include "tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_passes.h"
33 #include "tensorflow/compiler/mlir/xla/transforms/passes.h"
34 #include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/gml_st/transforms/passes.h"
35 #include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h"
36 #include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Transforms/passes.h"
37 
38 // -------------------------------------------------------------------------- //
39 // Custom passes that are missing upstream.
40 // -------------------------------------------------------------------------- //
41 
42 namespace tensorflow {
43 namespace {
44 
45 using mlir::OpPassManager;
46 using mlir::func::FuncOp;
47 
48 // Adds a Tensorflow producer version to the module to enable shape inference.
49 struct AddTensorflowProducerVersion
50     : public mlir::PassWrapper<AddTensorflowProducerVersion,
51                                mlir::OperationPass<mlir::ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_IDtensorflow::__anonf1b3930c0111::AddTensorflowProducerVersion52   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(AddTensorflowProducerVersion)
53 
54   void runOnOperation() override {
55     mlir::ModuleOp module = getOperation();
56 
57     // Tensorflow producer version does not really impact anything during the
58     // shape inference. Set it to `0` (any random number will do the work) to
59     // bypass attribute checks.
60     mlir::Builder builder(module);
61     auto version =
62         builder.getNamedAttr("producer", builder.getI32IntegerAttr(0));
63     module->setAttr("tf.versions", builder.getDictionaryAttr({version}));
64   }
65 };
66 
67 // Adds Linalg passes to perform fusion, tiling, peeling and vectorization.
AddLinalgTransformations(OpPassManager & pm,const TfJitRtPipelineOptions & options)68 void AddLinalgTransformations(OpPassManager& pm,
69                               const TfJitRtPipelineOptions& options) {
70   pm.addNestedPass<FuncOp>(CreateFusionPass());
71 
72   if (!options.vectorize) return;
73 
74   pm.addNestedPass<FuncOp>(CreateDetensorizeLinalgPass());
75 
76   // Unfortunately, at the moment there is no way to provide default values for
77   // ListOption. That's why we have to provide them here. When
78   // https://github.com/llvm/llvm-project/issues/52667 feature request is
79   // accepted and implemented, this line will have to be removed.
80   mlir::SmallVector<int64_t, 2> reduction_2d_tile_sizes = {4, 4};
81   if (options.reduction_2d_tile_sizes.hasValue()) {
82     reduction_2d_tile_sizes.assign(options.reduction_2d_tile_sizes.begin(),
83                                    options.reduction_2d_tile_sizes.end());
84   }
85   pm.addNestedPass<FuncOp>(CreateTileReductionPass(
86       options.vector_size, options.reduction_1d_tile_size,
87       reduction_2d_tile_sizes));
88 
89   if (options.vectorize && options.codegen_transpose)
90     pm.addNestedPass<FuncOp>(CreateTileTransposePass());
91   pm.addNestedPass<FuncOp>(CreateTileCWisePass(options.vector_size));
92   if (options.peel) {
93     pm.addNestedPass<FuncOp>(CreatePeelTiledLoopsPass());
94   }
95   pm.addNestedPass<FuncOp>(mlir::createCSEPass());
96   pm.addPass(mlir::createCanonicalizerPass());
97   if (options.fuse_fill) {
98     pm.addNestedPass<FuncOp>(CreateFuseFillIntoTiledReductionPass());
99   }
100   pm.addNestedPass<FuncOp>(CreateTileFillPass(options.vector_size));
101   pm.addNestedPass<FuncOp>(CreateVectorizeTiledOpsPass());
102 }
103 
AddBufferizationPasses(OpPassManager & pm,bool one_shot_bufferize)104 void AddBufferizationPasses(OpPassManager& pm, bool one_shot_bufferize) {
105   // Rewrite init_tensor ops to alloc_tensor ops.
106   pm.addNestedPass<FuncOp>(mlir::createLinalgInitTensorToAllocTensorPass());
107   // Run One-Shot Bufferize.
108   if (one_shot_bufferize) {
109     pm.addPass(mlir::hlo::createOneShotBufferizePass());
110     return;
111   }
112   // Now bufferize all the compute operations (hlo + linalg) and func signature.
113   pm.addPass(mlir::createComputeOpAndFuncBufferizePass());
114   pm.addNestedPass<FuncOp>(mlir::gml_st::CreateTiledLoopBufferizePass());
115   // Always run CSE and canonicalizer (which does dead code removal) before
116   // bufferizing anything.
117   pm.addPass(mlir::createCSEPass());
118   pm.addPass(mlir::createCanonicalizerPass());
119   pm.addPass(mlir::createFinalBufferizePass(/*alignment=*/64));
120 }
121 
122 }  // namespace
123 
124 // -------------------------------------------------------------------------- //
125 // Assemble a TF JitRt pipeline to lower from Tensorflow dialects to Linalg on
126 // buffers via progressive lowering to MHLO and Linalg.
127 // -------------------------------------------------------------------------- //
CreateTfJitRtPipeline(OpPassManager & pm,const TfJitRtPipelineOptions & options)128 void CreateTfJitRtPipeline(OpPassManager& pm,
129                            const TfJitRtPipelineOptions& options) {
130   // Break Tensorflow fused operations into primitive operations before
131   // lowering to HLO.
132   pm.addNestedPass<FuncOp>(CreateFissionPass());
133 
134   // Run shape inference to propagate potentially specialized input shapes.
135   pm.addPass(std::make_unique<AddTensorflowProducerVersion>());
136   pm.addPass(mlir::TF::CreateTFShapeInferencePass());
137   pm.addPass(mlir::createCanonicalizerPass());
138 
139   // Transform TF operation to HLO.
140   pm.addPass(mlir::mhlo::createLegalizeTFControlFlowPass());
141   pm.addNestedPass<FuncOp>(mlir::mhlo::createLegalizeTFPass());
142 
143   if (options.legalize_i1_tensors) {
144     // Convert 'i1' tensors into 'i8' tensors.
145     pm.addPass(CreateJitRtLegalizeI1TypesPass());
146   }
147 
148   // Remove redundant shape operations left after legalizing to HLO.
149   pm.addPass(mlir::createCSEPass());
150 
151   // Resolve all shape constraints (e.g. broadcast constraints that can be
152   // proved statically and changed to const witness) early to allow more
153   // efficient broadcast operations moving.
154   pm.addNestedPass<FuncOp>(
155       CreateSymbolicShapeOptimizationPass(/*constraints_only=*/true));
156 
157   // Analyze shapes and try to simplify the IR as early as possible.
158   pm.addNestedPass<FuncOp>(mlir::createSymbolicShapeOptimizationPass());
159   pm.addPass(mlir::createCSEPass());
160   pm.addPass(mlir::createCanonicalizerPass());
161 
162   // Move up broadcasting operations to allow for more fusion opportunities.
163   // Add the broadcast propagation pass first, because it can help to avoid
164   // exponential complexity from the EarlyBroadcastInDimOp pattern which is used
165   // in the merge assuming ops pass further down.
166   pm.addNestedPass<FuncOp>(mlir::mhlo::createMergeAssumingOpsPass());
167   pm.addNestedPass<FuncOp>(mlir::mhlo::createBroadcastPropagationPass());
168   pm.addPass(mlir::createCSEPass());
169   pm.addPass(mlir::createCanonicalizerPass());
170 
171   // After all shape constraints removed and broadcasts moved to the top, try
172   // to resolve broadcasts that can be converted to linalg generic operations.
173   pm.addNestedPass<FuncOp>(CreateSymbolicShapeOptimizationPass());
174 
175   // Group reduction and parallel dimensions of reduction operations and realize
176   // them through equivalent 1D or 2D reductions, if possible.
177   pm.addNestedPass<FuncOp>(mlir::mhlo::createGroupReductionDimensionsPass());
178 
179   // Also, try to simplify reshape operations.
180   pm.addNestedPass<FuncOp>(mlir::createSymbolicShapeOptimizationPass());
181 
182   // Transform HLO operations to Linalg and Standard.
183   pm.addNestedPass<FuncOp>(mlir::mhlo::createLegalizeControlFlowPass());
184   pm.addNestedPass<mlir::func::FuncOp>(mlir::mhlo::createLegalizeSortPass());
185   pm.addNestedPass<FuncOp>(mlir::mhlo::createLegalizeHloToLinalgPass());
186   pm.addPass(mlir::mhlo::createLegalizeToArithmeticPass());
187   pm.addNestedPass<FuncOp>(
188       mlir::mhlo::createLegalizeHloShapeOpsToStandardPass());
189 
190   // Now that all compute operations are converted to standard (as a side effect
191   // of bufferizing to memref dialect) we can remove the remaining references
192   // to unsigned types.
193   pm.addPass(mlir::mhlo::createConvertToSignlessPass());
194 
195   // Lower shape dialect to standard to enable linalg canonicalizations (e.g.
196   // use linalg inputs instead of outputs for memref.dim operations).
197   pm.addNestedPass<FuncOp>(mlir::createShapeSimplification());
198   pm.addNestedPass<FuncOp>(mlir::createShapeToShapeLowering());
199   pm.addPass(mlir::createConvertShapeToStandardPass());
200   pm.addNestedPass<FuncOp>(mlir::createConvertShapeConstraintsPass());
201 
202   // Fuse Linalg on tensors operations.
203   pm.addPass(mlir::createCSEPass());
204   pm.addPass(mlir::memref::createResolveShapedTypeResultDimsPass());
205   // Lower index cast on tensors to tensor.generate.
206   pm.addNestedPass<FuncOp>(mlir::createLowerIndexCastPass());
207   pm.addPass(mlir::createCSEPass());
208   pm.addPass(mlir::createCanonicalizerPass());
209 
210   // Convert complex types.
211   pm.addPass(mlir::createConvertComplexToStandardPass());
212 
213   // Add linalg passes to perform fusion, tiling, peeling and vectorization.
214   AddLinalgTransformations(pm, options);
215 
216   // Inline everything, bufferization doesn't model ownership across calls.
217   pm.addPass(mlir::createInlinerPass());
218 
219   // Always run canonicalizer (which does dead code removal) before bufferizing
220   // anything.
221   pm.addPass(mlir::createCanonicalizerPass());
222 
223   AddBufferizationPasses(pm, options.one_shot_bufferize || options.vectorize);
224 
225   pm.addPass(mlir::createCSEPass());
226   pm.addPass(mlir::createCanonicalizerPass());
227 
228   // Deallocate all temporary buffers.
229   pm.addNestedPass<FuncOp>(mlir::bufferization::createBufferDeallocationPass());
230 
231   // Do trivial buffer forwarding across linalg.generic operations.
232   pm.addNestedPass<FuncOp>(CreateLinalgTrivialBufferForwardingPass());
233 
234   // Remove trivial copy operations.
235   pm.addNestedPass<FuncOp>(CreateLinalgTrivialCopyRemovalPass());
236 
237   if (options.vectorize)
238     pm.addNestedPass<FuncOp>(mlir::gml_st::createGmlStToScfPass());
239 
240   pm.addPass(mlir::createBufferizationToMemRefPass());
241   pm.addPass(mlir::createCSEPass());
242   pm.addPass(mlir::createCanonicalizerPass());
243 
244   if (options.vectorize && options.codegen_transpose)
245     pm.addNestedPass<FuncOp>(CreateLowerVectorTransposePass());
246 
247   mlir::VectorTransferToSCFOptions vec_to_scf_options;
248   vec_to_scf_options.unroll = true;
249   pm.addNestedPass<FuncOp>(
250       mlir::createConvertVectorToSCFPass(vec_to_scf_options));
251   pm.addNestedPass<FuncOp>(createRewriteVectorMultiReductionPass());
252 
253   pm.addNestedPass<FuncOp>(CreateMathApproximationPass({"all"}));
254 }
255 
CreateDefaultTfJitRtPipeline(OpPassManager & pm)256 void CreateDefaultTfJitRtPipeline(OpPassManager& pm) {
257   TfJitRtPipelineOptions options;
258   options.vectorize = tensorflow::GetJitRtFlags().vectorize;
259   CreateTfJitRtPipeline(pm, options);
260 }
261 
CreateJitRtSpecializationPipeline(mlir::OpPassManager & pm)262 void CreateJitRtSpecializationPipeline(mlir::OpPassManager& pm) {
263   pm.addPass(std::make_unique<AddTensorflowProducerVersion>());
264   pm.addPass(mlir::TF::CreateTFShapeInferencePass());
265   pm.addPass(mlir::createCanonicalizerPass());
266 }
267 
268 static mlir::PassPipelineRegistration<TfJitRtPipelineOptions> tf_jitrt_pipeline(
269     "tf-jitrt-pipeline",
270     "Convert Tensorflow dialect to TFRT's JitRt compatible dialects",
271     CreateTfJitRtPipeline);
272 
273 }  // namespace tensorflow
274