1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This file implements conversion of `gml_st.loop` to buffer form.
17 
18 #include <utility>
19 
20 #include "mlir-hlo/Dialect/gml_st/IR/gml_st_ops.h"
21 #include "mlir-hlo/Dialect/gml_st/transforms/pass_detail.h"
22 #include "mlir-hlo/Dialect/gml_st/transforms/passes.h"
23 #include "mlir-hlo/Dialect/gml_st/transforms/rewriters.h"
24 #include "mlir-hlo/Dialect/lhlo/IR/lhlo_ops.h"
25 #include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
26 #include "mlir-hlo/Dialect/mhlo/transforms/type_conversion.h"
27 #include "mlir/Dialect/Affine/IR/AffineOps.h"
28 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
29 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
30 #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
31 #include "mlir/Dialect/Complex/IR/Complex.h"
32 #include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
33 #include "mlir/Dialect/Func/IR/FuncOps.h"
34 #include "mlir/Dialect/Func/Transforms/FuncConversions.h"
35 #include "mlir/Dialect/Func/Transforms/Passes.h"
36 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
37 #include "mlir/Dialect/Linalg/IR/Linalg.h"
38 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
39 #include "mlir/Dialect/Math/IR/Math.h"
40 #include "mlir/Dialect/MemRef/IR/MemRef.h"
41 #include "mlir/Dialect/SCF/IR/SCF.h"
42 #include "mlir/Dialect/SCF/Transforms/Transforms.h"
43 #include "mlir/Dialect/Shape/IR/Shape.h"
44 #include "mlir/Dialect/Tensor/IR/Tensor.h"
45 #include "mlir/Dialect/Tensor/Transforms/Passes.h"
46 #include "mlir/Dialect/Vector/IR/VectorOps.h"
47 #include "mlir/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.h"
48 #include "mlir/IR/Attributes.h"
49 #include "mlir/IR/BlockAndValueMapping.h"
50 #include "mlir/IR/BuiltinOps.h"
51 #include "mlir/IR/BuiltinTypes.h"
52 #include "mlir/IR/ImplicitLocOpBuilder.h"
53 #include "mlir/IR/MLIRContext.h"
54 #include "mlir/IR/Operation.h"
55 #include "mlir/IR/PatternMatch.h"
56 #include "mlir/IR/Visitors.h"
57 #include "mlir/Transforms/DialectConversion.h"
58 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
59 
60 namespace mlir {
61 namespace {
62 
63 using bufferization::ToMemrefOp;
64 using bufferization::ToTensorOp;
65 using gml_st::LoopOp;
66 using linalg::InitTensorOp;
67 using memref::SubViewOp;
68 using tensor::ExtractSliceOp;
69 using tensor::InsertSliceOp;
70 using vector::TransferReadOp;
71 using vector::TransferWriteOp;
72 
materializeToTensor(OpBuilder & builder,TensorType type,ValueRange inputs,Location loc)73 static Value materializeToTensor(OpBuilder &builder, TensorType type,
74                                  ValueRange inputs, Location loc) {
75   assert(inputs.size() == 1);
76   assert(inputs[0].getType().isa<BaseMemRefType>());
77   return builder.create<bufferization::ToTensorOp>(loc, type, inputs[0]);
78 }
79 
80 // TODO(pifon): Remove as soon as https://reviews.llvm.org/D93126 is landed.
81 class CustomBufferizeTypeConverter
82     : public bufferization::BufferizeTypeConverter {
83  public:
CustomBufferizeTypeConverter()84   CustomBufferizeTypeConverter() {
85     // Keep all types unchanged.
86     addConversion([](Type type) { return type; });
87     // Convert RankedTensorType to MemRefType.
88     addConversion([](RankedTensorType type) -> Type {
89       return MemRefType::get(type.getShape(), type.getElementType());
90     });
91     // Convert UnrankedTensorType to UnrankedMemRefType.
92     addConversion([](UnrankedTensorType type) -> Type {
93       return UnrankedMemRefType::get(type.getElementType(), 0);
94     });
95     addArgumentMaterialization(materializeToTensor);
96     addSourceMaterialization(materializeToTensor);
97     addTargetMaterialization([](OpBuilder &builder, BaseMemRefType type,
98                                 ValueRange inputs, Location loc) -> Value {
99       assert(inputs.size() == 1);
100       // Target materialization is invoked if the new operand type does not
101       // match the expected type. A special case is when the new operand type is
102       // a memref with a specified layout, i.e. non-empty affine map.
103       // TODO(pifon) : Change how target materialization is invoked in dialect
104       // conversion.
105       if (auto memrefType = inputs[0].getType().dyn_cast<MemRefType>()) {
106         assert(!memrefType.getLayout().isIdentity());
107         return inputs[0];
108       }
109       assert(inputs[0].getType().isa<TensorType>());
110       return builder.create<bufferization::ToMemrefOp>(loc, type, inputs[0]);
111     });
112   }
113 };
114 
115 /// Convert `tensor.extract_slice` to `memref.subview` in-place.
116 struct BufferizeExtractSliceOp : public OpConversionPattern<ExtractSliceOp> {
117   using OpConversionPattern<ExtractSliceOp>::OpConversionPattern;
118 
matchAndRewritemlir::__anon293c5c8c0111::BufferizeExtractSliceOp119   LogicalResult matchAndRewrite(
120       ExtractSliceOp op, OpAdaptor adaptor,
121       ConversionPatternRewriter &rewriter) const final {
122     if (!op->getParentOfType<LoopOp>()) return failure();
123 
124     rewriter.replaceOpWithNewOp<SubViewOp>(
125         op, adaptor.getSource(), op.getMixedOffsets(), op.getMixedSizes(),
126         op.getMixedStrides());
127     return success();
128   }
129 };
130 
131 /// Convert `linalg.init_tensor` of `memref.alloc`.
132 struct BufferizeInitTensorOp : public OpConversionPattern<InitTensorOp> {
133   using OpConversionPattern<InitTensorOp>::OpConversionPattern;
134 
matchAndRewritemlir::__anon293c5c8c0111::BufferizeInitTensorOp135   LogicalResult matchAndRewrite(
136       InitTensorOp op, OpAdaptor adaptor,
137       ConversionPatternRewriter &rewriter) const final {
138     if (!op->getParentOfType<LoopOp>()) return failure();
139 
140     rewriter.replaceOpWithNewOp<memref::AllocOp>(
141         op, getTypeConverter()->convertType(op.getType()).cast<MemRefType>(),
142         adaptor.sizes());
143     return success();
144   }
145 };
146 
isBlockArgOfTiledLoop(Value value)147 bool isBlockArgOfTiledLoop(Value value) {
148   if (auto blockArg = value.dyn_cast<BlockArgument>())
149     return isa<LoopOp>(blockArg.getOwner()->getParentOp());
150   return false;
151 }
152 
153 // Attempts to find an existing `memref.subview` of `destMemRef` in the tiled
154 // loop. The assumption is that in `gml_st.loop` the tile of the output
155 // tensor that we read and the tile that we write to are the same.
findExistingSubview(Value destMemRef)156 Value findExistingSubview(Value destMemRef) {
157   if (auto toMemref = destMemRef.getDefiningOp<ToMemrefOp>()) {
158     if (auto toTensor = toMemref.getTensor().getDefiningOp<ToTensorOp>()) {
159       if (!isBlockArgOfTiledLoop(toTensor.getMemref())) return Value{};
160       // Scan through users of the block argument to find `subview` op.
161       for (Operation *tensorUser : toMemref.getTensor().getUsers()) {
162         if (auto anotherCast = mlir::dyn_cast<ToMemrefOp>(tensorUser)) {
163           for (Operation *memrefUser : anotherCast.getMemref().getUsers()) {
164             if (auto subview = mlir::dyn_cast<SubViewOp>(memrefUser)) {
165               if (subview.getSource() == destMemRef) return subview;
166             }
167           }
168         }
169       }
170     }
171   }
172   return Value{};
173 }
174 
175 /// Convert `tensor.insert_slice` to `memref.subview` in-place.
176 struct BufferizeInsertSliceOp : public OpConversionPattern<InsertSliceOp> {
177  public:
178   using OpConversionPattern<InsertSliceOp>::OpConversionPattern;
179 
matchAndRewritemlir::__anon293c5c8c0111::BufferizeInsertSliceOp180   LogicalResult matchAndRewrite(
181       InsertSliceOp op, OpAdaptor adaptor,
182       ConversionPatternRewriter &rewriter) const final {
183     Value sourceMemRef = adaptor.getSource();
184     assert(sourceMemRef.getType().isa<MemRefType>());
185 
186     Value destMemRef = adaptor.getDest();
187     assert(destMemRef.getType().isa<MemRefType>());
188 
189     if (!op->getParentOfType<LoopOp>()) return failure();
190 
191     Value subview = findExistingSubview(destMemRef);
192     if (!subview) {
193       subview = rewriter.create<SubViewOp>(
194           op.getLoc(), destMemRef, op.getMixedOffsets(), op.getMixedSizes(),
195           op.getMixedStrides());
196     }
197     rewriter.create<memref::CopyOp>(op.getLoc(), sourceMemRef, subview);
198     rewriter.replaceOp(op, destMemRef);
199     return success();
200   }
201 };
202 
203 /// Create linalg op on buffers given the original tensor-based operation and
204 /// the buffers for the outputs.
createLinalgOpOnBuffers(ConversionPatternRewriter & rewriter,linalg::LinalgOp linalgOp,ValueRange inputs,ValueRange outputs)205 linalg::LinalgOp createLinalgOpOnBuffers(ConversionPatternRewriter &rewriter,
206                                          linalg::LinalgOp linalgOp,
207                                          ValueRange inputs,
208                                          ValueRange outputs) {
209   SmallVector<Value, 8> newOperands = inputs;
210   newOperands.append(outputs.begin(), outputs.end());
211   auto *newOp = linalgOp.cloneWithoutRegions(rewriter, linalgOp.getLoc(),
212                                              /*resultTypes=*/ArrayRef<Type>{},
213                                              newOperands);
214   for (auto regions : llvm::zip(linalgOp->getRegions(), newOp->getRegions())) {
215     auto &oldRegion = std::get<0>(regions);
216     auto &newRegion = std::get<1>(regions);
217     rewriter.inlineRegionBefore(oldRegion, newRegion, newRegion.begin());
218   }
219   return newOp;
220 }
221 
222 /// Get a variadic operand segment.
getVariadicOperands(DenseI32ArrayAttr sizeAttr,ValueRange operands,unsigned index)223 ValueRange getVariadicOperands(DenseI32ArrayAttr sizeAttr,
224                                ValueRange operands, unsigned index) {
225   const int32_t *sizeIt = &*sizeAttr.value_begin<int32_t>();
226   if (sizeAttr.isSplat()) return operands.slice(*sizeIt * index, *sizeIt);
227 
228   unsigned start = 0;
229   for (unsigned i = 0; i < index; ++i) start += sizeIt[i];
230   return operands.slice(start, sizeIt[index]);
231 }
232 
233 // Bufferize LinalgOps in-place.
234 struct BufferizeLinalgOp
235     : public OpInterfaceConversionPattern<linalg::LinalgOp> {
236   using OpInterfaceConversionPattern<
237       linalg::LinalgOp>::OpInterfaceConversionPattern;
238 
matchAndRewritemlir::__anon293c5c8c0111::BufferizeLinalgOp239   LogicalResult matchAndRewrite(
240       linalg::LinalgOp op, ArrayRef<Value> operands,
241       ConversionPatternRewriter &rewriter) const final {
242     if (!op->getParentOfType<LoopOp>()) return failure();
243 
244     // An op with two variadic operand groups expects a segment size attribute.
245     auto operandSegments =
246         op->getAttrOfType<DenseI32ArrayAttr>("operand_segment_sizes");
247     if (!operandSegments) return failure();
248 
249     const auto getOperands = [&](unsigned index) {
250       return getVariadicOperands(operandSegments, operands, index);
251     };
252     createLinalgOpOnBuffers(rewriter, op, getOperands(0), getOperands(1));
253     rewriter.replaceOp(op, getOperands(1));
254     return success();
255   }
256 };
257 
258 // Convert `gml_st.yield` terminator of `gml_st.loop` to `gml_st.yield` with no
259 // arguments.
260 struct BufferizeLinalgYieldOp : public OpConversionPattern<gml_st::YieldOp> {
261   using OpConversionPattern<gml_st::YieldOp>::OpConversionPattern;
262 
matchAndRewritemlir::__anon293c5c8c0111::BufferizeLinalgYieldOp263   LogicalResult matchAndRewrite(
264       gml_st::YieldOp op, OpAdaptor adaptor,
265       ConversionPatternRewriter &rewriter) const final {
266     if (!mlir::dyn_cast<LoopOp>(op->getParentOp()) ||
267         adaptor.getOperands().empty())
268       return failure();
269 
270     rewriter.replaceOpWithNewOp<gml_st::YieldOp>(op);
271     return success();
272   }
273 };
274 
275 // FuncOp-like bufferization pattern for `gml_st.loop` that inserts
276 // `memref.tensor_load` ops for every memref block argument.
277 //
278 // TODO(b/230082413): This code has to go away if we migrate to one-shot
279 // bufferization.
280 struct BufferizeLoopOp : public OpConversionPattern<LoopOp> {
281   using OpConversionPattern::OpConversionPattern;
282 
matchAndRewritemlir::__anon293c5c8c0111::BufferizeLoopOp283   LogicalResult matchAndRewrite(
284       LoopOp op, OpAdaptor adaptor,
285       ConversionPatternRewriter &rewriter) const override {
286     if (op.getNumResults() == 0) return failure();
287 
288     // Allocate new buffers for results if it is used by multiple uses.
289     SmallVector<Value, 4> operands = adaptor.getOperands();
290     for (auto &en : llvm::enumerate(op.outputs())) {
291       Value output = en.value();
292 
293       auto toTensor = output.getDefiningOp<bufferization::ToTensorOp>();
294       if (!toTensor || toTensor->hasOneUse()) continue;
295 
296       auto alloc = toTensor.getMemref().getDefiningOp<memref::AllocOp>();
297       if (!alloc) continue;
298 
299       OpBuilder::InsertionGuard g(rewriter);
300       rewriter.setInsertionPoint(op);
301       auto *newAlloc = rewriter.clone(*alloc.getOperation());
302       operands[op.getNumControlOperands() + op.getNumInputs() + en.index()] =
303           newAlloc->getResult(0);
304     }
305 
306     SmallVector<NamedAttribute> attrList;
307     for (auto &item : adaptor.getAttributes()) {
308       attrList.push_back(item);
309     }
310     auto newOp = rewriter.create<LoopOp>(op.getLoc(), mlir::TypeRange{},
311                                          operands, attrList);
312     // Take the region from the old op and put it in the new op.
313     rewriter.inlineRegionBefore(op.getLoopBody(), newOp.getLoopBody(),
314                                 newOp.getLoopBody().end());
315 
316     // Convert the type of the entry block of the LoopOp's body.
317     if (failed(rewriter.convertRegionTypes(&newOp.getLoopBody(),
318                                            *getTypeConverter()))) {
319       return rewriter.notifyMatchFailure(op, "could not convert body types");
320     }
321 
322     rewriter.replaceOp(op, newOp.outputs());
323     return success();
324   }
325 };
326 
327 // TODO(b/199045477): The pattern for vector.transfer_read/write have to be
328 // moved out of Linalg bufferization to a VectorOps bufferization pass.
329 struct BufferizeVectorTransferReadOp
330     : public OpConversionPattern<vector::TransferReadOp> {
331   using OpConversionPattern<vector::TransferReadOp>::OpConversionPattern;
332 
matchAndRewritemlir::__anon293c5c8c0111::BufferizeVectorTransferReadOp333   LogicalResult matchAndRewrite(
334       vector::TransferReadOp readOp, OpAdaptor adaptor,
335       ConversionPatternRewriter &rewriter) const final {
336     if (readOp.getShapedType().isa<MemRefType>()) return failure();
337     rewriter.replaceOpWithNewOp<vector::TransferReadOp>(
338         readOp, readOp.getType(), adaptor.getSource(), adaptor.getIndices(),
339         adaptor.getPermutationMapAttr(), adaptor.getPadding(),
340         adaptor.getMask(),
341         adaptor.getInBounds() ? adaptor.getInBoundsAttr() : ArrayAttr());
342     return success();
343   }
344 };
345 
346 struct BufferizeVectorTransferWriteOp
347     : public OpConversionPattern<vector::TransferWriteOp> {
348   using OpConversionPattern<vector::TransferWriteOp>::OpConversionPattern;
349 
matchAndRewritemlir::__anon293c5c8c0111::BufferizeVectorTransferWriteOp350   LogicalResult matchAndRewrite(
351       vector::TransferWriteOp writeOp, OpAdaptor adaptor,
352       ConversionPatternRewriter &rewriter) const final {
353     if (writeOp.getShapedType().isa<MemRefType>()) return failure();
354     rewriter.create<vector::TransferWriteOp>(
355         writeOp.getLoc(), adaptor.getVector(), adaptor.getSource(),
356         adaptor.getIndices(), adaptor.getPermutationMapAttr(),
357         adaptor.getInBounds() ? adaptor.getInBoundsAttr() : ArrayAttr());
358     rewriter.replaceOp(writeOp, adaptor.getSource());
359     return success();
360   }
361 };
362 
363 }  // namespace
364 
365 namespace gml_st {
366 struct TiledLoopBufferizePass
367     : public TiledLoopBufferizePassBase<TiledLoopBufferizePass> {
getDependentDialectsmlir::gml_st::TiledLoopBufferizePass368   void getDependentDialects(DialectRegistry &registry) const override {
369     registry.insert<memref::MemRefDialect>();
370   }
371 
runOnOperationmlir::gml_st::TiledLoopBufferizePass372   void runOnOperation() override {
373     // Bufferize ops using BufferizableOpInterface. This could be switched to
374     // One-Shot Bufferize in the future.
375     mlir::RewritePatternSet patterns(&getContext());
376     mlir::bufferization::BufferizationOptions options =
377         mlir::bufferization::getPartialBufferizationOptions();
378     // TODO(springerm): Add dialects to this filter as more and more dialects
379     // will be migrated to BufferizableOpInterface-based bufferization.
380     options.opFilter.allowDialect<shape::ShapeDialect>();
381     if (failed(mlir::bufferization::bufferizeOp(getOperation(), options))) {
382       signalPassFailure();
383       return;
384     }
385 
386     // Bufferize the remaining IR with dialect conversion. This will disappear
387     // eventually once all bufferization is done via BufferizableOpInterface.
388     if (failed(runDialectConversionBasedBufferization())) signalPassFailure();
389   }
390 
391  private:
runDialectConversionBasedBufferizationmlir::gml_st::TiledLoopBufferizePass392   LogicalResult runDialectConversionBasedBufferization() {
393     mlir::RewritePatternSet patterns(&getContext());
394     auto &context = getContext();
395     ConversionTarget target(context);
396     target.addLegalDialect<
397         mlir::arith::ArithmeticDialect,
398         mlir::bufferization::BufferizationDialect,
399         mlir::complex::ComplexDialect, mlir::lmhlo::LmhloDialect,
400         mlir::AffineDialect, mlir::vector::VectorDialect,
401         mlir::memref::MemRefDialect, mlir::func::FuncDialect,
402         mlir::tensor::TensorDialect, mlir::math::MathDialect>();
403     target.addLegalOp<UnrealizedConversionCastOp>();
404     target.addIllegalDialect<mhlo::MhloDialect>();
405     target.addIllegalOp<tensor::ExtractSliceOp, tensor::InsertSliceOp>();
406 
407     CustomBufferizeTypeConverter converter;
408     mlir::mhlo::RemoveSignTypeConverter removeSignConverter;
409 
410     // Configure bufferize pattern.
411     populateCallOpTypeConversionPattern(patterns, converter);
412     populateBranchOpInterfaceTypeConversionPattern(patterns, converter);
413     populateReturnOpTypeConversionPattern(patterns, converter);
414     mlir::bufferization::populateBufferizeMaterializationLegality(target);
415     populateTiledLoopBufferizePattern(&getContext(), &converter, &patterns);
416     mlir::scf::populateSCFStructuralTypeConversionsAndLegality(
417         converter, patterns, target);
418     // Configure legality.
419     auto isLegalOp = [&](Operation *op) { return converter.isLegal(op); };
420     target.addDynamicallyLegalDialect<mlir::linalg::LinalgDialect>(isLegalOp);
421     target.addDynamicallyLegalOp<mlir::func::CallOp, gml_st::LoopOp,
422                                  gml_st::YieldOp, mlir::LLVM::InlineAsmOp,
423                                  mlir::vector::TransferWriteOp,
424                                  mlir::vector::TransferReadOp>(isLegalOp);
425 
426     return applyPartialConversion(getOperation(), target, std::move(patterns));
427   }
428 };
429 
populateTiledLoopBufferizePattern(mlir::MLIRContext * context,mlir::bufferization::BufferizeTypeConverter * converter,mlir::RewritePatternSet * patterns)430 void populateTiledLoopBufferizePattern(
431     mlir::MLIRContext *context,
432     mlir::bufferization::BufferizeTypeConverter *converter,
433     mlir::RewritePatternSet *patterns) {
434   // clang-format off
435   patterns->add<
436     BufferizeExtractSliceOp,
437     BufferizeInitTensorOp,
438     BufferizeInsertSliceOp,
439     BufferizeLinalgOp,
440     BufferizeLinalgYieldOp,
441     BufferizeLoopOp,
442     BufferizeVectorTransferReadOp,
443     BufferizeVectorTransferWriteOp
444   >(*converter, context);
445   // clang-format on
446 }
447 
CreateTiledLoopBufferizePass()448 std::unique_ptr<OperationPass<func::FuncOp>> CreateTiledLoopBufferizePass() {
449   return std::make_unique<TiledLoopBufferizePass>();
450 }
451 
452 }  // namespace gml_st
453 }  // namespace mlir
454