1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // This file implements conversion of `gml_st.loop` to buffer form.
17
18 #include <utility>
19
20 #include "mlir-hlo/Dialect/gml_st/IR/gml_st_ops.h"
21 #include "mlir-hlo/Dialect/gml_st/transforms/pass_detail.h"
22 #include "mlir-hlo/Dialect/gml_st/transforms/passes.h"
23 #include "mlir-hlo/Dialect/gml_st/transforms/rewriters.h"
24 #include "mlir-hlo/Dialect/lhlo/IR/lhlo_ops.h"
25 #include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
26 #include "mlir-hlo/Dialect/mhlo/transforms/type_conversion.h"
27 #include "mlir/Dialect/Affine/IR/AffineOps.h"
28 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
29 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
30 #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
31 #include "mlir/Dialect/Complex/IR/Complex.h"
32 #include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
33 #include "mlir/Dialect/Func/IR/FuncOps.h"
34 #include "mlir/Dialect/Func/Transforms/FuncConversions.h"
35 #include "mlir/Dialect/Func/Transforms/Passes.h"
36 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
37 #include "mlir/Dialect/Linalg/IR/Linalg.h"
38 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
39 #include "mlir/Dialect/Math/IR/Math.h"
40 #include "mlir/Dialect/MemRef/IR/MemRef.h"
41 #include "mlir/Dialect/SCF/IR/SCF.h"
42 #include "mlir/Dialect/SCF/Transforms/Transforms.h"
43 #include "mlir/Dialect/Shape/IR/Shape.h"
44 #include "mlir/Dialect/Tensor/IR/Tensor.h"
45 #include "mlir/Dialect/Tensor/Transforms/Passes.h"
46 #include "mlir/Dialect/Vector/IR/VectorOps.h"
47 #include "mlir/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.h"
48 #include "mlir/IR/Attributes.h"
49 #include "mlir/IR/BlockAndValueMapping.h"
50 #include "mlir/IR/BuiltinOps.h"
51 #include "mlir/IR/BuiltinTypes.h"
52 #include "mlir/IR/ImplicitLocOpBuilder.h"
53 #include "mlir/IR/MLIRContext.h"
54 #include "mlir/IR/Operation.h"
55 #include "mlir/IR/PatternMatch.h"
56 #include "mlir/IR/Visitors.h"
57 #include "mlir/Transforms/DialectConversion.h"
58 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
59
60 namespace mlir {
61 namespace {
62
63 using bufferization::ToMemrefOp;
64 using bufferization::ToTensorOp;
65 using gml_st::LoopOp;
66 using linalg::InitTensorOp;
67 using memref::SubViewOp;
68 using tensor::ExtractSliceOp;
69 using tensor::InsertSliceOp;
70 using vector::TransferReadOp;
71 using vector::TransferWriteOp;
72
materializeToTensor(OpBuilder & builder,TensorType type,ValueRange inputs,Location loc)73 static Value materializeToTensor(OpBuilder &builder, TensorType type,
74 ValueRange inputs, Location loc) {
75 assert(inputs.size() == 1);
76 assert(inputs[0].getType().isa<BaseMemRefType>());
77 return builder.create<bufferization::ToTensorOp>(loc, type, inputs[0]);
78 }
79
80 // TODO(pifon): Remove as soon as https://reviews.llvm.org/D93126 is landed.
81 class CustomBufferizeTypeConverter
82 : public bufferization::BufferizeTypeConverter {
83 public:
CustomBufferizeTypeConverter()84 CustomBufferizeTypeConverter() {
85 // Keep all types unchanged.
86 addConversion([](Type type) { return type; });
87 // Convert RankedTensorType to MemRefType.
88 addConversion([](RankedTensorType type) -> Type {
89 return MemRefType::get(type.getShape(), type.getElementType());
90 });
91 // Convert UnrankedTensorType to UnrankedMemRefType.
92 addConversion([](UnrankedTensorType type) -> Type {
93 return UnrankedMemRefType::get(type.getElementType(), 0);
94 });
95 addArgumentMaterialization(materializeToTensor);
96 addSourceMaterialization(materializeToTensor);
97 addTargetMaterialization([](OpBuilder &builder, BaseMemRefType type,
98 ValueRange inputs, Location loc) -> Value {
99 assert(inputs.size() == 1);
100 // Target materialization is invoked if the new operand type does not
101 // match the expected type. A special case is when the new operand type is
102 // a memref with a specified layout, i.e. non-empty affine map.
103 // TODO(pifon) : Change how target materialization is invoked in dialect
104 // conversion.
105 if (auto memrefType = inputs[0].getType().dyn_cast<MemRefType>()) {
106 assert(!memrefType.getLayout().isIdentity());
107 return inputs[0];
108 }
109 assert(inputs[0].getType().isa<TensorType>());
110 return builder.create<bufferization::ToMemrefOp>(loc, type, inputs[0]);
111 });
112 }
113 };
114
115 /// Convert `tensor.extract_slice` to `memref.subview` in-place.
116 struct BufferizeExtractSliceOp : public OpConversionPattern<ExtractSliceOp> {
117 using OpConversionPattern<ExtractSliceOp>::OpConversionPattern;
118
matchAndRewritemlir::__anon293c5c8c0111::BufferizeExtractSliceOp119 LogicalResult matchAndRewrite(
120 ExtractSliceOp op, OpAdaptor adaptor,
121 ConversionPatternRewriter &rewriter) const final {
122 if (!op->getParentOfType<LoopOp>()) return failure();
123
124 rewriter.replaceOpWithNewOp<SubViewOp>(
125 op, adaptor.getSource(), op.getMixedOffsets(), op.getMixedSizes(),
126 op.getMixedStrides());
127 return success();
128 }
129 };
130
131 /// Convert `linalg.init_tensor` of `memref.alloc`.
132 struct BufferizeInitTensorOp : public OpConversionPattern<InitTensorOp> {
133 using OpConversionPattern<InitTensorOp>::OpConversionPattern;
134
matchAndRewritemlir::__anon293c5c8c0111::BufferizeInitTensorOp135 LogicalResult matchAndRewrite(
136 InitTensorOp op, OpAdaptor adaptor,
137 ConversionPatternRewriter &rewriter) const final {
138 if (!op->getParentOfType<LoopOp>()) return failure();
139
140 rewriter.replaceOpWithNewOp<memref::AllocOp>(
141 op, getTypeConverter()->convertType(op.getType()).cast<MemRefType>(),
142 adaptor.sizes());
143 return success();
144 }
145 };
146
isBlockArgOfTiledLoop(Value value)147 bool isBlockArgOfTiledLoop(Value value) {
148 if (auto blockArg = value.dyn_cast<BlockArgument>())
149 return isa<LoopOp>(blockArg.getOwner()->getParentOp());
150 return false;
151 }
152
153 // Attempts to find an existing `memref.subview` of `destMemRef` in the tiled
154 // loop. The assumption is that in `gml_st.loop` the tile of the output
155 // tensor that we read and the tile that we write to are the same.
findExistingSubview(Value destMemRef)156 Value findExistingSubview(Value destMemRef) {
157 if (auto toMemref = destMemRef.getDefiningOp<ToMemrefOp>()) {
158 if (auto toTensor = toMemref.getTensor().getDefiningOp<ToTensorOp>()) {
159 if (!isBlockArgOfTiledLoop(toTensor.getMemref())) return Value{};
160 // Scan through users of the block argument to find `subview` op.
161 for (Operation *tensorUser : toMemref.getTensor().getUsers()) {
162 if (auto anotherCast = mlir::dyn_cast<ToMemrefOp>(tensorUser)) {
163 for (Operation *memrefUser : anotherCast.getMemref().getUsers()) {
164 if (auto subview = mlir::dyn_cast<SubViewOp>(memrefUser)) {
165 if (subview.getSource() == destMemRef) return subview;
166 }
167 }
168 }
169 }
170 }
171 }
172 return Value{};
173 }
174
175 /// Convert `tensor.insert_slice` to `memref.subview` in-place.
176 struct BufferizeInsertSliceOp : public OpConversionPattern<InsertSliceOp> {
177 public:
178 using OpConversionPattern<InsertSliceOp>::OpConversionPattern;
179
matchAndRewritemlir::__anon293c5c8c0111::BufferizeInsertSliceOp180 LogicalResult matchAndRewrite(
181 InsertSliceOp op, OpAdaptor adaptor,
182 ConversionPatternRewriter &rewriter) const final {
183 Value sourceMemRef = adaptor.getSource();
184 assert(sourceMemRef.getType().isa<MemRefType>());
185
186 Value destMemRef = adaptor.getDest();
187 assert(destMemRef.getType().isa<MemRefType>());
188
189 if (!op->getParentOfType<LoopOp>()) return failure();
190
191 Value subview = findExistingSubview(destMemRef);
192 if (!subview) {
193 subview = rewriter.create<SubViewOp>(
194 op.getLoc(), destMemRef, op.getMixedOffsets(), op.getMixedSizes(),
195 op.getMixedStrides());
196 }
197 rewriter.create<memref::CopyOp>(op.getLoc(), sourceMemRef, subview);
198 rewriter.replaceOp(op, destMemRef);
199 return success();
200 }
201 };
202
203 /// Create linalg op on buffers given the original tensor-based operation and
204 /// the buffers for the outputs.
createLinalgOpOnBuffers(ConversionPatternRewriter & rewriter,linalg::LinalgOp linalgOp,ValueRange inputs,ValueRange outputs)205 linalg::LinalgOp createLinalgOpOnBuffers(ConversionPatternRewriter &rewriter,
206 linalg::LinalgOp linalgOp,
207 ValueRange inputs,
208 ValueRange outputs) {
209 SmallVector<Value, 8> newOperands = inputs;
210 newOperands.append(outputs.begin(), outputs.end());
211 auto *newOp = linalgOp.cloneWithoutRegions(rewriter, linalgOp.getLoc(),
212 /*resultTypes=*/ArrayRef<Type>{},
213 newOperands);
214 for (auto regions : llvm::zip(linalgOp->getRegions(), newOp->getRegions())) {
215 auto &oldRegion = std::get<0>(regions);
216 auto &newRegion = std::get<1>(regions);
217 rewriter.inlineRegionBefore(oldRegion, newRegion, newRegion.begin());
218 }
219 return newOp;
220 }
221
222 /// Get a variadic operand segment.
getVariadicOperands(DenseI32ArrayAttr sizeAttr,ValueRange operands,unsigned index)223 ValueRange getVariadicOperands(DenseI32ArrayAttr sizeAttr,
224 ValueRange operands, unsigned index) {
225 const int32_t *sizeIt = &*sizeAttr.value_begin<int32_t>();
226 if (sizeAttr.isSplat()) return operands.slice(*sizeIt * index, *sizeIt);
227
228 unsigned start = 0;
229 for (unsigned i = 0; i < index; ++i) start += sizeIt[i];
230 return operands.slice(start, sizeIt[index]);
231 }
232
233 // Bufferize LinalgOps in-place.
234 struct BufferizeLinalgOp
235 : public OpInterfaceConversionPattern<linalg::LinalgOp> {
236 using OpInterfaceConversionPattern<
237 linalg::LinalgOp>::OpInterfaceConversionPattern;
238
matchAndRewritemlir::__anon293c5c8c0111::BufferizeLinalgOp239 LogicalResult matchAndRewrite(
240 linalg::LinalgOp op, ArrayRef<Value> operands,
241 ConversionPatternRewriter &rewriter) const final {
242 if (!op->getParentOfType<LoopOp>()) return failure();
243
244 // An op with two variadic operand groups expects a segment size attribute.
245 auto operandSegments =
246 op->getAttrOfType<DenseI32ArrayAttr>("operand_segment_sizes");
247 if (!operandSegments) return failure();
248
249 const auto getOperands = [&](unsigned index) {
250 return getVariadicOperands(operandSegments, operands, index);
251 };
252 createLinalgOpOnBuffers(rewriter, op, getOperands(0), getOperands(1));
253 rewriter.replaceOp(op, getOperands(1));
254 return success();
255 }
256 };
257
258 // Convert `gml_st.yield` terminator of `gml_st.loop` to `gml_st.yield` with no
259 // arguments.
260 struct BufferizeLinalgYieldOp : public OpConversionPattern<gml_st::YieldOp> {
261 using OpConversionPattern<gml_st::YieldOp>::OpConversionPattern;
262
matchAndRewritemlir::__anon293c5c8c0111::BufferizeLinalgYieldOp263 LogicalResult matchAndRewrite(
264 gml_st::YieldOp op, OpAdaptor adaptor,
265 ConversionPatternRewriter &rewriter) const final {
266 if (!mlir::dyn_cast<LoopOp>(op->getParentOp()) ||
267 adaptor.getOperands().empty())
268 return failure();
269
270 rewriter.replaceOpWithNewOp<gml_st::YieldOp>(op);
271 return success();
272 }
273 };
274
275 // FuncOp-like bufferization pattern for `gml_st.loop` that inserts
276 // `memref.tensor_load` ops for every memref block argument.
277 //
278 // TODO(b/230082413): This code has to go away if we migrate to one-shot
279 // bufferization.
280 struct BufferizeLoopOp : public OpConversionPattern<LoopOp> {
281 using OpConversionPattern::OpConversionPattern;
282
matchAndRewritemlir::__anon293c5c8c0111::BufferizeLoopOp283 LogicalResult matchAndRewrite(
284 LoopOp op, OpAdaptor adaptor,
285 ConversionPatternRewriter &rewriter) const override {
286 if (op.getNumResults() == 0) return failure();
287
288 // Allocate new buffers for results if it is used by multiple uses.
289 SmallVector<Value, 4> operands = adaptor.getOperands();
290 for (auto &en : llvm::enumerate(op.outputs())) {
291 Value output = en.value();
292
293 auto toTensor = output.getDefiningOp<bufferization::ToTensorOp>();
294 if (!toTensor || toTensor->hasOneUse()) continue;
295
296 auto alloc = toTensor.getMemref().getDefiningOp<memref::AllocOp>();
297 if (!alloc) continue;
298
299 OpBuilder::InsertionGuard g(rewriter);
300 rewriter.setInsertionPoint(op);
301 auto *newAlloc = rewriter.clone(*alloc.getOperation());
302 operands[op.getNumControlOperands() + op.getNumInputs() + en.index()] =
303 newAlloc->getResult(0);
304 }
305
306 SmallVector<NamedAttribute> attrList;
307 for (auto &item : adaptor.getAttributes()) {
308 attrList.push_back(item);
309 }
310 auto newOp = rewriter.create<LoopOp>(op.getLoc(), mlir::TypeRange{},
311 operands, attrList);
312 // Take the region from the old op and put it in the new op.
313 rewriter.inlineRegionBefore(op.getLoopBody(), newOp.getLoopBody(),
314 newOp.getLoopBody().end());
315
316 // Convert the type of the entry block of the LoopOp's body.
317 if (failed(rewriter.convertRegionTypes(&newOp.getLoopBody(),
318 *getTypeConverter()))) {
319 return rewriter.notifyMatchFailure(op, "could not convert body types");
320 }
321
322 rewriter.replaceOp(op, newOp.outputs());
323 return success();
324 }
325 };
326
327 // TODO(b/199045477): The pattern for vector.transfer_read/write have to be
328 // moved out of Linalg bufferization to a VectorOps bufferization pass.
329 struct BufferizeVectorTransferReadOp
330 : public OpConversionPattern<vector::TransferReadOp> {
331 using OpConversionPattern<vector::TransferReadOp>::OpConversionPattern;
332
matchAndRewritemlir::__anon293c5c8c0111::BufferizeVectorTransferReadOp333 LogicalResult matchAndRewrite(
334 vector::TransferReadOp readOp, OpAdaptor adaptor,
335 ConversionPatternRewriter &rewriter) const final {
336 if (readOp.getShapedType().isa<MemRefType>()) return failure();
337 rewriter.replaceOpWithNewOp<vector::TransferReadOp>(
338 readOp, readOp.getType(), adaptor.getSource(), adaptor.getIndices(),
339 adaptor.getPermutationMapAttr(), adaptor.getPadding(),
340 adaptor.getMask(),
341 adaptor.getInBounds() ? adaptor.getInBoundsAttr() : ArrayAttr());
342 return success();
343 }
344 };
345
346 struct BufferizeVectorTransferWriteOp
347 : public OpConversionPattern<vector::TransferWriteOp> {
348 using OpConversionPattern<vector::TransferWriteOp>::OpConversionPattern;
349
matchAndRewritemlir::__anon293c5c8c0111::BufferizeVectorTransferWriteOp350 LogicalResult matchAndRewrite(
351 vector::TransferWriteOp writeOp, OpAdaptor adaptor,
352 ConversionPatternRewriter &rewriter) const final {
353 if (writeOp.getShapedType().isa<MemRefType>()) return failure();
354 rewriter.create<vector::TransferWriteOp>(
355 writeOp.getLoc(), adaptor.getVector(), adaptor.getSource(),
356 adaptor.getIndices(), adaptor.getPermutationMapAttr(),
357 adaptor.getInBounds() ? adaptor.getInBoundsAttr() : ArrayAttr());
358 rewriter.replaceOp(writeOp, adaptor.getSource());
359 return success();
360 }
361 };
362
363 } // namespace
364
365 namespace gml_st {
366 struct TiledLoopBufferizePass
367 : public TiledLoopBufferizePassBase<TiledLoopBufferizePass> {
getDependentDialectsmlir::gml_st::TiledLoopBufferizePass368 void getDependentDialects(DialectRegistry ®istry) const override {
369 registry.insert<memref::MemRefDialect>();
370 }
371
runOnOperationmlir::gml_st::TiledLoopBufferizePass372 void runOnOperation() override {
373 // Bufferize ops using BufferizableOpInterface. This could be switched to
374 // One-Shot Bufferize in the future.
375 mlir::RewritePatternSet patterns(&getContext());
376 mlir::bufferization::BufferizationOptions options =
377 mlir::bufferization::getPartialBufferizationOptions();
378 // TODO(springerm): Add dialects to this filter as more and more dialects
379 // will be migrated to BufferizableOpInterface-based bufferization.
380 options.opFilter.allowDialect<shape::ShapeDialect>();
381 if (failed(mlir::bufferization::bufferizeOp(getOperation(), options))) {
382 signalPassFailure();
383 return;
384 }
385
386 // Bufferize the remaining IR with dialect conversion. This will disappear
387 // eventually once all bufferization is done via BufferizableOpInterface.
388 if (failed(runDialectConversionBasedBufferization())) signalPassFailure();
389 }
390
391 private:
runDialectConversionBasedBufferizationmlir::gml_st::TiledLoopBufferizePass392 LogicalResult runDialectConversionBasedBufferization() {
393 mlir::RewritePatternSet patterns(&getContext());
394 auto &context = getContext();
395 ConversionTarget target(context);
396 target.addLegalDialect<
397 mlir::arith::ArithmeticDialect,
398 mlir::bufferization::BufferizationDialect,
399 mlir::complex::ComplexDialect, mlir::lmhlo::LmhloDialect,
400 mlir::AffineDialect, mlir::vector::VectorDialect,
401 mlir::memref::MemRefDialect, mlir::func::FuncDialect,
402 mlir::tensor::TensorDialect, mlir::math::MathDialect>();
403 target.addLegalOp<UnrealizedConversionCastOp>();
404 target.addIllegalDialect<mhlo::MhloDialect>();
405 target.addIllegalOp<tensor::ExtractSliceOp, tensor::InsertSliceOp>();
406
407 CustomBufferizeTypeConverter converter;
408 mlir::mhlo::RemoveSignTypeConverter removeSignConverter;
409
410 // Configure bufferize pattern.
411 populateCallOpTypeConversionPattern(patterns, converter);
412 populateBranchOpInterfaceTypeConversionPattern(patterns, converter);
413 populateReturnOpTypeConversionPattern(patterns, converter);
414 mlir::bufferization::populateBufferizeMaterializationLegality(target);
415 populateTiledLoopBufferizePattern(&getContext(), &converter, &patterns);
416 mlir::scf::populateSCFStructuralTypeConversionsAndLegality(
417 converter, patterns, target);
418 // Configure legality.
419 auto isLegalOp = [&](Operation *op) { return converter.isLegal(op); };
420 target.addDynamicallyLegalDialect<mlir::linalg::LinalgDialect>(isLegalOp);
421 target.addDynamicallyLegalOp<mlir::func::CallOp, gml_st::LoopOp,
422 gml_st::YieldOp, mlir::LLVM::InlineAsmOp,
423 mlir::vector::TransferWriteOp,
424 mlir::vector::TransferReadOp>(isLegalOp);
425
426 return applyPartialConversion(getOperation(), target, std::move(patterns));
427 }
428 };
429
populateTiledLoopBufferizePattern(mlir::MLIRContext * context,mlir::bufferization::BufferizeTypeConverter * converter,mlir::RewritePatternSet * patterns)430 void populateTiledLoopBufferizePattern(
431 mlir::MLIRContext *context,
432 mlir::bufferization::BufferizeTypeConverter *converter,
433 mlir::RewritePatternSet *patterns) {
434 // clang-format off
435 patterns->add<
436 BufferizeExtractSliceOp,
437 BufferizeInitTensorOp,
438 BufferizeInsertSliceOp,
439 BufferizeLinalgOp,
440 BufferizeLinalgYieldOp,
441 BufferizeLoopOp,
442 BufferizeVectorTransferReadOp,
443 BufferizeVectorTransferWriteOp
444 >(*converter, context);
445 // clang-format on
446 }
447
CreateTiledLoopBufferizePass()448 std::unique_ptr<OperationPass<func::FuncOp>> CreateTiledLoopBufferizePass() {
449 return std::make_unique<TiledLoopBufferizePass>();
450 }
451
452 } // namespace gml_st
453 } // namespace mlir
454