xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/mlir_hlo/lib/Transforms/bufferize_pass.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This file implements logic for translating mixed IR to buffer form.
17 // Currently it supports MHLO and some operations from the Standard dialect.
18 
19 #include <cstdint>
20 #include <functional>
21 #include <memory>
22 #include <utility>
23 
24 #include "llvm/ADT/STLExtras.h"
25 #include "llvm/Support/raw_ostream.h"
26 #include "mlir-hlo/Dialect/gml_st/IR/gml_st_ops.h"
27 #include "mlir-hlo/Dialect/gml_st/transforms/bufferizable_op_interface_impl.h"
28 #include "mlir-hlo/Dialect/lhlo/IR/lhlo_ops.h"
29 #include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
30 #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
31 #include "mlir-hlo/Dialect/mhlo/transforms/bufferizable_op_interface_impl.h"
32 #include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
33 #include "mlir-hlo/Dialect/mhlo/transforms/type_conversion.h"
34 #include "mlir-hlo/Transforms/PassDetail.h"
35 #include "mlir-hlo/Transforms/passes.h"
36 #include "mlir-hlo/Transforms/rewriters.h"
37 #include "mlir/Dialect/Affine/IR/AffineOps.h"
38 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
39 #include "mlir/Dialect/Arithmetic/Transforms/BufferizableOpInterfaceImpl.h"
40 #include "mlir/Dialect/Arithmetic/Transforms/Passes.h"
41 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
42 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
43 #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
44 #include "mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h"
45 #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
46 #include "mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h"
47 #include "mlir/Dialect/Complex/IR/Complex.h"
48 #include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
49 #include "mlir/Dialect/Func/IR/FuncOps.h"
50 #include "mlir/Dialect/Func/Transforms/FuncConversions.h"
51 #include "mlir/Dialect/Func/Transforms/Passes.h"
52 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
53 #include "mlir/Dialect/Linalg/IR/Linalg.h"
54 #include "mlir/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.h"
55 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
56 #include "mlir/Dialect/Math/IR/Math.h"
57 #include "mlir/Dialect/MemRef/IR/MemRef.h"
58 #include "mlir/Dialect/SCF/IR/SCF.h"
59 #include "mlir/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.h"
60 #include "mlir/Dialect/SCF/Transforms/Transforms.h"
61 #include "mlir/Dialect/Shape/IR/Shape.h"
62 #include "mlir/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.h"
63 #include "mlir/Dialect/Shape/Transforms/Passes.h"
64 #include "mlir/Dialect/Tensor/IR/Tensor.h"
65 #include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h"
66 #include "mlir/Dialect/Tensor/Transforms/Passes.h"
67 #include "mlir/Dialect/Vector/IR/VectorOps.h"
68 #include "mlir/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.h"
69 #include "mlir/IR/BuiltinOps.h"
70 #include "mlir/IR/BuiltinTypes.h"
71 #include "mlir/IR/MLIRContext.h"
72 #include "mlir/IR/Operation.h"
73 #include "mlir/IR/PatternMatch.h"
74 #include "mlir/IR/Visitors.h"
75 #include "mlir/Transforms/DialectConversion.h"
76 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
77 
78 namespace mlir {
79 namespace {
80 
81 /// A helper type converter class that automatically populates the relevant
82 /// materializations and type conversions for bufferization.
83 
materializeToTensor(OpBuilder & builder,TensorType type,ValueRange inputs,Location loc)84 static Value materializeToTensor(OpBuilder& builder, TensorType type,
85                                  ValueRange inputs, Location loc) {
86   assert(inputs.size() == 1);
87   assert(inputs[0].getType().isa<BaseMemRefType>());
88   return builder.create<bufferization::ToTensorOp>(loc, type, inputs[0]);
89 }
90 
91 // TODO(pifon): Remove as soon as https://reviews.llvm.org/D93126 is landed.
92 class CustomBufferizeTypeConverter
93     : public bufferization::BufferizeTypeConverter {
94  public:
CustomBufferizeTypeConverter()95   CustomBufferizeTypeConverter() {
96     // Keep all types unchanged.
97     addConversion([](Type type) { return type; });
98     // Convert RankedTensorType to MemRefType.
99     addConversion([](RankedTensorType type) -> Type {
100       return MemRefType::get(type.getShape(), type.getElementType());
101     });
102     // Convert UnrankedTensorType to UnrankedMemRefType.
103     addConversion([](UnrankedTensorType type) -> Type {
104       return UnrankedMemRefType::get(type.getElementType(), 0);
105     });
106     addArgumentMaterialization(materializeToTensor);
107     addSourceMaterialization(materializeToTensor);
108     addTargetMaterialization([](OpBuilder& builder, BaseMemRefType type,
109                                 ValueRange inputs, Location loc) -> Value {
110       assert(inputs.size() == 1);
111       // Target materialization is invoked if the new operand type does not
112       // match the expected type. A special case is when the new operand type is
113       // a memref with a specified layout, i.e. non-empty affine map.
114       // TODO(pifon) : Change how target materialization is invoked in dialect
115       // conversion.
116       if (auto memrefType = inputs[0].getType().dyn_cast<MemRefType>()) {
117         assert(!memrefType.getLayout().isIdentity());
118         return inputs[0];
119       }
120       assert(inputs[0].getType().isa<TensorType>());
121       return builder.create<bufferization::ToMemrefOp>(loc, type, inputs[0]);
122     });
123   }
124 };
125 
126 struct ComputeOpAndFuncBufferizePass
127     : public ComputeOpAndFuncBufferizePassBase<ComputeOpAndFuncBufferizePass> {
getDependentDialectsmlir::__anone43020be0111::ComputeOpAndFuncBufferizePass128   void getDependentDialects(DialectRegistry& registry) const override {
129     registry
130         .insert<bufferization::BufferizationDialect, lmhlo::LmhloDialect,
131                 linalg::LinalgDialect, memref::MemRefDialect, mhlo::MhloDialect,
132                 shape::ShapeDialect, vector::VectorDialect>();
133     linalg::registerBufferizableOpInterfaceExternalModels(registry);
134     mhlo::registerBufferizableOpInterfaceExternalModels(registry);
135     shape::registerBufferizableOpInterfaceExternalModels(registry);
136     vector::registerBufferizableOpInterfaceExternalModels(registry);
137   }
138 
runOnOperationmlir::__anone43020be0111::ComputeOpAndFuncBufferizePass139   void runOnOperation() override {
140     // Bufferize ops using BufferizableOpInterface. This could be switched to
141     // One-Shot Bufferize in the future.
142     bufferization::BufferizationOptions options =
143         bufferization::getPartialBufferizationOptions();
144     // TODO(springerm): Add dialects to this filter as more and more dialects
145     // will be migrated to BufferizableOpInterface-based bufferization.
146     options.opFilter.allowDialect<bufferization::BufferizationDialect,
147                                   linalg::LinalgDialect, mhlo::MhloDialect,
148                                   shape::ShapeDialect, tensor::TensorDialect,
149                                   vector::VectorDialect>();
150     // Ops inside TiledLoopOps have special handling.
151     options.opFilter.denyOperation([](Operation* op) {
152       return mlir::isa<gml_st::LoopOp>(op->getParentOp());
153     });
154 
155     if (failed(bufferization::bufferizeOp(getOperation(), options))) {
156       signalPassFailure();
157       return;
158     }
159 
160     // Bufferize the remaining IR with dialect conversion. This will disappear
161     // eventually once all bufferization is done via BufferizableOpInterface.
162     if (failed(runDialectConversionBasedBufferization())) signalPassFailure();
163   }
164 
165  private:
runDialectConversionBasedBufferizationmlir::__anone43020be0111::ComputeOpAndFuncBufferizePass166   LogicalResult runDialectConversionBasedBufferization() {
167     RewritePatternSet patterns(&getContext());
168     auto& context = getContext();
169     ConversionTarget target(context);
170     target.addLegalDialect<
171         arith::ArithmeticDialect, complex::ComplexDialect, lmhlo::LmhloDialect,
172         AffineDialect, vector::VectorDialect, memref::MemRefDialect,
173         func::FuncDialect, tensor::TensorDialect, math::MathDialect>();
174     target.addLegalOp<UnrealizedConversionCastOp, gml_st::LoopOp>();
175     target.addIllegalDialect<mhlo::MhloDialect>();
176     target.addDynamicallyLegalOp<tensor::ExtractSliceOp, tensor::InsertSliceOp>(
177         [&](Operation* op) {
178           return mlir::isa<gml_st::LoopOp>(op->getParentOp());
179         });
180 
181     CustomBufferizeTypeConverter converter;
182     populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns,
183                                                                    converter);
184     populateCallOpTypeConversionPattern(patterns, converter);
185     populateBranchOpInterfaceTypeConversionPattern(patterns, converter);
186     populateReturnOpTypeConversionPattern(patterns, converter);
187 
188     // Configure legality and structural patterns.
189     bufferization::populateBufferizeMaterializationLegality(target);
190     scf::populateSCFStructuralTypeConversionsAndLegality(converter, patterns,
191                                                          target);
192 
193     // TODO(herhut): Move this legality configuration to bufferize itself?
194     target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
195       auto inputs = op.getFunctionType().getInputs();
196       auto results = op.getFunctionType().getResults();
197       return converter.isLegal(inputs) && converter.isLegal(results) &&
198              converter.isLegal(&op.getBody());
199     });
200     auto isLegalOp = [&](Operation* op) { return converter.isLegal(op); };
201     target.addDynamicallyLegalOp<func::CallOp, func::ReturnOp>(isLegalOp);
202 
203     auto isLegalOrInsideTiledLoop = [&](Operation* op) {
204       return converter.isLegal(op) ||
205              mlir::isa<gml_st::LoopOp>(op->getParentOp());
206     };
207     target.addDynamicallyLegalDialect<linalg::LinalgDialect>(
208         isLegalOrInsideTiledLoop);
209     target
210         .addDynamicallyLegalOp<vector::TransferWriteOp, vector::TransferReadOp>(
211             isLegalOrInsideTiledLoop);
212 
213     return applyPartialConversion(getOperation(), target, std::move(patterns));
214   }
215 };
216 
217 struct OneShotBufferizePass
218     : public OneShotBufferizeBase<OneShotBufferizePass> {
219   // TODO(b/173201243): Move to tablegen.
getDependentDialectsmlir::__anone43020be0111::OneShotBufferizePass220   void getDependentDialects(DialectRegistry& registry) const override {
221     registry
222         .insert<bufferization::BufferizationDialect, lmhlo::LmhloDialect,
223                 linalg::LinalgDialect, memref::MemRefDialect, mhlo::MhloDialect,
224                 scf::SCFDialect, shape::ShapeDialect, vector::VectorDialect>();
225     arith::registerBufferizableOpInterfaceExternalModels(registry);
226     bufferization::func_ext::registerBufferizableOpInterfaceExternalModels(
227         registry);
228     gml_st::registerBufferizableOpInterfaceExternalModels(registry);
229     linalg::registerBufferizableOpInterfaceExternalModels(registry);
230     mhlo::registerBufferizableOpInterfaceExternalModels(registry);
231     scf::registerBufferizableOpInterfaceExternalModels(registry);
232     shape::registerBufferizableOpInterfaceExternalModels(registry);
233     tensor::registerBufferizableOpInterfaceExternalModels(registry);
234     vector::registerBufferizableOpInterfaceExternalModels(registry);
235   }
236 
runOnOperationmlir::__anone43020be0111::OneShotBufferizePass237   void runOnOperation() override {
238     bufferization::OneShotBufferizationOptions opts;
239     opts.allowReturnAllocs = true;
240     opts.bufferizeFunctionBoundaries = true;
241     opts.functionBoundaryTypeConversion =
242         bufferization::BufferizationOptions::LayoutMapOption::IdentityLayoutMap;
243     opts.createDeallocs = false;
244     opts.bufferAlignment = 64;
245 
246     ModuleOp module = getOperation();
247     if (failed(bufferization::runOneShotModuleBufferize(module, opts))) {
248       signalPassFailure();
249     }
250   }
251 };
252 
253 struct FinalBufferizePass : public FinalBufferizePassBase<FinalBufferizePass> {
254  private:
255   BufferizeDialectsCallback dialectsCallback;
256   BufferizePatternsCallback patternsCallback;
257 
258  public:
getDependentDialectsmlir::__anone43020be0111::FinalBufferizePass259   void getDependentDialects(DialectRegistry& registry) const override {
260     registry
261         .insert<AffineDialect, bufferization::BufferizationDialect,
262                 linalg::LinalgDialect, memref::MemRefDialect, scf::SCFDialect,
263                 shape::ShapeDialect, tensor::TensorDialect, lmhlo::LmhloDialect,
264                 arith::ArithmeticDialect, vector::VectorDialect>();
265     arith::registerBufferizableOpInterfaceExternalModels(registry);
266     linalg::registerBufferizableOpInterfaceExternalModels(registry);
267     shape::registerBufferizableOpInterfaceExternalModels(registry);
268     tensor::registerBufferizableOpInterfaceExternalModels(registry);
269     vector::registerBufferizableOpInterfaceExternalModels(registry);
270     if (dialectsCallback) dialectsCallback(registry);
271   }
272   // Default alignment_ specified in passes.td
273   FinalBufferizePass() = default;
274 
FinalBufferizePassmlir::__anone43020be0111::FinalBufferizePass275   explicit FinalBufferizePass(uint64_t alignment) { alignment_ = alignment; }
276 
setCallbacksmlir::__anone43020be0111::FinalBufferizePass277   void setCallbacks(BufferizeDialectsCallback dc,
278                     BufferizePatternsCallback pc) {
279     dialectsCallback = std::move(dc);
280     patternsCallback = std::move(pc);
281   }
282 
runOnOperationmlir::__anone43020be0111::FinalBufferizePass283   void runOnOperation() override {
284     // Bufferize ops using BufferizableOpInterface. This could be switched to
285     // One-Shot Bufferize in the future.
286     bufferization::BufferizationOptions options =
287         bufferization::getPartialBufferizationOptions();
288     options.bufferAlignment = alignment_;
289     // TODO(springerm): Add dialects to this filter as more and more dialects
290     // will be migrated to BufferizableOpInterface-based bufferization.
291     options.opFilter.allowDialect<
292         arith::ArithmeticDialect, bufferization::BufferizationDialect,
293         linalg::LinalgDialect, func::FuncDialect, shape::ShapeDialect,
294         tensor::TensorDialect, vector::VectorDialect>();
295     if (failed(bufferization::bufferizeOp(getOperation(), options))) {
296       signalPassFailure();
297       return;
298     }
299 
300     // Bufferize the remaining IR with dialect conversion. This will disappear
301     // eventually once all bufferization is done via BufferizableOpInterface.
302     if (failed(runDialectConversionBasedBufferization())) signalPassFailure();
303   }
304 
305  private:
runDialectConversionBasedBufferizationmlir::__anone43020be0111::FinalBufferizePass306   LogicalResult runDialectConversionBasedBufferization() {
307     auto& context = getContext();
308     ConversionTarget target(context);
309     target.addLegalDialect<
310         arith::ArithmeticDialect, bufferization::BufferizationDialect,
311         cf::ControlFlowDialect, complex::ComplexDialect, memref::MemRefDialect,
312         func::FuncDialect, scf::SCFDialect, tensor::TensorDialect,
313         AffineDialect, shape::ShapeDialect, lmhlo::LmhloDialect,
314         linalg::LinalgDialect, math::MathDialect, vector::VectorDialect>();
315     target.addLegalOp<func::FuncOp, ModuleOp>();
316 
317     target.addIllegalDialect<mhlo::MhloDialect>();
318     target.addIllegalOp<tensor::GenerateOp, tensor::ExtractOp,
319                         tensor::FromElementsOp, tensor::CastOp, tensor::DimOp,
320                         tensor::RankOp, chlo::MinimumBroadcastShapesOp,
321                         bufferization::ToTensorOp, bufferization::ToMemrefOp,
322                         tensor::ExpandShapeOp, tensor::CollapseShapeOp>();
323     CustomBufferizeTypeConverter converter;
324     auto typesAreLegal = [&converter](Operation* op) {
325       return converter.isLegal(op->getOperandTypes()) &&
326              converter.isLegal(op->getResultTypes());
327     };
328     target.addDynamicallyLegalOp<func::ConstantOp, arith::ConstantOp,
329                                  arith::IndexCastOp, arith::SelectOp,
330                                  gml_st::LoopOp, gml_st::YieldOp>(
331         typesAreLegal);
332 
333     RewritePatternSet patterns(&getContext());
334     populateEliminateBufferizeMaterializationsPatterns(converter, patterns);
335     populateExtraBufferizePatterns(&getContext(), &converter, &patterns);
336     scf::populateSCFStructuralTypeConversionsAndLegality(converter, patterns,
337                                                          target);
338     if (patternsCallback)
339       patternsCallback(target, &getContext(), &converter, &patterns);
340 
341     return applyFullConversion(getOperation(), target, std::move(patterns));
342   }
343 };
344 
345 }  // namespace
346 
347 namespace hlo {
createOneShotBufferizePass()348 std::unique_ptr<OperationPass<ModuleOp>> createOneShotBufferizePass() {
349   return std::make_unique<OneShotBufferizePass>();
350 }
351 }  // namespace hlo
352 
createComputeOpAndFuncBufferizePass()353 std::unique_ptr<OperationPass<ModuleOp>> createComputeOpAndFuncBufferizePass() {
354   return std::make_unique<ComputeOpAndFuncBufferizePass>();
355 }
356 
createFinalBufferizePass()357 std::unique_ptr<OperationPass<ModuleOp>> createFinalBufferizePass() {
358   return std::make_unique<FinalBufferizePass>();
359 }
360 
createFinalBufferizePass(uint64_t alignment,BufferizeDialectsCallback dc,BufferizePatternsCallback pc)361 std::unique_ptr<OperationPass<ModuleOp>> createFinalBufferizePass(
362     uint64_t alignment, BufferizeDialectsCallback dc,
363     BufferizePatternsCallback pc) {
364   auto pass = std::make_unique<FinalBufferizePass>(alignment);
365   pass->setCallbacks(std::move(dc), std::move(pc));
366   return pass;
367 }
368 
369 }  // namespace mlir
370