1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/mlir/tfrt/jit/tf_jitrt_pipeline.h"
17
18 #include "mlir/Conversion/BufferizationToMemRef/BufferizationToMemRef.h"
19 #include "mlir/Conversion/ComplexToStandard/ComplexToStandard.h"
20 #include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h"
21 #include "mlir/Conversion/VectorToSCF/VectorToSCF.h"
22 #include "mlir/Dialect/Arithmetic/Transforms/Passes.h"
23 #include "mlir/Dialect/Bufferization/Transforms/Passes.h"
24 #include "mlir/Dialect/Func/IR/FuncOps.h"
25 #include "mlir/Dialect/Linalg/Passes.h"
26 #include "mlir/Dialect/MemRef/Transforms/Passes.h"
27 #include "mlir/Dialect/Shape/Transforms/Passes.h"
28 #include "mlir/Dialect/Tensor/Transforms/Passes.h"
29 #include "mlir/Transforms/Passes.h"
30 #include "tensorflow/compiler/jit/flags.h"
31 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
32 #include "tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_passes.h"
33 #include "tensorflow/compiler/mlir/xla/transforms/passes.h"
34 #include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/gml_st/transforms/passes.h"
35 #include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h"
36 #include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Transforms/passes.h"
37
38 // -------------------------------------------------------------------------- //
39 // Custom passes that are missing upstream.
40 // -------------------------------------------------------------------------- //
41
42 namespace tensorflow {
43 namespace {
44
45 using mlir::OpPassManager;
46 using mlir::func::FuncOp;
47
48 // Adds a Tensorflow producer version to the module to enable shape inference.
49 struct AddTensorflowProducerVersion
50 : public mlir::PassWrapper<AddTensorflowProducerVersion,
51 mlir::OperationPass<mlir::ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_IDtensorflow::__anonf1b3930c0111::AddTensorflowProducerVersion52 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(AddTensorflowProducerVersion)
53
54 void runOnOperation() override {
55 mlir::ModuleOp module = getOperation();
56
57 // Tensorflow producer version does not really impact anything during the
58 // shape inference. Set it to `0` (any random number will do the work) to
59 // bypass attribute checks.
60 mlir::Builder builder(module);
61 auto version =
62 builder.getNamedAttr("producer", builder.getI32IntegerAttr(0));
63 module->setAttr("tf.versions", builder.getDictionaryAttr({version}));
64 }
65 };
66
67 // Adds Linalg passes to perform fusion, tiling, peeling and vectorization.
AddLinalgTransformations(OpPassManager & pm,const TfJitRtPipelineOptions & options)68 void AddLinalgTransformations(OpPassManager& pm,
69 const TfJitRtPipelineOptions& options) {
70 pm.addNestedPass<FuncOp>(CreateFusionPass());
71
72 if (!options.vectorize) return;
73
74 pm.addNestedPass<FuncOp>(CreateDetensorizeLinalgPass());
75
76 // Unfortunately, at the moment there is no way to provide default values for
77 // ListOption. That's why we have to provide them here. When
78 // https://github.com/llvm/llvm-project/issues/52667 feature request is
79 // accepted and implemented, this line will have to be removed.
80 mlir::SmallVector<int64_t, 2> reduction_2d_tile_sizes = {4, 4};
81 if (options.reduction_2d_tile_sizes.hasValue()) {
82 reduction_2d_tile_sizes.assign(options.reduction_2d_tile_sizes.begin(),
83 options.reduction_2d_tile_sizes.end());
84 }
85 pm.addNestedPass<FuncOp>(CreateTileReductionPass(
86 options.vector_size, options.reduction_1d_tile_size,
87 reduction_2d_tile_sizes));
88
89 if (options.vectorize && options.codegen_transpose)
90 pm.addNestedPass<FuncOp>(CreateTileTransposePass());
91 pm.addNestedPass<FuncOp>(CreateTileCWisePass(options.vector_size));
92 if (options.peel) {
93 pm.addNestedPass<FuncOp>(CreatePeelTiledLoopsPass());
94 }
95 pm.addNestedPass<FuncOp>(mlir::createCSEPass());
96 pm.addPass(mlir::createCanonicalizerPass());
97 if (options.fuse_fill) {
98 pm.addNestedPass<FuncOp>(CreateFuseFillIntoTiledReductionPass());
99 }
100 pm.addNestedPass<FuncOp>(CreateTileFillPass(options.vector_size));
101 pm.addNestedPass<FuncOp>(CreateVectorizeTiledOpsPass());
102 }
103
AddBufferizationPasses(OpPassManager & pm,bool one_shot_bufferize)104 void AddBufferizationPasses(OpPassManager& pm, bool one_shot_bufferize) {
105 // Rewrite init_tensor ops to alloc_tensor ops.
106 pm.addNestedPass<FuncOp>(mlir::createLinalgInitTensorToAllocTensorPass());
107 // Run One-Shot Bufferize.
108 if (one_shot_bufferize) {
109 pm.addPass(mlir::hlo::createOneShotBufferizePass());
110 return;
111 }
112 // Now bufferize all the compute operations (hlo + linalg) and func signature.
113 pm.addPass(mlir::createComputeOpAndFuncBufferizePass());
114 pm.addNestedPass<FuncOp>(mlir::gml_st::CreateTiledLoopBufferizePass());
115 // Always run CSE and canonicalizer (which does dead code removal) before
116 // bufferizing anything.
117 pm.addPass(mlir::createCSEPass());
118 pm.addPass(mlir::createCanonicalizerPass());
119 pm.addPass(mlir::createFinalBufferizePass(/*alignment=*/64));
120 }
121
122 } // namespace
123
124 // -------------------------------------------------------------------------- //
125 // Assemble a TF JitRt pipeline to lower from Tensorflow dialects to Linalg on
126 // buffers via progressive lowering to MHLO and Linalg.
127 // -------------------------------------------------------------------------- //
CreateTfJitRtPipeline(OpPassManager & pm,const TfJitRtPipelineOptions & options)128 void CreateTfJitRtPipeline(OpPassManager& pm,
129 const TfJitRtPipelineOptions& options) {
130 // Break Tensorflow fused operations into primitive operations before
131 // lowering to HLO.
132 pm.addNestedPass<FuncOp>(CreateFissionPass());
133
134 // Run shape inference to propagate potentially specialized input shapes.
135 pm.addPass(std::make_unique<AddTensorflowProducerVersion>());
136 pm.addPass(mlir::TF::CreateTFShapeInferencePass());
137 pm.addPass(mlir::createCanonicalizerPass());
138
139 // Transform TF operation to HLO.
140 pm.addPass(mlir::mhlo::createLegalizeTFControlFlowPass());
141 pm.addNestedPass<FuncOp>(mlir::mhlo::createLegalizeTFPass());
142
143 if (options.legalize_i1_tensors) {
144 // Convert 'i1' tensors into 'i8' tensors.
145 pm.addPass(CreateJitRtLegalizeI1TypesPass());
146 }
147
148 // Remove redundant shape operations left after legalizing to HLO.
149 pm.addPass(mlir::createCSEPass());
150
151 // Resolve all shape constraints (e.g. broadcast constraints that can be
152 // proved statically and changed to const witness) early to allow more
153 // efficient broadcast operations moving.
154 pm.addNestedPass<FuncOp>(
155 CreateSymbolicShapeOptimizationPass(/*constraints_only=*/true));
156
157 // Analyze shapes and try to simplify the IR as early as possible.
158 pm.addNestedPass<FuncOp>(mlir::createSymbolicShapeOptimizationPass());
159 pm.addPass(mlir::createCSEPass());
160 pm.addPass(mlir::createCanonicalizerPass());
161
162 // Move up broadcasting operations to allow for more fusion opportunities.
163 // Add the broadcast propagation pass first, because it can help to avoid
164 // exponential complexity from the EarlyBroadcastInDimOp pattern which is used
165 // in the merge assuming ops pass further down.
166 pm.addNestedPass<FuncOp>(mlir::mhlo::createMergeAssumingOpsPass());
167 pm.addNestedPass<FuncOp>(mlir::mhlo::createBroadcastPropagationPass());
168 pm.addPass(mlir::createCSEPass());
169 pm.addPass(mlir::createCanonicalizerPass());
170
171 // After all shape constraints removed and broadcasts moved to the top, try
172 // to resolve broadcasts that can be converted to linalg generic operations.
173 pm.addNestedPass<FuncOp>(CreateSymbolicShapeOptimizationPass());
174
175 // Group reduction and parallel dimensions of reduction operations and realize
176 // them through equivalent 1D or 2D reductions, if possible.
177 pm.addNestedPass<FuncOp>(mlir::mhlo::createGroupReductionDimensionsPass());
178
179 // Also, try to simplify reshape operations.
180 pm.addNestedPass<FuncOp>(mlir::createSymbolicShapeOptimizationPass());
181
182 // Transform HLO operations to Linalg and Standard.
183 pm.addNestedPass<FuncOp>(mlir::mhlo::createLegalizeControlFlowPass());
184 pm.addNestedPass<mlir::func::FuncOp>(mlir::mhlo::createLegalizeSortPass());
185 pm.addNestedPass<FuncOp>(mlir::mhlo::createLegalizeHloToLinalgPass());
186 pm.addPass(mlir::mhlo::createLegalizeToArithmeticPass());
187 pm.addNestedPass<FuncOp>(
188 mlir::mhlo::createLegalizeHloShapeOpsToStandardPass());
189
190 // Now that all compute operations are converted to standard (as a side effect
191 // of bufferizing to memref dialect) we can remove the remaining references
192 // to unsigned types.
193 pm.addPass(mlir::mhlo::createConvertToSignlessPass());
194
195 // Lower shape dialect to standard to enable linalg canonicalizations (e.g.
196 // use linalg inputs instead of outputs for memref.dim operations).
197 pm.addNestedPass<FuncOp>(mlir::createShapeSimplification());
198 pm.addNestedPass<FuncOp>(mlir::createShapeToShapeLowering());
199 pm.addPass(mlir::createConvertShapeToStandardPass());
200 pm.addNestedPass<FuncOp>(mlir::createConvertShapeConstraintsPass());
201
202 // Fuse Linalg on tensors operations.
203 pm.addPass(mlir::createCSEPass());
204 pm.addPass(mlir::memref::createResolveShapedTypeResultDimsPass());
205 // Lower index cast on tensors to tensor.generate.
206 pm.addNestedPass<FuncOp>(mlir::createLowerIndexCastPass());
207 pm.addPass(mlir::createCSEPass());
208 pm.addPass(mlir::createCanonicalizerPass());
209
210 // Convert complex types.
211 pm.addPass(mlir::createConvertComplexToStandardPass());
212
213 // Add linalg passes to perform fusion, tiling, peeling and vectorization.
214 AddLinalgTransformations(pm, options);
215
216 // Inline everything, bufferization doesn't model ownership across calls.
217 pm.addPass(mlir::createInlinerPass());
218
219 // Always run canonicalizer (which does dead code removal) before bufferizing
220 // anything.
221 pm.addPass(mlir::createCanonicalizerPass());
222
223 AddBufferizationPasses(pm, options.one_shot_bufferize || options.vectorize);
224
225 pm.addPass(mlir::createCSEPass());
226 pm.addPass(mlir::createCanonicalizerPass());
227
228 // Deallocate all temporary buffers.
229 pm.addNestedPass<FuncOp>(mlir::bufferization::createBufferDeallocationPass());
230
231 // Do trivial buffer forwarding across linalg.generic operations.
232 pm.addNestedPass<FuncOp>(CreateLinalgTrivialBufferForwardingPass());
233
234 // Remove trivial copy operations.
235 pm.addNestedPass<FuncOp>(CreateLinalgTrivialCopyRemovalPass());
236
237 if (options.vectorize)
238 pm.addNestedPass<FuncOp>(mlir::gml_st::createGmlStToScfPass());
239
240 pm.addPass(mlir::createBufferizationToMemRefPass());
241 pm.addPass(mlir::createCSEPass());
242 pm.addPass(mlir::createCanonicalizerPass());
243
244 if (options.vectorize && options.codegen_transpose)
245 pm.addNestedPass<FuncOp>(CreateLowerVectorTransposePass());
246
247 mlir::VectorTransferToSCFOptions vec_to_scf_options;
248 vec_to_scf_options.unroll = true;
249 pm.addNestedPass<FuncOp>(
250 mlir::createConvertVectorToSCFPass(vec_to_scf_options));
251 pm.addNestedPass<FuncOp>(createRewriteVectorMultiReductionPass());
252
253 pm.addNestedPass<FuncOp>(CreateMathApproximationPass({"all"}));
254 }
255
CreateDefaultTfJitRtPipeline(OpPassManager & pm)256 void CreateDefaultTfJitRtPipeline(OpPassManager& pm) {
257 TfJitRtPipelineOptions options;
258 options.vectorize = tensorflow::GetJitRtFlags().vectorize;
259 CreateTfJitRtPipeline(pm, options);
260 }
261
CreateJitRtSpecializationPipeline(mlir::OpPassManager & pm)262 void CreateJitRtSpecializationPipeline(mlir::OpPassManager& pm) {
263 pm.addPass(std::make_unique<AddTensorflowProducerVersion>());
264 pm.addPass(mlir::TF::CreateTFShapeInferencePass());
265 pm.addPass(mlir::createCanonicalizerPass());
266 }
267
268 static mlir::PassPipelineRegistration<TfJitRtPipelineOptions> tf_jitrt_pipeline(
269 "tf-jitrt-pipeline",
270 "Convert Tensorflow dialect to TFRT's JitRt compatible dialects",
271 CreateTfJitRtPipeline);
272
273 } // namespace tensorflow
274