1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // This file implements logic for translating mixed IR to buffer form.
17 // Currently it supports MHLO and some operations from the Standard dialect.
18
19 #include <cstdint>
20 #include <functional>
21 #include <memory>
22 #include <utility>
23
24 #include "llvm/ADT/STLExtras.h"
25 #include "llvm/Support/raw_ostream.h"
26 #include "mlir-hlo/Dialect/gml_st/IR/gml_st_ops.h"
27 #include "mlir-hlo/Dialect/gml_st/transforms/bufferizable_op_interface_impl.h"
28 #include "mlir-hlo/Dialect/lhlo/IR/lhlo_ops.h"
29 #include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
30 #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
31 #include "mlir-hlo/Dialect/mhlo/transforms/bufferizable_op_interface_impl.h"
32 #include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
33 #include "mlir-hlo/Dialect/mhlo/transforms/type_conversion.h"
34 #include "mlir-hlo/Transforms/PassDetail.h"
35 #include "mlir-hlo/Transforms/passes.h"
36 #include "mlir-hlo/Transforms/rewriters.h"
37 #include "mlir/Dialect/Affine/IR/AffineOps.h"
38 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
39 #include "mlir/Dialect/Arithmetic/Transforms/BufferizableOpInterfaceImpl.h"
40 #include "mlir/Dialect/Arithmetic/Transforms/Passes.h"
41 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
42 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
43 #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
44 #include "mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h"
45 #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
46 #include "mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h"
47 #include "mlir/Dialect/Complex/IR/Complex.h"
48 #include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
49 #include "mlir/Dialect/Func/IR/FuncOps.h"
50 #include "mlir/Dialect/Func/Transforms/FuncConversions.h"
51 #include "mlir/Dialect/Func/Transforms/Passes.h"
52 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
53 #include "mlir/Dialect/Linalg/IR/Linalg.h"
54 #include "mlir/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.h"
55 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
56 #include "mlir/Dialect/Math/IR/Math.h"
57 #include "mlir/Dialect/MemRef/IR/MemRef.h"
58 #include "mlir/Dialect/SCF/IR/SCF.h"
59 #include "mlir/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.h"
60 #include "mlir/Dialect/SCF/Transforms/Transforms.h"
61 #include "mlir/Dialect/Shape/IR/Shape.h"
62 #include "mlir/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.h"
63 #include "mlir/Dialect/Shape/Transforms/Passes.h"
64 #include "mlir/Dialect/Tensor/IR/Tensor.h"
65 #include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h"
66 #include "mlir/Dialect/Tensor/Transforms/Passes.h"
67 #include "mlir/Dialect/Vector/IR/VectorOps.h"
68 #include "mlir/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.h"
69 #include "mlir/IR/BuiltinOps.h"
70 #include "mlir/IR/BuiltinTypes.h"
71 #include "mlir/IR/MLIRContext.h"
72 #include "mlir/IR/Operation.h"
73 #include "mlir/IR/PatternMatch.h"
74 #include "mlir/IR/Visitors.h"
75 #include "mlir/Transforms/DialectConversion.h"
76 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
77
78 namespace mlir {
79 namespace {
80
81 /// A helper type converter class that automatically populates the relevant
82 /// materializations and type conversions for bufferization.
83
materializeToTensor(OpBuilder & builder,TensorType type,ValueRange inputs,Location loc)84 static Value materializeToTensor(OpBuilder& builder, TensorType type,
85 ValueRange inputs, Location loc) {
86 assert(inputs.size() == 1);
87 assert(inputs[0].getType().isa<BaseMemRefType>());
88 return builder.create<bufferization::ToTensorOp>(loc, type, inputs[0]);
89 }
90
91 // TODO(pifon): Remove as soon as https://reviews.llvm.org/D93126 is landed.
92 class CustomBufferizeTypeConverter
93 : public bufferization::BufferizeTypeConverter {
94 public:
CustomBufferizeTypeConverter()95 CustomBufferizeTypeConverter() {
96 // Keep all types unchanged.
97 addConversion([](Type type) { return type; });
98 // Convert RankedTensorType to MemRefType.
99 addConversion([](RankedTensorType type) -> Type {
100 return MemRefType::get(type.getShape(), type.getElementType());
101 });
102 // Convert UnrankedTensorType to UnrankedMemRefType.
103 addConversion([](UnrankedTensorType type) -> Type {
104 return UnrankedMemRefType::get(type.getElementType(), 0);
105 });
106 addArgumentMaterialization(materializeToTensor);
107 addSourceMaterialization(materializeToTensor);
108 addTargetMaterialization([](OpBuilder& builder, BaseMemRefType type,
109 ValueRange inputs, Location loc) -> Value {
110 assert(inputs.size() == 1);
111 // Target materialization is invoked if the new operand type does not
112 // match the expected type. A special case is when the new operand type is
113 // a memref with a specified layout, i.e. non-empty affine map.
114 // TODO(pifon) : Change how target materialization is invoked in dialect
115 // conversion.
116 if (auto memrefType = inputs[0].getType().dyn_cast<MemRefType>()) {
117 assert(!memrefType.getLayout().isIdentity());
118 return inputs[0];
119 }
120 assert(inputs[0].getType().isa<TensorType>());
121 return builder.create<bufferization::ToMemrefOp>(loc, type, inputs[0]);
122 });
123 }
124 };
125
126 struct ComputeOpAndFuncBufferizePass
127 : public ComputeOpAndFuncBufferizePassBase<ComputeOpAndFuncBufferizePass> {
getDependentDialectsmlir::__anone43020be0111::ComputeOpAndFuncBufferizePass128 void getDependentDialects(DialectRegistry& registry) const override {
129 registry
130 .insert<bufferization::BufferizationDialect, lmhlo::LmhloDialect,
131 linalg::LinalgDialect, memref::MemRefDialect, mhlo::MhloDialect,
132 shape::ShapeDialect, vector::VectorDialect>();
133 linalg::registerBufferizableOpInterfaceExternalModels(registry);
134 mhlo::registerBufferizableOpInterfaceExternalModels(registry);
135 shape::registerBufferizableOpInterfaceExternalModels(registry);
136 vector::registerBufferizableOpInterfaceExternalModels(registry);
137 }
138
runOnOperationmlir::__anone43020be0111::ComputeOpAndFuncBufferizePass139 void runOnOperation() override {
140 // Bufferize ops using BufferizableOpInterface. This could be switched to
141 // One-Shot Bufferize in the future.
142 bufferization::BufferizationOptions options =
143 bufferization::getPartialBufferizationOptions();
144 // TODO(springerm): Add dialects to this filter as more and more dialects
145 // will be migrated to BufferizableOpInterface-based bufferization.
146 options.opFilter.allowDialect<bufferization::BufferizationDialect,
147 linalg::LinalgDialect, mhlo::MhloDialect,
148 shape::ShapeDialect, tensor::TensorDialect,
149 vector::VectorDialect>();
150 // Ops inside TiledLoopOps have special handling.
151 options.opFilter.denyOperation([](Operation* op) {
152 return mlir::isa<gml_st::LoopOp>(op->getParentOp());
153 });
154
155 if (failed(bufferization::bufferizeOp(getOperation(), options))) {
156 signalPassFailure();
157 return;
158 }
159
160 // Bufferize the remaining IR with dialect conversion. This will disappear
161 // eventually once all bufferization is done via BufferizableOpInterface.
162 if (failed(runDialectConversionBasedBufferization())) signalPassFailure();
163 }
164
165 private:
runDialectConversionBasedBufferizationmlir::__anone43020be0111::ComputeOpAndFuncBufferizePass166 LogicalResult runDialectConversionBasedBufferization() {
167 RewritePatternSet patterns(&getContext());
168 auto& context = getContext();
169 ConversionTarget target(context);
170 target.addLegalDialect<
171 arith::ArithmeticDialect, complex::ComplexDialect, lmhlo::LmhloDialect,
172 AffineDialect, vector::VectorDialect, memref::MemRefDialect,
173 func::FuncDialect, tensor::TensorDialect, math::MathDialect>();
174 target.addLegalOp<UnrealizedConversionCastOp, gml_st::LoopOp>();
175 target.addIllegalDialect<mhlo::MhloDialect>();
176 target.addDynamicallyLegalOp<tensor::ExtractSliceOp, tensor::InsertSliceOp>(
177 [&](Operation* op) {
178 return mlir::isa<gml_st::LoopOp>(op->getParentOp());
179 });
180
181 CustomBufferizeTypeConverter converter;
182 populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns,
183 converter);
184 populateCallOpTypeConversionPattern(patterns, converter);
185 populateBranchOpInterfaceTypeConversionPattern(patterns, converter);
186 populateReturnOpTypeConversionPattern(patterns, converter);
187
188 // Configure legality and structural patterns.
189 bufferization::populateBufferizeMaterializationLegality(target);
190 scf::populateSCFStructuralTypeConversionsAndLegality(converter, patterns,
191 target);
192
193 // TODO(herhut): Move this legality configuration to bufferize itself?
194 target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
195 auto inputs = op.getFunctionType().getInputs();
196 auto results = op.getFunctionType().getResults();
197 return converter.isLegal(inputs) && converter.isLegal(results) &&
198 converter.isLegal(&op.getBody());
199 });
200 auto isLegalOp = [&](Operation* op) { return converter.isLegal(op); };
201 target.addDynamicallyLegalOp<func::CallOp, func::ReturnOp>(isLegalOp);
202
203 auto isLegalOrInsideTiledLoop = [&](Operation* op) {
204 return converter.isLegal(op) ||
205 mlir::isa<gml_st::LoopOp>(op->getParentOp());
206 };
207 target.addDynamicallyLegalDialect<linalg::LinalgDialect>(
208 isLegalOrInsideTiledLoop);
209 target
210 .addDynamicallyLegalOp<vector::TransferWriteOp, vector::TransferReadOp>(
211 isLegalOrInsideTiledLoop);
212
213 return applyPartialConversion(getOperation(), target, std::move(patterns));
214 }
215 };
216
217 struct OneShotBufferizePass
218 : public OneShotBufferizeBase<OneShotBufferizePass> {
219 // TODO(b/173201243): Move to tablegen.
getDependentDialectsmlir::__anone43020be0111::OneShotBufferizePass220 void getDependentDialects(DialectRegistry& registry) const override {
221 registry
222 .insert<bufferization::BufferizationDialect, lmhlo::LmhloDialect,
223 linalg::LinalgDialect, memref::MemRefDialect, mhlo::MhloDialect,
224 scf::SCFDialect, shape::ShapeDialect, vector::VectorDialect>();
225 arith::registerBufferizableOpInterfaceExternalModels(registry);
226 bufferization::func_ext::registerBufferizableOpInterfaceExternalModels(
227 registry);
228 gml_st::registerBufferizableOpInterfaceExternalModels(registry);
229 linalg::registerBufferizableOpInterfaceExternalModels(registry);
230 mhlo::registerBufferizableOpInterfaceExternalModels(registry);
231 scf::registerBufferizableOpInterfaceExternalModels(registry);
232 shape::registerBufferizableOpInterfaceExternalModels(registry);
233 tensor::registerBufferizableOpInterfaceExternalModels(registry);
234 vector::registerBufferizableOpInterfaceExternalModels(registry);
235 }
236
runOnOperationmlir::__anone43020be0111::OneShotBufferizePass237 void runOnOperation() override {
238 bufferization::OneShotBufferizationOptions opts;
239 opts.allowReturnAllocs = true;
240 opts.bufferizeFunctionBoundaries = true;
241 opts.functionBoundaryTypeConversion =
242 bufferization::BufferizationOptions::LayoutMapOption::IdentityLayoutMap;
243 opts.createDeallocs = false;
244 opts.bufferAlignment = 64;
245
246 ModuleOp module = getOperation();
247 if (failed(bufferization::runOneShotModuleBufferize(module, opts))) {
248 signalPassFailure();
249 }
250 }
251 };
252
253 struct FinalBufferizePass : public FinalBufferizePassBase<FinalBufferizePass> {
254 private:
255 BufferizeDialectsCallback dialectsCallback;
256 BufferizePatternsCallback patternsCallback;
257
258 public:
getDependentDialectsmlir::__anone43020be0111::FinalBufferizePass259 void getDependentDialects(DialectRegistry& registry) const override {
260 registry
261 .insert<AffineDialect, bufferization::BufferizationDialect,
262 linalg::LinalgDialect, memref::MemRefDialect, scf::SCFDialect,
263 shape::ShapeDialect, tensor::TensorDialect, lmhlo::LmhloDialect,
264 arith::ArithmeticDialect, vector::VectorDialect>();
265 arith::registerBufferizableOpInterfaceExternalModels(registry);
266 linalg::registerBufferizableOpInterfaceExternalModels(registry);
267 shape::registerBufferizableOpInterfaceExternalModels(registry);
268 tensor::registerBufferizableOpInterfaceExternalModels(registry);
269 vector::registerBufferizableOpInterfaceExternalModels(registry);
270 if (dialectsCallback) dialectsCallback(registry);
271 }
272 // Default alignment_ specified in passes.td
273 FinalBufferizePass() = default;
274
FinalBufferizePassmlir::__anone43020be0111::FinalBufferizePass275 explicit FinalBufferizePass(uint64_t alignment) { alignment_ = alignment; }
276
setCallbacksmlir::__anone43020be0111::FinalBufferizePass277 void setCallbacks(BufferizeDialectsCallback dc,
278 BufferizePatternsCallback pc) {
279 dialectsCallback = std::move(dc);
280 patternsCallback = std::move(pc);
281 }
282
runOnOperationmlir::__anone43020be0111::FinalBufferizePass283 void runOnOperation() override {
284 // Bufferize ops using BufferizableOpInterface. This could be switched to
285 // One-Shot Bufferize in the future.
286 bufferization::BufferizationOptions options =
287 bufferization::getPartialBufferizationOptions();
288 options.bufferAlignment = alignment_;
289 // TODO(springerm): Add dialects to this filter as more and more dialects
290 // will be migrated to BufferizableOpInterface-based bufferization.
291 options.opFilter.allowDialect<
292 arith::ArithmeticDialect, bufferization::BufferizationDialect,
293 linalg::LinalgDialect, func::FuncDialect, shape::ShapeDialect,
294 tensor::TensorDialect, vector::VectorDialect>();
295 if (failed(bufferization::bufferizeOp(getOperation(), options))) {
296 signalPassFailure();
297 return;
298 }
299
300 // Bufferize the remaining IR with dialect conversion. This will disappear
301 // eventually once all bufferization is done via BufferizableOpInterface.
302 if (failed(runDialectConversionBasedBufferization())) signalPassFailure();
303 }
304
305 private:
runDialectConversionBasedBufferizationmlir::__anone43020be0111::FinalBufferizePass306 LogicalResult runDialectConversionBasedBufferization() {
307 auto& context = getContext();
308 ConversionTarget target(context);
309 target.addLegalDialect<
310 arith::ArithmeticDialect, bufferization::BufferizationDialect,
311 cf::ControlFlowDialect, complex::ComplexDialect, memref::MemRefDialect,
312 func::FuncDialect, scf::SCFDialect, tensor::TensorDialect,
313 AffineDialect, shape::ShapeDialect, lmhlo::LmhloDialect,
314 linalg::LinalgDialect, math::MathDialect, vector::VectorDialect>();
315 target.addLegalOp<func::FuncOp, ModuleOp>();
316
317 target.addIllegalDialect<mhlo::MhloDialect>();
318 target.addIllegalOp<tensor::GenerateOp, tensor::ExtractOp,
319 tensor::FromElementsOp, tensor::CastOp, tensor::DimOp,
320 tensor::RankOp, chlo::MinimumBroadcastShapesOp,
321 bufferization::ToTensorOp, bufferization::ToMemrefOp,
322 tensor::ExpandShapeOp, tensor::CollapseShapeOp>();
323 CustomBufferizeTypeConverter converter;
324 auto typesAreLegal = [&converter](Operation* op) {
325 return converter.isLegal(op->getOperandTypes()) &&
326 converter.isLegal(op->getResultTypes());
327 };
328 target.addDynamicallyLegalOp<func::ConstantOp, arith::ConstantOp,
329 arith::IndexCastOp, arith::SelectOp,
330 gml_st::LoopOp, gml_st::YieldOp>(
331 typesAreLegal);
332
333 RewritePatternSet patterns(&getContext());
334 populateEliminateBufferizeMaterializationsPatterns(converter, patterns);
335 populateExtraBufferizePatterns(&getContext(), &converter, &patterns);
336 scf::populateSCFStructuralTypeConversionsAndLegality(converter, patterns,
337 target);
338 if (patternsCallback)
339 patternsCallback(target, &getContext(), &converter, &patterns);
340
341 return applyFullConversion(getOperation(), target, std::move(patterns));
342 }
343 };
344
345 } // namespace
346
347 namespace hlo {
createOneShotBufferizePass()348 std::unique_ptr<OperationPass<ModuleOp>> createOneShotBufferizePass() {
349 return std::make_unique<OneShotBufferizePass>();
350 }
351 } // namespace hlo
352
createComputeOpAndFuncBufferizePass()353 std::unique_ptr<OperationPass<ModuleOp>> createComputeOpAndFuncBufferizePass() {
354 return std::make_unique<ComputeOpAndFuncBufferizePass>();
355 }
356
createFinalBufferizePass()357 std::unique_ptr<OperationPass<ModuleOp>> createFinalBufferizePass() {
358 return std::make_unique<FinalBufferizePass>();
359 }
360
createFinalBufferizePass(uint64_t alignment,BufferizeDialectsCallback dc,BufferizePatternsCallback pc)361 std::unique_ptr<OperationPass<ModuleOp>> createFinalBufferizePass(
362 uint64_t alignment, BufferizeDialectsCallback dc,
363 BufferizePatternsCallback pc) {
364 auto pass = std::make_unique<FinalBufferizePass>(alignment);
365 pass->setCallbacks(std::move(dc), std::move(pc));
366 return pass;
367 }
368
369 } // namespace mlir
370