1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "mlir-hlo/Dialect/gml_st/IR/gml_st_ops.h"
17
18 #include <algorithm>
19 #include <iterator>
20 #include <memory>
21 #include <utility>
22
23 #include "llvm/ADT/ArrayRef.h"
24 #include "llvm/ADT/STLExtras.h"
25 #include "llvm/ADT/Sequence.h"
26 #include "llvm/ADT/SmallVector.h"
27 #include "llvm/ADT/TypeSwitch.h"
28 #include "llvm/Support/Casting.h"
29 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
30 #include "mlir/Dialect/MemRef/IR/MemRef.h"
31 #include "mlir/Dialect/SCF/IR/SCF.h"
32 #include "mlir/Dialect/Tensor/IR/Tensor.h"
33 #include "mlir/Dialect/Tensor/Utils/Utils.h"
34 #include "mlir/IR/BlockAndValueMapping.h"
35 #include "mlir/IR/BuiltinAttributes.h"
36 #include "mlir/IR/BuiltinTypes.h"
37 #include "mlir/IR/DialectImplementation.h"
38 #include "mlir/IR/OpDefinition.h"
39 #include "mlir/IR/OpImplementation.h"
40 #include "mlir/IR/Operation.h"
41 #include "mlir/IR/PatternMatch.h"
42 #include "mlir/Interfaces/ViewLikeInterface.h"
43
44 namespace mlir {
45 namespace {
46
printShapeTypeDimensionsList(AsmPrinter & printer,ArrayRef<int64_t> integers)47 void printShapeTypeDimensionsList(AsmPrinter &printer,
48 ArrayRef<int64_t> integers) {
49 llvm::interleave(
50 integers, printer,
51 [&](int64_t val) {
52 if (val == ShapedType::kDynamicSize)
53 printer << '?';
54 else
55 printer << val;
56 },
57 "x");
58 }
59
parseShapeTypeDimensionsList(AsmParser & parser,FailureOr<SmallVector<int64_t>> & dims)60 ParseResult parseShapeTypeDimensionsList(
61 AsmParser &parser, FailureOr<SmallVector<int64_t>> &dims) {
62 SmallVector<int64_t> vals;
63 if (failed(parser.parseDimensionList(vals, /*allowDynamic=*/true,
64 /*withTrailingX=*/false))) {
65 return failure();
66 }
67 dims = vals;
68 return success();
69 }
70
parseAssignmentListWithTypes(OpAsmParser & parser,SmallVectorImpl<OpAsmParser::UnresolvedOperand> & lhs,SmallVectorImpl<OpAsmParser::UnresolvedOperand> & rhs,SmallVectorImpl<Type> & types)71 ParseResult parseAssignmentListWithTypes(
72 OpAsmParser &parser, SmallVectorImpl<OpAsmParser::UnresolvedOperand> &lhs,
73 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &rhs,
74 SmallVectorImpl<Type> &types) {
75 auto parseElt = [&]() -> ParseResult {
76 if (parser.parseOperand(lhs.emplace_back(), /*allowResultNumber=*/false) ||
77 parser.parseEqual() || parser.parseOperand(rhs.emplace_back()) ||
78 parser.parseColon() || parser.parseType(types.emplace_back())) {
79 return failure();
80 }
81 return success();
82 };
83 return parser.parseCommaSeparatedList(AsmParser::Delimiter::Paren, parseElt);
84 }
85
86 } // namespace
87 } // namespace mlir
88
89 // Generated dialect definitions.
90 #include "mlir-hlo/Dialect/gml_st/IR/gml_st_dialect.cc.inc"
91
92 // Generated type classes.
93 #define GET_TYPEDEF_CLASSES
94 #include "mlir-hlo/Dialect/gml_st/IR/gml_st_types.cc.inc"
95
96 namespace mlir {
97 namespace gml_st {
98
99 //===----------------------------------------------------------------------===//
100 // GmlStDialect
101 //===----------------------------------------------------------------------===//
102
initialize()103 void GmlStDialect::initialize() {
104 addOperations<
105 #define GET_OP_LIST
106 #include "mlir-hlo/Dialect/gml_st/IR/gml_st_ops.cc.inc"
107 >();
108 addTypes<
109 #define GET_TYPEDEF_LIST
110 #include "mlir-hlo/Dialect/gml_st/IR/gml_st_types.cc.inc"
111 >();
112 }
113
114 // Helper function to ensure index types for some attrbutes when folding.
ensureIndexTypeForAttribute(OpFoldResult foldResult)115 static OpFoldResult ensureIndexTypeForAttribute(OpFoldResult foldResult) {
116 if (foldResult.is<Attribute>()) {
117 auto attr = foldResult.get<Attribute>().dyn_cast<IntegerAttr>();
118 if (!attr.getType().isa<IndexType>()) {
119 Builder b(attr.getContext());
120 return b.getIndexAttr(attr.getInt());
121 }
122 }
123 return foldResult;
124 }
125
materializeConstant(OpBuilder & builder,Attribute attr,Type type,Location loc)126 Operation *GmlStDialect::materializeConstant(OpBuilder &builder, Attribute attr,
127 Type type, Location loc) {
128 if (type.isa<IndexType>()) {
129 int64_t intValue = attr.cast<IntegerAttr>().getInt();
130 return builder.create<arith::ConstantIndexOp>(loc, intValue);
131 }
132 return {};
133 }
134
135 //===----------------------------------------------------------------------===//
136 // MaterializeOp
137 //===----------------------------------------------------------------------===//
138
inferReturnTypes(MLIRContext *,Optional<Location>,ValueRange operands,DictionaryAttr attributes,RegionRange,SmallVectorImpl<Type> & inferredReturnTypes)139 LogicalResult MaterializeOp::inferReturnTypes(
140 MLIRContext *, Optional<Location>, ValueRange operands,
141 DictionaryAttr attributes, RegionRange,
142 SmallVectorImpl<Type> &inferredReturnTypes) {
143 MaterializeOp::Adaptor adaptor(operands, attributes);
144
145 ShapedType sourceType = adaptor.source().getType().cast<ShapedType>();
146 Type setType = adaptor.set().getType();
147
148 if (auto tileType = setType.dyn_cast<TileType>()) {
149 if (auto memrefType = sourceType.dyn_cast<MemRefType>()) {
150 inferredReturnTypes.push_back(
151 MemRefType::get(tileType.getShape(), sourceType.getElementType()));
152 } else if (auto tensorType = sourceType.dyn_cast<RankedTensorType>()) {
153 inferredReturnTypes.push_back(RankedTensorType::get(
154 tileType.getShape(), sourceType.getElementType()));
155 } else {
156 return failure();
157 }
158 } else if (setType.isa<PointType>()) {
159 inferredReturnTypes.push_back(sourceType.getElementType());
160 } else {
161 return failure();
162 }
163 return success();
164 }
165
166 //===----------------------------------------------------------------------===//
167 // LoopOp
168 //===----------------------------------------------------------------------===//
169
build(OpBuilder & builder,OperationState & result,ValueRange lowerBounds,ValueRange upperBounds,ValueRange steps,ValueRange inputs,ValueRange outputs,ArrayAttr iteratorTypes,function_ref<void (OpBuilder &,Location,ValueRange,ValueRange,ValueRange)> bodyBuilderFn)170 void LoopOp::build(OpBuilder &builder, OperationState &result,
171 ValueRange lowerBounds, ValueRange upperBounds,
172 ValueRange steps, ValueRange inputs, ValueRange outputs,
173 ArrayAttr iteratorTypes,
174 function_ref<void(OpBuilder &, Location, ValueRange,
175 ValueRange, ValueRange)>
176 bodyBuilderFn) {
177 build(builder, result, lowerBounds, upperBounds, steps, inputs, outputs,
178 iteratorTypes, llvm::None, bodyBuilderFn);
179 }
180
build(OpBuilder & builder,OperationState & result,ValueRange lowerBounds,ValueRange upperBounds,ValueRange steps,ValueRange inputs,ValueRange outputs,ArrayAttr iteratorTypes,Optional<ArrayAttr> distributionTypes,function_ref<void (OpBuilder &,Location,ValueRange,ValueRange,ValueRange)> bodyBuilderFn)181 void LoopOp::build(OpBuilder &builder, OperationState &result,
182 ValueRange lowerBounds, ValueRange upperBounds,
183 ValueRange steps, ValueRange inputs, ValueRange outputs,
184 ArrayAttr iteratorTypes,
185 Optional<ArrayAttr> distributionTypes,
186 function_ref<void(OpBuilder &, Location, ValueRange,
187 ValueRange, ValueRange)>
188 bodyBuilderFn) {
189 result.addOperands(lowerBounds);
190 result.addOperands(upperBounds);
191 result.addOperands(steps);
192 result.addOperands(inputs);
193 result.addOperands(outputs);
194 result.addAttribute(
195 LoopOp::getOperandSegmentSizeAttr(),
196 builder.getDenseI32ArrayAttr({static_cast<int32_t>(lowerBounds.size()),
197 static_cast<int32_t>(upperBounds.size()),
198 static_cast<int32_t>(steps.size()),
199 static_cast<int32_t>(inputs.size()),
200 static_cast<int32_t>(outputs.size())}));
201 result.addAttribute(getIteratorTypesAttrStrName(), iteratorTypes);
202
203 if (distributionTypes.has_value())
204 result.addAttribute(getDistributionTypesAttrStrName(),
205 distributionTypes.getValue());
206
207 // Add output types for `RankedTensorType` output arguments.
208 for (Value output : outputs) {
209 Type outputType = output.getType();
210 if (outputType.isa<RankedTensorType>()) result.addTypes(outputType);
211 }
212
213 OpBuilder::InsertionGuard guard(builder);
214 unsigned numIVs = steps.size();
215 SmallVector<Type, 8> argTypes(numIVs, builder.getIndexType());
216 SmallVector<Location, 8> argLocs(numIVs, result.location);
217 for (Value input : inputs) {
218 argTypes.push_back(input.getType());
219 argLocs.push_back(input.getLoc());
220 }
221 for (Value output : outputs) {
222 argTypes.push_back(output.getType());
223 argLocs.push_back(output.getLoc());
224 }
225 Region *bodyRegion = result.addRegion();
226 Block *bodyBlock = builder.createBlock(bodyRegion, {}, argTypes, argLocs);
227
228 if (bodyBuilderFn) {
229 builder.setInsertionPointToStart(bodyBlock);
230 bodyBuilderFn(builder, result.location,
231 bodyBlock->getArguments().take_front(numIVs),
232 bodyBlock->getArguments().slice(numIVs, inputs.size()),
233 bodyBlock->getArguments().take_back(outputs.size()));
234 LoopOp::ensureTerminator(*bodyRegion, builder, result.location);
235 }
236 }
237
print(OpAsmPrinter & p)238 void LoopOp::print(OpAsmPrinter &p) {
239 p << " (" << getInductionVars() << ") = (" << lowerBound() << ") to ("
240 << upperBound() << ") step (" << step() << ")";
241
242 if (!inputs().empty()) {
243 p << " ins (";
244 llvm::interleaveComma(llvm::zip(getRegionInputArgs(), inputs()), p,
245 [&](auto it) {
246 p << std::get<0>(it) << " = " << std::get<1>(it)
247 << ": " << std::get<1>(it).getType();
248 });
249 p << ")";
250 }
251 if (!outputs().empty()) {
252 p << " outs (";
253 llvm::interleaveComma(llvm::zip(getRegionOutputArgs(), outputs()), p,
254 [&](auto it) {
255 p << std::get<0>(it) << " = " << std::get<1>(it)
256 << ": " << std::get<1>(it).getType();
257 });
258 p << ")";
259 }
260
261 if (llvm::any_of(iterator_types(), [](Attribute attr) {
262 return attr.cast<StringAttr>().getValue() !=
263 LoopOp::getParallelIteratorTypeName();
264 }))
265 p << " iterators" << iterator_types();
266
267 if (distribution_types().has_value())
268 p << " distribution" << distribution_types().getValue();
269
270 p << ' ';
271 p.printRegion(region(), /*printEntryBlockArgs=*/false);
272 p.printOptionalAttrDict(
273 getOperation()->getAttrs(),
274 /*elidedAttrs=*/{LoopOp::getOperandSegmentSizeAttr(),
275 LoopOp::getIteratorTypesAttrName(),
276 LoopOp::getDistributionTypesAttrName()});
277 }
278
parse(OpAsmParser & parser,OperationState & result)279 ParseResult LoopOp::parse(OpAsmParser &parser, OperationState &result) {
280 auto &builder = parser.getBuilder();
281 // Parse an opening `(` followed by induction variables followed by `)`
282 SmallVector<OpAsmParser::UnresolvedOperand, 4> ivs;
283 if (parser.parseOperandList(ivs, OpAsmParser::Delimiter::Paren,
284 /*allowResultNumber=*/false))
285 return failure();
286
287 // Parse loop bounds.
288 SmallVector<OpAsmParser::UnresolvedOperand, 4> lower;
289 if (parser.parseEqual() ||
290 parser.parseOperandList(lower, ivs.size(),
291 OpAsmParser::Delimiter::Paren) ||
292 parser.resolveOperands(lower, builder.getIndexType(), result.operands))
293 return failure();
294
295 SmallVector<OpAsmParser::UnresolvedOperand, 4> upper;
296 if (parser.parseKeyword("to") ||
297 parser.parseOperandList(upper, ivs.size(),
298 OpAsmParser::Delimiter::Paren) ||
299 parser.resolveOperands(upper, builder.getIndexType(), result.operands))
300 return failure();
301
302 // Parse step values.
303 SmallVector<OpAsmParser::UnresolvedOperand, 4> steps;
304 if (parser.parseKeyword("step") ||
305 parser.parseOperandList(steps, ivs.size(),
306 OpAsmParser::Delimiter::Paren) ||
307 parser.resolveOperands(steps, builder.getIndexType(), result.operands))
308 return failure();
309
310 // Parse input tensors.
311 SmallVector<OpAsmParser::UnresolvedOperand, 4> inputs, inputRegionArgs;
312 SmallVector<Type, 4> inputTypes;
313 if (succeeded(parser.parseOptionalKeyword("ins"))) {
314 SMLoc inputsOperandsLoc = parser.getCurrentLocation();
315
316 if (parseAssignmentListWithTypes(parser, inputRegionArgs, inputs,
317 inputTypes))
318 return failure();
319
320 if (parser.resolveOperands(inputs, inputTypes, inputsOperandsLoc,
321 result.operands))
322 return failure();
323 }
324
325 // Parse output tensors.
326 SmallVector<OpAsmParser::UnresolvedOperand, 4> outputs, outputRegionArgs;
327 SmallVector<Type, 4> outputTypes;
328 if (succeeded(parser.parseOptionalKeyword("outs"))) {
329 SMLoc outputsOperandsLoc = parser.getCurrentLocation();
330
331 if (parseAssignmentListWithTypes(parser, outputRegionArgs, outputs,
332 outputTypes))
333 return failure();
334
335 if (parser.resolveOperands(outputs, outputTypes, outputsOperandsLoc,
336 result.operands))
337 return failure();
338 for (Type outputType : outputTypes)
339 if (outputType.isa<RankedTensorType>()) result.addTypes(outputType);
340 }
341
342 // Parse attributes.
343 SmallVector<Attribute, 4> iterTypes, distributionTypes;
344 auto parseAttr = [&](StringRef keyword, SmallVector<Attribute, 4> *attrs) {
345 if (succeeded(parser.parseOptionalKeyword(keyword))) {
346 StringAttr attr;
347
348 if (parser.parseLSquare() || parser.parseAttribute(attr))
349 return failure();
350 attrs->push_back(attr);
351 for (int i = 1, e = ivs.size(); i < e; ++i) {
352 if (parser.parseComma() || parser.parseAttribute(attr))
353 return failure();
354 attrs->push_back(attr);
355 }
356 if (parser.parseRSquare()) return failure();
357 }
358 return success();
359 };
360 if (failed(parseAttr("iterators", &iterTypes)) ||
361 failed(parseAttr("distribution", &distributionTypes)))
362 return failure();
363
364 // Set all loop iterator types to "parallel" if they are not printed in IR.
365 if (iterTypes.empty()) {
366 auto parallelIter =
367 builder.getStringAttr(LoopOp::getParallelIteratorTypeName());
368 iterTypes = SmallVector<Attribute, 4>(ivs.size(), parallelIter);
369 }
370 result.addAttribute(LoopOp::getIteratorTypesAttrStrName(),
371 builder.getArrayAttr(iterTypes));
372 if (!distributionTypes.empty())
373 result.addAttribute(LoopOp::getDistributionTypesAttrStrName(),
374 builder.getArrayAttr(distributionTypes));
375 result.addAttribute(
376 LoopOp::getOperandSegmentSizeAttr(),
377 builder.getDenseI32ArrayAttr({static_cast<int32_t>(lower.size()),
378 static_cast<int32_t>(upper.size()),
379 static_cast<int32_t>(steps.size()),
380 static_cast<int32_t>(inputs.size()),
381 static_cast<int32_t>(outputs.size())}));
382
383 // Parse the body.
384 Region *body = result.addRegion();
385
386 SmallVector<Type, 4> regionTypes(ivs.size(), builder.getIndexType());
387 regionTypes.append(inputTypes);
388 regionTypes.append(outputTypes);
389
390 SmallVector<OpAsmParser::UnresolvedOperand, 4> regionOperands(ivs);
391 regionOperands.append(inputRegionArgs);
392 regionOperands.append(outputRegionArgs);
393
394 SmallVector<OpAsmParser::Argument, 4> regionArgs;
395
396 for (auto argAndType : llvm::zip(regionOperands, regionTypes)) {
397 auto &arg = regionArgs.emplace_back();
398 arg.ssaName = std::get<0>(argAndType);
399 arg.type = std::get<1>(argAndType);
400 }
401
402 if (parser.parseRegion(*body, regionArgs)) return failure();
403
404 // Parse optional attributes.
405 if (parser.parseOptionalAttrDict(result.attributes)) return failure();
406
407 return success();
408 }
409
getLoopBody()410 Region &LoopOp::getLoopBody() { return region(); }
411
verify()412 LogicalResult LoopOp::verify() {
413 // Check if iterator types are provided for every loop dimension.
414 if (iterator_types().size() != getNumLoops())
415 return emitOpError("expected iterator types array attribute size = ")
416 << iterator_types().size()
417 << " to match the number of loops = " << getNumLoops();
418
419 // Check if types of input arguments match region args types.
420 for (auto &item :
421 llvm::enumerate(llvm::zip(inputs(), getRegionInputArgs()))) {
422 Value input, inputRegionArg;
423 unsigned index = item.index();
424 std::tie(input, inputRegionArg) = item.value();
425 if (input.getType() != inputRegionArg.getType())
426 return emitOpError("expected input arg ")
427 << index << " with type = " << input.getType()
428 << " to match region arg " << index + getNumLoops()
429 << " type = " << inputRegionArg.getType();
430 }
431
432 // Check if types of output arguments match region args types.
433 for (auto &item :
434 llvm::enumerate(llvm::zip(outputs(), getRegionOutputArgs()))) {
435 Value output, outputRegionArg;
436 unsigned index = item.index();
437 std::tie(output, outputRegionArg) = item.value();
438 if (output.getType() != outputRegionArg.getType())
439 return emitOpError("expected output arg ")
440 << index << " with type = " << output.getType()
441 << " to match region arg "
442 << index + getNumLoops() + inputs().size()
443 << " type = " << outputRegionArg.getType();
444 }
445 return success();
446 }
447
448 //===----------------------------------------------------------------------===//
449 // LoopLikeOp
450 //===----------------------------------------------------------------------===//
451
452 namespace {
453
parseForOpOutputArgs(OpAsmParser & parser,OperationState & result,SmallVectorImpl<OpAsmParser::UnresolvedOperand> & regionOperands,SmallVectorImpl<Type> & regionTypes,int32_t * outputCount)454 ParseResult parseForOpOutputArgs(
455 OpAsmParser &parser, OperationState &result,
456 SmallVectorImpl<OpAsmParser::UnresolvedOperand> ®ionOperands,
457 SmallVectorImpl<Type> ®ionTypes, int32_t *outputCount) {
458 SmallVector<OpAsmParser::UnresolvedOperand, 4> outputs, outputRegionArgs;
459 SmallVector<Type, 4> outputTypes;
460
461 auto parseElt = [&]() -> ParseResult {
462 if (parser.parseOperand(outputRegionArgs.emplace_back(),
463 /*allowResultNumber=*/false) ||
464 parser.parseEqual()) {
465 return failure();
466 }
467 if (parser.parseOperand(outputs.emplace_back()) || parser.parseColon() ||
468 parser.parseType(outputTypes.emplace_back())) {
469 return failure();
470 }
471 *outputCount = outputs.size();
472 return success();
473 };
474 if (succeeded(parser.parseOptionalKeyword("outs"))) {
475 SMLoc loc = parser.getCurrentLocation();
476
477 if (parser.parseCommaSeparatedList(AsmParser::Delimiter::Paren, parseElt))
478 return failure();
479 if (parser.resolveOperands(outputs, outputTypes, loc, result.operands))
480 return failure();
481 }
482 regionOperands.append(outputRegionArgs);
483 regionTypes.append(outputTypes);
484 return success();
485 }
486
487 } // namespace
488
489 template <typename LoopTy>
parseLoopLikeOp(OpAsmParser & parser,OperationState & result)490 ParseResult parseLoopLikeOp(OpAsmParser &parser, OperationState &result) {
491 auto &builder = parser.getBuilder();
492 // Parse an opening `(` followed by induction variables followed by `)`
493 SmallVector<OpAsmParser::UnresolvedOperand, 4> ivs;
494 if (parser.parseOperandList(ivs, OpAsmParser::Delimiter::Paren,
495 /*allowResultNumber=*/false))
496 return failure();
497
498 // Parse loop bounds.
499 SmallVector<OpAsmParser::UnresolvedOperand, 4> lower;
500 if (parser.parseEqual() ||
501 parser.parseOperandList(lower, ivs.size(),
502 OpAsmParser::Delimiter::Paren) ||
503 parser.resolveOperands(lower, builder.getIndexType(), result.operands))
504 return failure();
505
506 SmallVector<OpAsmParser::UnresolvedOperand, 4> upper;
507 if (parser.parseKeyword("to") ||
508 parser.parseOperandList(upper, ivs.size(),
509 OpAsmParser::Delimiter::Paren) ||
510 parser.resolveOperands(upper, builder.getIndexType(), result.operands))
511 return failure();
512
513 // Parse step values.
514 SmallVector<OpAsmParser::UnresolvedOperand, 4> steps;
515 if (parser.parseKeyword("step") ||
516 parser.parseOperandList(steps, ivs.size(),
517 OpAsmParser::Delimiter::Paren) ||
518 parser.resolveOperands(steps, builder.getIndexType(), result.operands))
519 return failure();
520
521 SmallVector<int32_t> segmentSizes{static_cast<int32_t>(lower.size()),
522 static_cast<int32_t>(upper.size()),
523 static_cast<int32_t>(steps.size())};
524
525 // Parse the output tensors (only for ForOp) and the body.
526 SmallVector<OpAsmParser::UnresolvedOperand, 4> regionOperands(ivs);
527 SmallVector<Type, 4> regionTypes(ivs.size(), builder.getIndexType());
528
529 if (std::is_same<LoopTy, ForOp>::value) {
530 int32_t outputCount = 0;
531 if (parseForOpOutputArgs(parser, result, regionOperands, regionTypes,
532 &outputCount))
533 return failure();
534 segmentSizes.push_back(outputCount);
535 }
536
537 SmallVector<OpAsmParser::Argument, 4> regionArgs;
538 for (auto argAndType : llvm::zip(regionOperands, regionTypes)) {
539 auto &arg = regionArgs.emplace_back();
540 std::tie(arg.ssaName, arg.type) = argAndType;
541 }
542 Region *body = result.addRegion();
543 if (parser.parseRegion(*body, regionArgs)) return failure();
544
545 // Parse attributes.
546 if (parser.parseOptionalAttrDict(result.attributes)) return failure();
547
548 // Parser result types.
549 if (parser.parseOptionalColonTypeList(result.types)) return failure();
550
551 // Add segment sizes.
552 result.addAttribute(LoopTy::getOperandSegmentSizeAttr(),
553 builder.getDenseI32ArrayAttr(segmentSizes));
554
555 return success();
556 }
557
558 //===----------------------------------------------------------------------===//
559 // ParallelOp
560 //===----------------------------------------------------------------------===//
561
getLoopBody()562 Region &ParallelOp::getLoopBody() { return region(); }
563
getTerminator()564 SetYieldOp ParallelOp::getTerminator() {
565 return cast<SetYieldOp>(getBody()->getTerminator());
566 }
567
verify()568 LogicalResult ParallelOp::verify() { return success(); }
569
build(OpBuilder & builder,OperationState & result,TypeRange resultTypes,ValueRange lowerBounds,ValueRange upperBounds,ValueRange steps,function_ref<void (OpBuilder &,Location,ValueRange)> bodyBuilderFn)570 void ParallelOp::build(
571 OpBuilder &builder, OperationState &result, TypeRange resultTypes,
572 ValueRange lowerBounds, ValueRange upperBounds, ValueRange steps,
573 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn) {
574 result.addOperands(lowerBounds);
575 result.addOperands(upperBounds);
576 result.addOperands(steps);
577 result.addTypes(resultTypes);
578 result.addAttribute(
579 LoopOp::getOperandSegmentSizeAttr(),
580 builder.getDenseI32ArrayAttr({static_cast<int32_t>(lowerBounds.size()),
581 static_cast<int32_t>(upperBounds.size()),
582 static_cast<int32_t>(steps.size())}));
583
584 OpBuilder::InsertionGuard guard(builder);
585 unsigned numIvs = steps.size();
586 SmallVector<Type, 8> argTypes(numIvs, builder.getIndexType());
587 SmallVector<Location, 8> argLocs(numIvs, result.location);
588 Region *bodyRegion = result.addRegion();
589 Block *bodyBlock = builder.createBlock(bodyRegion, {}, argTypes, argLocs);
590
591 if (bodyBuilderFn) {
592 builder.setInsertionPointToStart(bodyBlock);
593 bodyBuilderFn(builder, result.location,
594 bodyBlock->getArguments().take_front(numIvs));
595 ParallelOp::ensureTerminator(*bodyRegion, builder, result.location);
596 }
597 }
598
print(OpAsmPrinter & p)599 void ParallelOp::print(OpAsmPrinter &p) {
600 p << " (" << getInductionVars() << ") = (" << lowerBound() << ") to ("
601 << upperBound() << ") step (" << step() << ") ";
602
603 p.printRegion(region(), /*printEntryBlockArgs=*/false);
604 p.printOptionalAttrDict(
605 getOperation()->getAttrs(),
606 /*elidedAttrs=*/{ParallelOp::getOperandSegmentSizeAttr()});
607
608 if (!getResultTypes().empty()) {
609 p << " : ";
610 llvm::interleave(getResultTypes(), p, ", ");
611 }
612 }
613
parse(OpAsmParser & parser,OperationState & result)614 ParseResult ParallelOp::parse(OpAsmParser &parser, OperationState &result) {
615 return parseLoopLikeOp<ParallelOp>(parser, result);
616 }
617
618 //===----------------------------------------------------------------------===//
619 // ForOp
620 //===----------------------------------------------------------------------===//
621
getLoopBody()622 Region &ForOp::getLoopBody() { return region(); }
623
getTerminator()624 SetYieldOp ForOp::getTerminator() {
625 return cast<SetYieldOp>(getBody()->getTerminator());
626 }
627
verify()628 LogicalResult ForOp::verify() {
629 // Check if types of output arguments match region args types.
630 for (auto &item :
631 llvm::enumerate(llvm::zip(outputs(), getRegionOutputArgs()))) {
632 Value output, outputRegionArg;
633 unsigned index = item.index();
634 std::tie(output, outputRegionArg) = item.value();
635 if (output.getType() != outputRegionArg.getType()) {
636 return emitOpError("expected output arg ")
637 << index << " with type = " << output.getType()
638 << " to match region arg " << index + getNumLoops()
639 << " type = " << outputRegionArg.getType();
640 }
641 if (getTerminator().getDstOperand(index)->get() != outputRegionArg) {
642 return getTerminator().emitOpError("expected output block argument ")
643 << index << " to match set_yield destination";
644 }
645 }
646 return success();
647 }
648
build(OpBuilder & builder,OperationState & result,TypeRange resultTypes,ValueRange lowerBounds,ValueRange upperBounds,ValueRange steps,ValueRange outputs,function_ref<void (OpBuilder &,Location,ValueRange,ValueRange)> bodyBuilderFn)649 void ForOp::build(
650 OpBuilder &builder, OperationState &result, TypeRange resultTypes,
651 ValueRange lowerBounds, ValueRange upperBounds, ValueRange steps,
652 ValueRange outputs,
653 function_ref<void(OpBuilder &, Location, ValueRange, ValueRange)>
654 bodyBuilderFn) {
655 result.addOperands(lowerBounds);
656 result.addOperands(upperBounds);
657 result.addOperands(steps);
658 result.addOperands(outputs);
659 result.addTypes(resultTypes);
660 result.addAttribute(
661 LoopOp::getOperandSegmentSizeAttr(),
662 builder.getDenseI32ArrayAttr({static_cast<int32_t>(lowerBounds.size()),
663 static_cast<int32_t>(upperBounds.size()),
664 static_cast<int32_t>(steps.size()),
665 static_cast<int32_t>(outputs.size())}));
666
667 OpBuilder::InsertionGuard guard(builder);
668 unsigned numIvs = steps.size();
669 SmallVector<Type, 8> argTypes(numIvs, builder.getIndexType());
670 SmallVector<Location, 8> argLocs(numIvs, result.location);
671 for (Value output : outputs) {
672 argTypes.push_back(output.getType());
673 argLocs.push_back(output.getLoc());
674 }
675 Region *bodyRegion = result.addRegion();
676 Block *bodyBlock = builder.createBlock(bodyRegion, {}, argTypes, argLocs);
677
678 if (bodyBuilderFn) {
679 builder.setInsertionPointToStart(bodyBlock);
680 bodyBuilderFn(builder, result.location,
681 bodyBlock->getArguments().take_front(numIvs),
682 bodyBlock->getArguments().take_back(outputs.size()));
683 ForOp::ensureTerminator(*bodyRegion, builder, result.location);
684 }
685 }
686
print(OpAsmPrinter & p)687 void ForOp::print(OpAsmPrinter &p) {
688 p << " (" << getInductionVars() << ") = (" << lowerBound() << ") to ("
689 << upperBound() << ") step (" << step() << ")";
690
691 if (!outputs().empty()) {
692 p << " outs (";
693 llvm::interleaveComma(
694 llvm::zip(getRegionOutputArgs(), outputs()), p, [&](auto it) {
695 Value outputRegionArg, output;
696 std::tie(outputRegionArg, output) = it;
697 p << outputRegionArg << " = " << output << ": " << output.getType();
698 });
699 p << ")";
700 }
701
702 p << ' ';
703 p.printRegion(region(), /*printEntryBlockArgs=*/false);
704 p.printOptionalAttrDict(getOperation()->getAttrs(),
705 /*elidedAttrs=*/{ForOp::getOperandSegmentSizeAttr()});
706
707 if (!getResultTypes().empty()) {
708 p << " : ";
709 llvm::interleave(getResultTypes(), p, ", ");
710 }
711 }
712
parse(OpAsmParser & parser,OperationState & result)713 ParseResult ForOp::parse(OpAsmParser &parser, OperationState &result) {
714 return parseLoopLikeOp<ForOp>(parser, result);
715 }
716
717 namespace {
718
719 static constexpr int64_t kNoMatch = -1;
720
721 // Folds away LoopOp inputs if they have no uses within the body.
722 //
723 // Example:
724 //
725 // %0 = gml_st.loop ... ins (%in_ = %in: tensor<...>,
726 // %in_buf_ = %in_buf: memref<...>) {...}
727 // Becomes
728 //
729 // gml_st.loop ... ins (%in_buf_ = %in_buf: memref<...>) {...}
730 struct LoopInputsFolder : public OpRewritePattern<LoopOp> {
731 using OpRewritePattern<LoopOp>::OpRewritePattern;
732
matchAndRewritemlir::gml_st::__anon73e87aa30b11::LoopInputsFolder733 LogicalResult matchAndRewrite(LoopOp loop,
734 PatternRewriter &rewriter) const final {
735 SmallVector<Value, 2> newInputs, regionInputTensorArgs;
736 // Store ids of the corresponding old and new input operands.
737 SmallVector<int64_t, 2> oldInputIdToNew(loop.inputs().size(), kNoMatch);
738 for (const auto &en :
739 llvm::enumerate(llvm::zip(loop.inputs(), loop.getRegionInputArgs()))) {
740 Value in, bbArg;
741 size_t index = en.index();
742 std::tie(in, bbArg) = en.value();
743 if (!bbArg.use_empty()) {
744 oldInputIdToNew[index] = newInputs.size();
745 newInputs.push_back(in);
746 }
747 }
748 if (newInputs.size() == loop.inputs().size()) return failure();
749 Location loc = loop.getLoc();
750 auto newLoop = rewriter.create<LoopOp>(
751 loc, loop.lowerBound(), loop.upperBound(), loop.step(), newInputs,
752 loop.outputs(), loop.iterator_types(), loop.distribution_types());
753
754 // Clone the region.
755 BlockAndValueMapping bvm;
756 bvm.map(loop.getInductionVars(), newLoop.getInductionVars());
757 bvm.map(loop.getRegionOutputArgs(), newLoop.getRegionOutputArgs());
758 for (const auto &en : llvm::enumerate(oldInputIdToNew))
759 if (en.value() != kNoMatch)
760 bvm.map(loop.getRegionInputArgs()[en.index()],
761 newLoop.getRegionInputArgs()[en.value()]);
762 OpBuilder innerBuilder =
763 OpBuilder::atBlockEnd(newLoop.getBody(), rewriter.getListener());
764 for (auto &op : *loop.getBody()) innerBuilder.clone(op, bvm);
765 rewriter.replaceOp(loop, newLoop.getResults());
766
767 return success();
768 }
769 };
770
771 } // namespace
772
773 /// A simple, conservative analysis to determine if the loop is shape
774 /// conserving. I.e., the type of the arg-th yielded value is the same as the
775 /// type of the corresponding basic block argument of the loop.
776 /// Note: This function handles only simple cases. Expand as needed.
isShapePreserving(LoopOp loopOp,int64_t arg)777 static bool isShapePreserving(LoopOp loopOp, int64_t arg) {
778 auto yieldOp = cast<YieldOp>(loopOp.getLoopBody().front().getTerminator());
779 if (yieldOp.values().empty())
780 // Loop either has no outputs or is a "memref-based version". In either
781 // case, the loop is shape conserving.
782 return true;
783 assert(arg < static_cast<int64_t>(yieldOp.values().size()) &&
784 "arg is out of bounds");
785 Value value = yieldOp.values()[arg];
786 while (value) {
787 if (value == loopOp.getRegionOutputArgs()[arg]) return true;
788 OpResult opResult = value.dyn_cast<OpResult>();
789 if (!opResult) return false;
790
791 using tensor::InsertSliceOp;
792 value = llvm::TypeSwitch<Operation *, Value>(opResult.getOwner())
793 .template Case<InsertSliceOp>(
794 [&](InsertSliceOp op) { return op.getDest(); })
795 .template Case<LoopOp>([&](LoopOp loopOp) {
796 return isShapePreserving(loopOp, opResult.getResultNumber())
797 ? loopOp.outputs()[opResult.getResultNumber()]
798 : Value();
799 })
800 .Default([&](auto /*op*/) { return Value(); });
801 }
802 return false;
803 }
804
805 namespace {
806
807 /// Fold dim(x) where `x` is an input/output argument of a LoopOp block
808 /// to dim(y) where `y` is the initial input/output value of the argument.
809 ///
810 /// E.g.:
811 /// %y = ... : tensor<...>
812 /// gml_st.loop ... ins(%x = %y : tensor<...>) {
813 /// tensor.dim %x, %c0 : tensor<...>
814 /// }
815 ///
816 /// is folded to:
817 /// %y = ... : tensor<...>
818 /// gml_st.loop ... ins(%x = %y : tensor<...>) {
819 /// tensor.dim %y, %c0 : tensor<...>
820 /// }
821 ///
822 /// Note: Dim ops are folded only if it can be proven that the runtime type of
823 /// the yielded value (in case of outputs) does not change with loop iterations.
824 template <typename OpTy>
825 struct DimOfLoopInsOutsFolder : public OpRewritePattern<OpTy> {
826 using OpRewritePattern<OpTy>::OpRewritePattern;
827
matchAndRewritemlir::gml_st::__anon73e87aa30f11::DimOfLoopInsOutsFolder828 LogicalResult matchAndRewrite(OpTy dimOp,
829 PatternRewriter &rewriter) const final {
830 auto src = dimOp.getSource().template dyn_cast<BlockArgument>();
831 if (!src) return failure();
832 auto loopOp = dyn_cast<LoopOp>(src.getOwner()->getParent()->getParentOp());
833 if (!loopOp) return failure();
834 unsigned numLoops = loopOp.getNumLoops();
835 unsigned numInputArgs = loopOp.getRegionInputArgs().size();
836 if (src.getArgNumber() >= numInputArgs + numLoops &&
837 !isShapePreserving(loopOp,
838 src.getArgNumber() - numInputArgs - numLoops))
839 return failure();
840
841 auto inputArgs = loopOp.getRegionInputArgs();
842 auto it1 = llvm::find(inputArgs, src);
843 if (it1 != inputArgs.end()) {
844 rewriter.updateRootInPlace(dimOp, [&] {
845 dimOp.getSourceMutable().assign(
846 loopOp.inputs()[it1 - inputArgs.begin()]);
847 });
848 return success();
849 }
850
851 auto outputArgs = loopOp.getRegionOutputArgs();
852 auto it2 = llvm::find(outputArgs, src);
853 if (it2 != outputArgs.end()) {
854 rewriter.updateRootInPlace(dimOp, [&] {
855 dimOp.getSourceMutable().assign(
856 loopOp.outputs()[it2 - outputArgs.begin()]);
857 });
858 return success();
859 }
860
861 return failure();
862 }
863 };
864
865 /// Fold dim(r) where `r` is the result of a LoopOp to dim(y) where `y`
866 /// is the initial output value of the loop.
867 ///
868 /// E.g.:
869 /// %y = ... : tensor<...>
870 /// %r = gml_st.loop ... outs(%i = %y : tensor<...>) {
871 /// ...
872 /// }
873 /// %0 = tensor.dim %r, %c0 : tensor<...>
874 ///
875 /// is folded to:
876 /// %y = ... : tensor<...>
877 /// gml_st.loop ... outs(%i = %y : tensor<...>) {
878 /// ...
879 /// }
880 /// %0 = tensor.dim %y, %c0 : tensor<...>
881 ///
882 /// Note: Dim ops are folded only if it can be proven that the runtime type of
883 /// the yielded value (in case of outputs) does not change with loop iterations.
884 template <typename OpTy>
885 struct DimOfLoopResultFolder : public OpRewritePattern<OpTy> {
886 using OpRewritePattern<OpTy>::OpRewritePattern;
887
matchAndRewritemlir::gml_st::__anon73e87aa30f11::DimOfLoopResultFolder888 LogicalResult matchAndRewrite(OpTy dimOp,
889 PatternRewriter &rewriter) const final {
890 auto loopOp = dimOp.getSource().template getDefiningOp<LoopOp>();
891 if (!loopOp) return failure();
892 auto opResult = dimOp.getSource().template cast<OpResult>();
893 unsigned resultNumber = opResult.getResultNumber();
894 if (!isShapePreserving(loopOp, resultNumber)) return failure();
895 rewriter.updateRootInPlace(dimOp, [&]() {
896 dimOp.getSourceMutable().assign(loopOp.outputs()[resultNumber]);
897 });
898 return success();
899 }
900 };
901
902 // Folds away LoopOp output tensors when the following conditions are met:
903 // * result of `gml_st.loop` has no uses
904 // * output tensor is the argument of `gml_st.yield`
905 //
906 // Example:
907 //
908 // %0 = gml_st.loop ... outs (%o_ = %out: tensor<...>,
909 // %obuf_ = %out_buf: memref<...>) {
910 // ...
911 // gml_st.yield %o_ : tensor ...
912 // }
913 //
914 // Becomes
915 //
916 // gml_st.loop ... outs (%obuf_ = %out_buf: memref<...>) {
917 // ...
918 // gml_st.yield
919 // }
920 struct LoopResultsFolder : public OpRewritePattern<LoopOp> {
921 using OpRewritePattern<LoopOp>::OpRewritePattern;
922
matchAndRewritemlir::gml_st::__anon73e87aa30f11::LoopResultsFolder923 LogicalResult matchAndRewrite(LoopOp loop,
924 PatternRewriter &rewriter) const final {
925 if (loop.getNumResults() == 0) return failure();
926
927 Block *block = loop.getBody();
928 auto yieldOp = cast<YieldOp>(block->getTerminator());
929
930 // Match the pattern and collect output buffers that will replace the output
931 // tensors and also the ops that will be ignored when cloning the body.
932 SmallVector<Value, 2> newOutputOperands, newYieldArgs;
933 int resultId = 0;
934 // Store ids of the corresponding old and new output operands.
935 SmallVector<int64_t, 2> oldOutputIdToNew(loop.outputs().size(), kNoMatch);
936 // Store ids of the corresponding old and new results.
937 SmallVector<int64_t, 2> oldResultIdToNew(loop.getNumResults(), kNoMatch);
938 SmallVector<Value, 2> resultReplacement(loop.getNumResults());
939 for (const auto &en : llvm::enumerate(
940 llvm::zip(loop.outputs(), loop.getRegionOutputArgs()))) {
941 size_t index = en.index();
942 Value out = std::get<0>(en.value());
943 Value outRegionArg = std::get<1>(en.value());
944
945 if (!out.getType().isa<RankedTensorType>()) {
946 oldOutputIdToNew[index] = newOutputOperands.size();
947 newOutputOperands.push_back(out);
948 continue;
949 }
950 Value result = loop.getResult(resultId);
951 Value yieldArg = yieldOp.getOperand(resultId);
952 if (yieldArg != outRegionArg || !result.use_empty()) {
953 oldOutputIdToNew[index] = newOutputOperands.size();
954 oldResultIdToNew[resultId] = newYieldArgs.size();
955 resultReplacement[resultId] = out;
956 newOutputOperands.push_back(out);
957 newYieldArgs.push_back(yieldArg);
958 }
959 ++resultId;
960 }
961 if (newOutputOperands.size() == loop.outputs().size()) return failure();
962
963 Location loc = loop.getLoc();
964 auto newLoop = rewriter.create<LoopOp>(
965 loc, loop.lowerBound(), loop.upperBound(), loop.step(), loop.inputs(),
966 newOutputOperands, loop.iterator_types(), loop.distribution_types());
967
968 // Clone the region.
969 BlockAndValueMapping bvm;
970 bvm.map(loop.getInductionVars(), newLoop.getInductionVars());
971 bvm.map(loop.getRegionInputArgs(), newLoop.getRegionInputArgs());
972 for (const auto &en : llvm::enumerate(oldOutputIdToNew)) {
973 if (en.value() != kNoMatch)
974 bvm.map(loop.getRegionOutputArgs()[en.index()],
975 newLoop.getRegionOutputArgs()[en.value()]);
976 else
977 bvm.map(loop.getRegionOutputArgs()[en.index()],
978 loop.outputs()[en.index()]);
979 }
980 OpBuilder innerBuilder =
981 OpBuilder::atBlockEnd(newLoop.getBody(), rewriter.getListener());
982 for (auto &op : loop.getBody()->without_terminator())
983 innerBuilder.clone(op, bvm);
984 innerBuilder.create<YieldOp>(
985 loc, llvm::to_vector<2>(llvm::map_range(
986 newYieldArgs, [&](Value arg) { return bvm.lookup(arg); })));
987
988 for (const auto &en : llvm::enumerate(oldResultIdToNew))
989 if (en.value() != kNoMatch)
990 resultReplacement[en.index()] = newLoop.getResult(en.value());
991 rewriter.replaceOp(loop, resultReplacement);
992
993 return success();
994 }
995 };
996
997 /// Pull `gml_st.loop` input/output arguments that are produced by
998 /// `tensor.cast` ops inside `gml_st.loop`:
999 ///
1000 /// ```
1001 /// %in = tensor.cast %t0 : tensor<32x1024xf32> to tensor<?x?xf32>
1002 /// %out = tensor.cast %t1 : tensor<32x1024xf32> to tensor<?x?xf32>
1003 /// %result = gml_st.loop %i = %c0 to %c1024 step %c32
1004 /// ins (%in_ = %in: tensor<?x?xf32>)
1005 /// outs (%out_ = %out: tensor<?x?xf32>) {
1006 /// %0 = call @do(%in_, %out_)
1007 /// : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
1008 /// scf.yield %0 : tensor<?x?xf32>
1009 /// }
1010 /// %result_cast = tensor.cast %result
1011 /// : tensor<?x?xf32> to tensor<32x1024xf32>
1012 /// use_of(%result_cast)
1013 /// ```
1014 ///
1015 /// folds into:
1016 //
1017 /// ```
1018 /// %result = gml_st.loop %i = %c0 to %c1024 step %c32
1019 /// ins (%in_ = %t0: tensor<32x1024xf32>)
1020 /// outs (%out_ = %t1: tensor<32x1024xf32>) {
1021 /// %in_cast = tensor.cast %in_ : tensor<32x1024xf32> to tensor<?x?xf32>
1022 /// %out_cast = tensor.cast %out_ : tensor<32x1024xf32> to tensor<?x?xf32>
1023 /// %0 = call @do(%in_, %out_)
1024 /// : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
1025 /// %0_cast = tensor.cast %0 : tensor<?x?xf32> to tensor<32x1024xf32>
1026 /// scf.yield %0 : tensor<32x1024xf32>
1027 /// }
1028 /// use_of(%result)
1029 /// ```
1030 struct TensorCastOfLoopInsOutsFolder : public OpRewritePattern<LoopOp> {
1031 using OpRewritePattern<LoopOp>::OpRewritePattern;
1032
matchAndRewritemlir::gml_st::__anon73e87aa30f11::TensorCastOfLoopInsOutsFolder1033 LogicalResult matchAndRewrite(LoopOp loop,
1034 PatternRewriter &rewriter) const override {
1035 CastOpsOfArgs inputCasts = findTensorCastOps(loop.inputs());
1036 CastOpsOfArgs outputCasts = findTensorCastOps(loop.outputs());
1037 if (!inputCasts.castFound && !outputCasts.castFound) return failure();
1038
1039 auto newLoop = rewriter.create<LoopOp>(
1040 loop.getLoc(), loop.lowerBound(), loop.upperBound(), loop.step(),
1041 inputCasts.updatedArgs, outputCasts.updatedArgs, loop.iterator_types(),
1042 loop.distribution_types());
1043
1044 rewriter.replaceOp(loop, insertCastsAndCloneBody(inputCasts, outputCasts,
1045 loop, newLoop, rewriter));
1046 return success();
1047 }
1048
1049 private:
1050 struct CastOpsOfArgs {
1051 SmallVector<tensor::CastOp, 4> ops;
1052 // Contains either old arguments or arguments of `tensor.cast`.
1053 SmallVector<Value, 4> updatedArgs;
1054 bool castFound = false;
1055 };
1056
1057 // Scans through args to find what args are produced by `tensor.cast` ops.
findTensorCastOpsmlir::gml_st::__anon73e87aa30f11::TensorCastOfLoopInsOutsFolder1058 CastOpsOfArgs findTensorCastOps(ValueRange args) const {
1059 CastOpsOfArgs result;
1060 for (auto arg : args) {
1061 if (auto cast = arg.getDefiningOp<tensor::CastOp>()) {
1062 result.ops.push_back(cast);
1063 result.updatedArgs.push_back(cast.getSource());
1064 result.castFound = true;
1065 continue;
1066 }
1067 result.ops.push_back(nullptr);
1068 result.updatedArgs.push_back(arg);
1069 }
1070 return result;
1071 }
1072
insertCastsAndCloneBodymlir::gml_st::__anon73e87aa30f11::TensorCastOfLoopInsOutsFolder1073 SmallVector<Value, 4> insertCastsAndCloneBody(
1074 const CastOpsOfArgs &inputCasts, const CastOpsOfArgs &outputCasts,
1075 LoopOp loop, LoopOp newLoop, PatternRewriter &rewriter) const {
1076 auto loc = newLoop.getLoc();
1077 BlockAndValueMapping bvm;
1078 bvm.map(loop.getInductionVars(), newLoop.getInductionVars());
1079
1080 auto innerBuilder =
1081 OpBuilder::atBlockEnd(newLoop.getBody(), rewriter.getListener());
1082
1083 Value oldArg, newArg, yieldArg, result;
1084 tensor::CastOp argCast;
1085
1086 // Map inputs, insert `tensor.cast` if necessary.
1087 for (auto item : llvm::zip(loop.getRegionInputArgs(),
1088 newLoop.getRegionInputArgs(), inputCasts.ops)) {
1089 std::tie(oldArg, newArg, argCast) = item;
1090 if (!argCast) {
1091 bvm.map(oldArg, newArg);
1092 continue;
1093 }
1094 Value newCast =
1095 innerBuilder.create<tensor::CastOp>(loc, argCast.getType(), newArg);
1096 bvm.map(oldArg, newCast);
1097 }
1098
1099 // Map outputs, insert `tensor.cast` and cast the loop results if necessary.
1100 SmallVector<Value, 4> newResults;
1101 rewriter.setInsertionPointAfter(newLoop);
1102 for (auto item :
1103 llvm::zip(loop.getRegionOutputArgs(), newLoop.getRegionOutputArgs(),
1104 outputCasts.ops, newLoop.getResults())) {
1105 std::tie(oldArg, newArg, argCast, result) = item;
1106 if (!argCast) {
1107 bvm.map(oldArg, newArg);
1108 newResults.push_back(result);
1109 continue;
1110 }
1111 Value newCast =
1112 innerBuilder.create<tensor::CastOp>(loc, argCast.getType(), newArg);
1113 bvm.map(oldArg, newCast);
1114
1115 newResults.push_back(
1116 rewriter.create<tensor::CastOp>(loc, argCast.getType(), result));
1117 }
1118
1119 // Clone loop body.
1120 for (auto &op : loop.getBody()->without_terminator())
1121 innerBuilder.clone(op, bvm);
1122
1123 // Cast yield arguments to the new type.
1124 SmallVector<Value, 4> yieldArgs =
1125 loop.getBody()->getTerminator()->getOperands();
1126 SmallVector<Value, 4> newYieldArgs;
1127 for (auto item : llvm::zip(yieldArgs, outputCasts.ops)) {
1128 std::tie(yieldArg, argCast) = item;
1129 if (!argCast) {
1130 newYieldArgs.push_back(bvm.lookup(yieldArg));
1131 continue;
1132 }
1133 newYieldArgs.push_back(innerBuilder.create<tensor::CastOp>(
1134 loc, argCast.getSource().getType(), bvm.lookup(yieldArg)));
1135 }
1136 innerBuilder.create<YieldOp>(loc, newYieldArgs);
1137 return newResults;
1138 }
1139 };
1140
1141 /// Removes loops in which at least one lower/upper bound pair consists
1142 /// of the same values - such loops have an empty iteration domain.
1143 struct FoldEmptyLoops : public OpRewritePattern<LoopOp> {
1144 using OpRewritePattern<LoopOp>::OpRewritePattern;
1145
matchAndRewritemlir::gml_st::__anon73e87aa30f11::FoldEmptyLoops1146 LogicalResult matchAndRewrite(LoopOp op,
1147 PatternRewriter &rewriter) const override {
1148 for (auto dim : llvm::zip(op.lowerBound(), op.upperBound())) {
1149 if (std::get<0>(dim) != std::get<1>(dim)) continue;
1150 SmallVector<Value> tensorOutputs;
1151 for (Value out : op.outputs()) {
1152 if (out.getType().isa<RankedTensorType>()) tensorOutputs.push_back(out);
1153 }
1154 rewriter.replaceOp(op, tensorOutputs);
1155 return success();
1156 }
1157 return failure();
1158 }
1159 };
1160
1161 } // namespace
1162
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)1163 void LoopOp::getCanonicalizationPatterns(RewritePatternSet &results,
1164 MLIRContext *context) {
1165 results
1166 .add<FoldEmptyLoops, LoopInputsFolder, LoopResultsFolder,
1167 DimOfLoopInsOutsFolder<tensor::DimOp>,
1168 DimOfLoopInsOutsFolder<memref::DimOp>,
1169 DimOfLoopResultFolder<tensor::DimOp>,
1170 DimOfLoopResultFolder<memref::DimOp>, TensorCastOfLoopInsOutsFolder>(
1171 context);
1172 }
1173
1174 /// This is used for patterns of the form
1175 /// ```
1176 /// gml_st.loop(memrefcast(%src)) -> gml_st.loop(%src)
1177 /// ```
1178 /// It folds the source of the memref.cast into the root operation directly.
fold(ArrayRef<Attribute>,SmallVectorImpl<OpFoldResult> &)1179 LogicalResult LoopOp::fold(ArrayRef<Attribute>,
1180 SmallVectorImpl<OpFoldResult> &) {
1181 LoopOp op = *this;
1182 bool folded = false;
1183 Location loc = op->getLoc();
1184
1185 Block *body = op.getBody();
1186 OpBuilder b = OpBuilder::atBlockBegin(body);
1187
1188 // Update `input` and `output` operands and block arguments if necessary.
1189 // Operands list: [lbs, ubs, steps, inputs, outputs].
1190 // Block args list: [ivs, inputs, outputs].
1191 for (size_t operandIndex = op.getNumControlOperands(),
1192 bbArgIndex = op.getNumLoops(), e = op.getNumOperands();
1193 operandIndex < e; ++operandIndex, ++bbArgIndex) {
1194 OpOperand &operand = op->getOpOperand(operandIndex);
1195
1196 auto castOp = operand.get().getDefiningOp<memref::CastOp>();
1197 if (castOp && memref::CastOp::canFoldIntoConsumerOp(castOp)) {
1198 operand.set(castOp.getOperand());
1199 BlockArgument newBbArg = body->insertArgument(
1200 bbArgIndex, castOp.getOperand().getType(), op.getLoc());
1201 BlockArgument oldBbArg = body->getArgument(newBbArg.getArgNumber() + 1);
1202
1203 // Insert memref.cast back to the original type.
1204 oldBbArg.replaceAllUsesWith(
1205 b.create<memref::CastOp>(loc, oldBbArg.getType(), newBbArg));
1206 body->eraseArgument(oldBbArg.getArgNumber());
1207
1208 folded = true;
1209 }
1210 }
1211 return success(folded);
1212 }
1213
1214 //===----------------------------------------------------------------------===//
1215 // YieldOp
1216 //===----------------------------------------------------------------------===//
1217
verify()1218 LogicalResult YieldOp::verify() {
1219 auto *parentOp = getOperation()->getParentOp();
1220
1221 if (auto setYield = dyn_cast<SetYieldOp>(parentOp)) {
1222 if (values().size() != 1)
1223 return emitOpError(
1224 "expected a single argument for the terminator of accumulator "
1225 "region");
1226 return success();
1227 }
1228 auto loopOp = cast<LoopOp>(parentOp);
1229 // Check if output args with tensor types match results types.
1230 SmallVector<Value, 2> tensorOuts;
1231 llvm::copy_if(
1232 loopOp.outputs(), std::back_inserter(tensorOuts),
1233 [&](Value out) { return out.getType().isa<RankedTensorType>(); });
1234 if (tensorOuts.size() != values().size())
1235 return emitOpError("expected number of tensor output args = ")
1236 << tensorOuts.size()
1237 << " to match the number of yield operands = " << values().size();
1238
1239 TypeRange tensorTypes{ValueRange{tensorOuts}};
1240 for (auto &item :
1241 llvm::enumerate(llvm::zip(tensorTypes, getOperandTypes()))) {
1242 Type outType, resultType;
1243 unsigned index = item.index();
1244 std::tie(outType, resultType) = item.value();
1245 if (outType != resultType)
1246 return emitOpError("expected yield operand ")
1247 << index << " with type = " << resultType
1248 << " to match output arg type = " << outType;
1249 }
1250 return success();
1251 }
1252
1253 //===----------------------------------------------------------------------===//
1254 // SpaceOp
1255 //===----------------------------------------------------------------------===//
1256
build(OpBuilder & builder,OperationState & result,ArrayRef<OpFoldResult> sizes,ArrayRef<NamedAttribute> attrs)1257 void SpaceOp::build(OpBuilder &builder, OperationState &result,
1258 ArrayRef<OpFoldResult> sizes,
1259 ArrayRef<NamedAttribute> attrs) {
1260 SmallVector<Value> dynamicSizes;
1261 SmallVector<int64_t> staticSizes;
1262 for (OpFoldResult size : sizes)
1263 dispatchIndexOpFoldResult(size, dynamicSizes, staticSizes,
1264 ShapedType::kDynamicSize);
1265 build(builder, result, TileType::get(builder.getContext(), staticSizes),
1266 dynamicSizes, builder.getI64ArrayAttr(staticSizes));
1267 result.addAttributes(attrs);
1268 }
1269
inferReturnTypes(MLIRContext * ctx,Optional<Location>,ValueRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<Type> & inferredReturnTypes)1270 LogicalResult SpaceOp::inferReturnTypes(
1271 MLIRContext *ctx, Optional<Location> /*loc*/, ValueRange operands,
1272 DictionaryAttr attributes, RegionRange regions,
1273 SmallVectorImpl<Type> &inferredReturnTypes) {
1274 SpaceOp::Adaptor adaptor(operands, attributes, regions);
1275 SmallVector<int64_t> shape = llvm::to_vector(
1276 llvm::map_range(adaptor.static_sizes(), [&](const Attribute &val) {
1277 return val.cast<IntegerAttr>().getValue().getSExtValue();
1278 }));
1279 auto resultTy = TileType::get(ctx, shape);
1280 inferredReturnTypes.push_back(resultTy);
1281 return success();
1282 }
1283
verify()1284 LogicalResult SpaceOp::verify() {
1285 auto resultTy = getType().cast<TileType>();
1286 return mlir::verifyListOfOperandsOrIntegers(
1287 getOperation(), "size", resultTy.getShape().size(), static_sizes(),
1288 dynamic_sizes(), ShapedType::isDynamic);
1289 }
1290
getNumDynamicEntriesUpToIdx(unsigned idx)1291 unsigned SpaceOp::getNumDynamicEntriesUpToIdx(unsigned idx) {
1292 return std::count_if(static_sizes().begin(), static_sizes().begin() + idx,
1293 [&](const mlir::Attribute size) {
1294 return mlir::ShapedType::isDynamic(
1295 size.cast<mlir::IntegerAttr>().getInt());
1296 });
1297 }
1298
getDynamicSize(unsigned idx)1299 mlir::Value SpaceOp::getDynamicSize(unsigned idx) {
1300 auto numDynamic = getNumDynamicEntriesUpToIdx(idx);
1301 return dynamic_sizes()[numDynamic];
1302 }
1303
1304 //===----------------------------------------------------------------------===//
1305 // PointOp
1306 //===----------------------------------------------------------------------===//
1307
build(OpBuilder & builder,OperationState & result,Value superset,ArrayRef<OpFoldResult> offsets,ArrayRef<NamedAttribute> attrs)1308 void PointOp::build(OpBuilder &builder, OperationState &result, Value superset,
1309 ArrayRef<OpFoldResult> offsets,
1310 ArrayRef<NamedAttribute> attrs) {
1311 SmallVector<Value> dynamicOffsets;
1312 SmallVector<int64_t> staticOffsets;
1313 for (OpFoldResult offset : offsets)
1314 dispatchIndexOpFoldResult(offset, dynamicOffsets, staticOffsets,
1315 ShapedType::kDynamicStrideOrOffset);
1316 build(builder, result, PointType::get(builder.getContext()), superset,
1317 dynamicOffsets, builder.getI64ArrayAttr(staticOffsets));
1318 result.addAttributes(attrs);
1319 }
1320
verify()1321 LogicalResult PointOp::verify() {
1322 auto tileShape = superset().getType().cast<TileType>().getShape();
1323 if (failed(mlir::verifyListOfOperandsOrIntegers(
1324 getOperation(), "index", tileShape.size(), static_indices(),
1325 dynamic_indices(), ShapedType::isDynamicStrideOrOffset))) {
1326 return failure();
1327 }
1328 // Check whether the known indices are in-bounds of known dimension sizes.
1329 for (auto dimAndIndex : llvm::zip(tileShape, static_indices())) {
1330 auto dimSize = std::get<0>(dimAndIndex);
1331 auto index =
1332 std::get<1>(dimAndIndex).dyn_cast<mlir::IntegerAttr>().getInt();
1333 if (index == ShapedType::kDynamicStrideOrOffset) continue;
1334 if (index < 0) {
1335 return emitOpError("expected index = ") << index << " to be non-negative";
1336 }
1337 if (dimSize != ShapedType::kDynamicSize && index >= dimSize) {
1338 return emitOpError("expected index = ")
1339 << index << " to be between 0 and " << (dimSize - 1);
1340 }
1341 }
1342 return success();
1343 }
1344
1345 //
1346 //===----------------------------------------------------------------------===//
1347 // TileOp
1348 //===----------------------------------------------------------------------===//
1349
build(OpBuilder & b,OperationState & result,Value superset,ArrayRef<OpFoldResult> offsets,ArrayRef<OpFoldResult> sizes,ArrayRef<OpFoldResult> strides,ArrayRef<NamedAttribute> attrs)1350 void TileOp::build(OpBuilder &b, OperationState &result, Value superset,
1351 ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
1352 ArrayRef<OpFoldResult> strides,
1353 ArrayRef<NamedAttribute> attrs) {
1354 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
1355 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
1356 dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets,
1357 ShapedType::kDynamicStrideOrOffset);
1358 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes,
1359 ShapedType::kDynamicSize);
1360 dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides,
1361 ShapedType::kDynamicStrideOrOffset);
1362 auto tileType = TileType::get(b.getContext(), staticSizes);
1363 build(b, result, tileType, superset, dynamicOffsets, dynamicSizes,
1364 dynamicStrides, b.getI64ArrayAttr(staticOffsets),
1365 b.getI64ArrayAttr(staticSizes), b.getI64ArrayAttr(staticStrides));
1366 result.addAttributes(attrs);
1367 }
1368
inferReturnTypes(MLIRContext * ctx,Optional<Location>,ValueRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<Type> & inferredReturnTypes)1369 LogicalResult TileOp::inferReturnTypes(
1370 MLIRContext *ctx, Optional<Location> /*loc*/, ValueRange operands,
1371 DictionaryAttr attributes, RegionRange regions,
1372 SmallVectorImpl<Type> &inferredReturnTypes) {
1373 // Derive result shape.
1374 TileOp::Adaptor adaptor(operands, attributes, regions);
1375 SmallVector<int64_t> shape = llvm::to_vector(
1376 llvm::map_range(adaptor.static_sizes(), [&](const auto &size) {
1377 return size.template dyn_cast<mlir::IntegerAttr>()
1378 .getValue()
1379 .getSExtValue();
1380 }));
1381
1382 auto resultTy = TileType::get(ctx, shape);
1383 inferredReturnTypes.push_back(resultTy);
1384 return success();
1385 }
1386
verify()1387 LogicalResult TileOp::verify() {
1388 auto supersetTy = superset().getType().cast<TileType>();
1389 auto rank = supersetTy.getShape().size();
1390 if (failed(mlir::verifyListOfOperandsOrIntegers(getOperation(), "size", rank,
1391 static_sizes(), sizes(),
1392 ShapedType::isDynamic))) {
1393 return failure();
1394 }
1395 if (failed(mlir::verifyListOfOperandsOrIntegers(
1396 getOperation(), "offset", rank, static_offsets(), offsets(),
1397 ShapedType::isDynamicStrideOrOffset))) {
1398 return failure();
1399 }
1400 if (failed(mlir::verifyListOfOperandsOrIntegers(
1401 getOperation(), "stride", rank, static_strides(), strides(),
1402 ShapedType::isDynamicStrideOrOffset))) {
1403 return failure();
1404 }
1405 for (auto it : llvm::zip(supersetTy.getShape(), static_offsets(),
1406 static_sizes(), static_strides())) {
1407 auto offset =
1408 std::get<1>(it).dyn_cast<mlir::IntegerAttr>().getValue().getSExtValue();
1409 if (offset < 0 && offset != ShapedType::kDynamicStrideOrOffset) {
1410 return emitOpError("expected offset = ")
1411 << offset << " to be non-negative";
1412 }
1413 auto size =
1414 std::get<2>(it).dyn_cast<mlir::IntegerAttr>().getValue().getSExtValue();
1415 if (size < 0 && size != ShapedType::kDynamicSize) {
1416 return emitOpError("expected size = ") << size << " to be non-negative";
1417 }
1418 auto stride =
1419 std::get<3>(it).dyn_cast<mlir::IntegerAttr>().getValue().getSExtValue();
1420 if (stride < 0 && stride != ShapedType::kDynamicStrideOrOffset) {
1421 return emitOpError("expected stride = ")
1422 << stride << " to be non-negative";
1423 }
1424 auto argSize = std::get<0>(it);
1425 // If the argument tile has a dynamic dimension, no additional verification
1426 // is possible.
1427 if (argSize == ShapedType::kDynamicSize) continue;
1428 if (offset >= 0) {
1429 if (stride >= 0 && size > 0) {
1430 int64_t largestIndex = offset + stride * (size - 1);
1431 if (largestIndex >= argSize) {
1432 return emitOpError("offset = ")
1433 << offset << " size = " << size << " stride = " << stride
1434 << " causes access out of bounds at " << largestIndex
1435 << " for argument dimension size = " << argSize;
1436 }
1437 } else if (offset >= argSize) {
1438 return emitOpError("offset = ")
1439 << offset
1440 << " is out of bounds for argument dimension size = " << argSize;
1441 }
1442 } else if (stride > 0 && size > 0 && stride * (size - 1) >= argSize) {
1443 return emitOpError("size = ")
1444 << size << " stride = " << stride
1445 << " causes access out of bounds for argument dimension size = "
1446 << argSize;
1447 }
1448 }
1449 return success();
1450 }
1451
1452 namespace {
1453
multiplyOperandsOrIntegers(OpBuilder & builder,Location loc,OpFoldResult lhs,OpFoldResult rhs)1454 OpFoldResult multiplyOperandsOrIntegers(OpBuilder &builder, Location loc,
1455 OpFoldResult lhs, OpFoldResult rhs) {
1456 // Both operands are static.
1457 if (lhs.is<Attribute>() && rhs.is<Attribute>()) {
1458 return builder.getI64IntegerAttr(
1459 lhs.get<Attribute>().cast<IntegerAttr>().getInt() *
1460 rhs.get<Attribute>().cast<IntegerAttr>().getInt());
1461 }
1462
1463 // Exploit commutativity and move static operand to the left (if any).
1464 if (rhs.is<Attribute>()) std::swap(lhs, rhs);
1465
1466 // Create constant if needed.
1467 if (lhs.is<Attribute>()) {
1468 int64_t lhsInt = lhs.get<Attribute>().cast<IntegerAttr>().getInt();
1469
1470 // Exploit static operand if possible.
1471 if (lhsInt == 0) return lhs;
1472 if (lhsInt == 1) return rhs;
1473
1474 lhs = builder.create<arith::ConstantIndexOp>(loc, lhsInt).getResult();
1475 }
1476
1477 // Multiply.
1478 return builder.create<arith::MulIOp>(loc, lhs.get<Value>(), rhs.get<Value>())
1479 .getResult();
1480 }
1481
addOperandsOrIntegers(OpBuilder & builder,Location loc,OpFoldResult lhs,OpFoldResult rhs)1482 OpFoldResult addOperandsOrIntegers(OpBuilder &builder, Location loc,
1483 OpFoldResult lhs, OpFoldResult rhs) {
1484 // Both operands are static.
1485 if (lhs.is<Attribute>() && rhs.is<Attribute>()) {
1486 return builder.getI64IntegerAttr(
1487 lhs.get<Attribute>().cast<IntegerAttr>().getInt() +
1488 rhs.get<Attribute>().cast<IntegerAttr>().getInt());
1489 }
1490
1491 // Exploit commutativity and move static operand to the left (if any).
1492 if (rhs.is<Attribute>()) std::swap(lhs, rhs);
1493
1494 // Create constant if needed.
1495 if (lhs.is<Attribute>()) {
1496 int64_t lhsInt = lhs.get<Attribute>().cast<IntegerAttr>().getInt();
1497
1498 // Exploit static operand if possible.
1499 if (lhsInt == 0) return rhs;
1500
1501 lhs = builder.create<arith::ConstantIndexOp>(loc, lhsInt).getResult();
1502 }
1503
1504 // Add.
1505 return builder.create<arith::AddIOp>(loc, lhs.get<Value>(), rhs.get<Value>())
1506 .getResult();
1507 }
1508
1509 // Compose offsets with newOffset = supersetOffset + supersetStride * offset.
composeOffsets(const llvm::SmallVectorImpl<OpFoldResult> & supersetOffsets,const llvm::SmallVectorImpl<OpFoldResult> & supersetStrides,const llvm::SmallVectorImpl<OpFoldResult> & offsets,Location loc,OpBuilder & builder)1510 SmallVector<OpFoldResult> composeOffsets(
1511 const llvm::SmallVectorImpl<OpFoldResult> &supersetOffsets,
1512 const llvm::SmallVectorImpl<OpFoldResult> &supersetStrides,
1513 const llvm::SmallVectorImpl<OpFoldResult> &offsets, Location loc,
1514 OpBuilder &builder) {
1515 SmallVector<OpFoldResult> composedOffsets;
1516 for (auto it : llvm::zip(supersetOffsets, supersetStrides, offsets)) {
1517 composedOffsets.push_back(addOperandsOrIntegers(
1518 builder, loc, std::get<0>(it),
1519 multiplyOperandsOrIntegers(builder, loc, std::get<1>(it),
1520 std::get<2>(it))));
1521 }
1522 return composedOffsets;
1523 }
1524
1525 // Compose strides with newStride = supersetStride * stride.
composeStrides(OpBuilder & builder,Location loc,const llvm::SmallVectorImpl<OpFoldResult> & supersetStrides,const llvm::SmallVectorImpl<OpFoldResult> & strides)1526 SmallVector<OpFoldResult> composeStrides(
1527 OpBuilder &builder, Location loc,
1528 const llvm::SmallVectorImpl<OpFoldResult> &supersetStrides,
1529 const llvm::SmallVectorImpl<OpFoldResult> &strides) {
1530 SmallVector<OpFoldResult> composedStrides;
1531 for (auto it : llvm::zip(supersetStrides, strides)) {
1532 composedStrides.push_back(multiplyOperandsOrIntegers(
1533 builder, loc, std::get<0>(it), std::get<1>(it)));
1534 }
1535 return composedStrides;
1536 }
1537
1538 } // namespace
1539
compose(OpBuilder & builder)1540 Value TileOp::compose(OpBuilder &builder) {
1541 auto supersetOp = llvm::dyn_cast_or_null<TileOp>(superset().getDefiningOp());
1542 if (!supersetOp) return {};
1543
1544 // Compose offsets with newOffset = supersetOffset + supersetStride *
1545 // offset.
1546 auto loc = getLoc();
1547 auto composedOffsets =
1548 composeOffsets(supersetOp.getMixedOffsets(), supersetOp.getMixedStrides(),
1549 getMixedOffsets(), loc, builder);
1550
1551 // Compose strides with newStride = supersetStride * stride.
1552 auto composedStrides = composeStrides(
1553 builder, loc, supersetOp.getMixedStrides(), getMixedStrides());
1554
1555 // Build the composed tile op.
1556 return builder.create<TileOp>(loc, supersetOp.superset(), composedOffsets,
1557 getMixedSizes(), composedStrides);
1558 }
1559
1560 //===----------------------------------------------------------------------===//
1561 // PointOp
1562 //===----------------------------------------------------------------------===//
1563
compose(OpBuilder & builder)1564 Value PointOp::compose(OpBuilder &builder) {
1565 auto supersetOp = llvm::dyn_cast_or_null<TileOp>(superset().getDefiningOp());
1566 if (!supersetOp) return {};
1567
1568 // Compose offsets with newOffset = supersetOffset + supersetStride *
1569 // offset.
1570 auto loc = getLoc();
1571 auto composedOffsets = decomposeMixedStridesOrOffsets(
1572 builder,
1573 composeOffsets(
1574 supersetOp.getMixedOffsets(), supersetOp.getMixedStrides(),
1575 mlir::getMixedStridesOrOffsets(static_indices(), dynamic_indices()),
1576 loc, builder));
1577
1578 // Build the composed point op.
1579 return builder.create<PointOp>(loc, supersetOp.superset(),
1580 composedOffsets.second, composedOffsets.first);
1581 }
1582
1583 //===----------------------------------------------------------------------===//
1584 // DropDimsOp
1585 //===----------------------------------------------------------------------===//
1586
inferReturnTypes(MLIRContext * ctx,Optional<Location>,ValueRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<Type> & inferredReturnTypes)1587 LogicalResult DropDimsOp::inferReturnTypes(
1588 MLIRContext *ctx, Optional<Location> /*loc*/, ValueRange operands,
1589 DictionaryAttr attributes, RegionRange regions,
1590 SmallVectorImpl<Type> &inferredReturnTypes) {
1591 DropDimsOp::Adaptor adaptor(operands, attributes, regions);
1592 Type argTy = adaptor.superset().getType();
1593
1594 // If the argument is of point type, so is the result.
1595 if (auto pointTy = argTy.dyn_cast<PointType>()) {
1596 inferredReturnTypes.push_back(argTy);
1597 return success();
1598 }
1599
1600 // If the argument is of tile type, we can skip the dropped dimensions to
1601 // derive the result type.
1602 if (auto tileTy = argTy.dyn_cast<TileType>()) {
1603 auto argShape = tileTy.getShape();
1604 SmallVector<int64_t> resultShape = llvm::to_vector(llvm::map_range(
1605 adaptor.remaining_dims(), [&](const auto &d) { return argShape[d]; }));
1606 auto resultTy = TileType::get(ctx, resultShape);
1607 inferredReturnTypes.push_back(resultTy);
1608 return success();
1609 }
1610
1611 return failure();
1612 }
1613
1614 namespace {
1615
selectMixedValues(const SmallVectorImpl<OpFoldResult> & mixedValues,ArrayRef<int64_t> selection)1616 SmallVector<OpFoldResult> selectMixedValues(
1617 const SmallVectorImpl<OpFoldResult> &mixedValues,
1618 ArrayRef<int64_t> selection) {
1619 return llvm::to_vector(
1620 llvm::map_range(selection, [&](int64_t i) { return mixedValues[i]; }));
1621 }
1622
1623 // Composition set by selecting a subset of its dimensions. Both the dimensions
1624 // to select, and the order in which they should be selected, are specified by
1625 // `selection`.
selectDimsFromSet(OpBuilder & builder,Location loc,Type type,Value set,ArrayRef<int64_t> selection)1626 Value selectDimsFromSet(OpBuilder &builder, Location loc, Type type, Value set,
1627 ArrayRef<int64_t> selection) {
1628 // Case: space
1629 Operation *setDef = set.getDefiningOp();
1630 if (auto spaceOp = llvm::dyn_cast_or_null<SpaceOp>(setDef)) {
1631 auto spaceSizes =
1632 getMixedSizes(spaceOp.static_sizes(), spaceOp.dynamic_sizes());
1633 auto newSpaceSizes = selectMixedValues(spaceSizes, selection);
1634 auto newSpaceSizesDecomposed = decomposeMixedSizes(builder, newSpaceSizes);
1635 return builder.create<SpaceOp>(loc, newSpaceSizesDecomposed.second,
1636 newSpaceSizesDecomposed.first);
1637 }
1638
1639 // Case: point(space)
1640 if (PointOp pointOp = llvm::dyn_cast_or_null<PointOp>(setDef)) {
1641 auto newSpace =
1642 selectDimsFromSet(builder, loc, type, pointOp.superset(), selection);
1643 auto pointOffsets = getMixedStridesOrOffsets(pointOp.static_indices(),
1644 pointOp.dynamic_indices());
1645 auto newPointOffsets = selectMixedValues(pointOffsets, selection);
1646 auto newPointOffsetsDecomposed =
1647 decomposeMixedStridesOrOffsets(builder, newPointOffsets);
1648 return builder.create<PointOp>(loc, newSpace,
1649 newPointOffsetsDecomposed.second,
1650 newPointOffsetsDecomposed.first);
1651 }
1652
1653 // Case: tile(space)
1654 if (TileOp tileOp = llvm::dyn_cast_or_null<TileOp>(setDef)) {
1655 auto newSpace =
1656 selectDimsFromSet(builder, loc, type, tileOp.superset(), selection);
1657
1658 auto tileOffsets =
1659 getMixedStridesOrOffsets(tileOp.static_offsets(), tileOp.offsets());
1660 auto newTileOffsets = selectMixedValues(tileOffsets, selection);
1661 auto newTileOffsetsDecomposed =
1662 decomposeMixedStridesOrOffsets(builder, newTileOffsets);
1663
1664 auto tileSizes = getMixedSizes(tileOp.static_sizes(), tileOp.sizes());
1665 auto newTileSizes = selectMixedValues(tileSizes, selection);
1666 auto newTileSizesDecomposed = decomposeMixedSizes(builder, newTileSizes);
1667
1668 auto tileStrides =
1669 getMixedStridesOrOffsets(tileOp.static_strides(), tileOp.strides());
1670 auto newTileStrides = selectMixedValues(tileStrides, selection);
1671 auto newTileStridesDecomposed =
1672 decomposeMixedStridesOrOffsets(builder, newTileStrides);
1673
1674 return builder.create<TileOp>(
1675 loc, newSpace, newTileOffsetsDecomposed.second,
1676 newTileSizesDecomposed.second, newTileStridesDecomposed.second,
1677 newTileOffsetsDecomposed.first, newTileSizesDecomposed.first,
1678 newTileStridesDecomposed.first);
1679 }
1680
1681 return {};
1682 }
1683
1684 } // namespace
1685
compose(OpBuilder & builder)1686 Value DropDimsOp::compose(OpBuilder &builder) {
1687 // We can compose with a TileOp operand which has a SpaceOp operand, or
1688 // compose with a SpaceOp operand.
1689 return selectDimsFromSet(builder, getLoc(), getType(), superset(),
1690 remaining_dims());
1691 }
1692
1693 //===----------------------------------------------------------------------===//
1694 // TransposeDimsOp
1695 //===----------------------------------------------------------------------===//
1696
inferReturnTypes(MLIRContext * ctx,Optional<Location>,ValueRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<Type> & inferredReturnTypes)1697 LogicalResult TransposeDimsOp::inferReturnTypes(
1698 MLIRContext *ctx, Optional<Location> /*loc*/, ValueRange operands,
1699 DictionaryAttr attributes, RegionRange regions,
1700 SmallVectorImpl<Type> &inferredReturnTypes) {
1701 TransposeDimsOp::Adaptor adaptor(operands, attributes, regions);
1702 const Type argTy = adaptor.superset().getType();
1703
1704 // If the argument is of point type, so is the result.
1705 if (auto pointTy = argTy.dyn_cast<PointType>()) {
1706 inferredReturnTypes.push_back(argTy);
1707 return success();
1708 }
1709
1710 // If the argument is of tile type, we can transpose the type's dimensions.
1711 if (auto tileTy = argTy.dyn_cast<TileType>()) {
1712 auto argShape = tileTy.getShape();
1713 const SmallVector<int64_t> resultShape = llvm::to_vector(llvm::map_range(
1714 adaptor.permutation(), [&](const auto &d) { return argShape[d]; }));
1715 auto resultTy = TileType::get(ctx, resultShape);
1716 inferredReturnTypes.push_back(resultTy);
1717 return success();
1718 }
1719
1720 return failure();
1721 }
1722
compose(OpBuilder & builder)1723 Value TransposeDimsOp::compose(OpBuilder &builder) {
1724 // We can compose with a TileOp operand which has a SpaceOp operand, or
1725 // compose with a SpaceOp operand. transpose_tile(tile(space, offsets, sizes,
1726 // strides)) is replaced by tile(transpose(space), transpose(offsets),
1727 // transpose(sizes), transpose(strides)). transpose_tile(space) is replaced by
1728 // transpose(space).
1729
1730 return selectDimsFromSet(builder, getLoc(), getType(), superset(),
1731 permutation());
1732 }
1733
verify()1734 LogicalResult TransposeDimsOp::verify() {
1735 // Verify that `permutation` is in fact a permutation.
1736 size_t rank = permutation().size();
1737 SmallVector<int64_t> position(rank, -1);
1738 for (const auto &it : llvm::enumerate(permutation())) {
1739 int64_t dim = it.value();
1740 if (dim < 0 || dim >= static_cast<int64_t>(rank)) {
1741 return emitOpError("permutation[")
1742 << it.index() << "] = " << dim << " is outside of range [0, "
1743 << rank - 1 << "]";
1744 }
1745 if (position[dim] >= 0) {
1746 return emitOpError(
1747 "expected permutation attribute to contain no duplicate "
1748 "values, but got ")
1749 << dim << " at positions " << position[dim] << " and "
1750 << it.index();
1751 }
1752 position[dim] = it.index();
1753 }
1754
1755 // Verify tile-specific relationship between types and permutation. The
1756 // constraints between argument and result type are verified through the
1757 // implementation of `inferReturnTypes`.
1758 if (auto tileTy = getType().dyn_cast<TileType>()) {
1759 size_t tileRank = tileTy.getShape().size();
1760 if (tileRank != rank) {
1761 return emitOpError("expected result rank ")
1762 << tileRank << " to match the permutation size of " << rank << ".";
1763 }
1764 }
1765
1766 return success();
1767 }
1768
1769 //===----------------------------------------------------------------------===//
1770 // SetYieldOp
1771 //===----------------------------------------------------------------------===//
1772
1773 using AccumulatorRegionBuilderFn =
1774 function_ref<void(OpBuilder &, Location, Value, Value)>;
1775
build(OpBuilder & builder,OperationState & result)1776 void SetYieldOp::build(OpBuilder &builder, OperationState &result) {
1777 build(builder, result, llvm::None, llvm::None, llvm::None);
1778 }
1779
build(OpBuilder & builder,OperationState & result,ValueRange srcs,ValueRange dsts,ValueRange sets)1780 void SetYieldOp::build(OpBuilder &builder, OperationState &result,
1781 ValueRange srcs, ValueRange dsts, ValueRange sets) {
1782 SmallVector<bool, 2> accumulatorFlags(srcs.size(), false);
1783 build(builder, result, srcs, dsts, sets,
1784 builder.getBoolArrayAttr(accumulatorFlags), llvm::None);
1785 }
1786
build(OpBuilder & builder,OperationState & result,ValueRange srcs,ValueRange dsts,ValueRange sets,ArrayAttr accumulatorFlags,ArrayRef<AccumulatorRegionBuilderFn> accumulatorBuilderFns)1787 void SetYieldOp::build(
1788 OpBuilder &builder, OperationState &result, ValueRange srcs,
1789 ValueRange dsts, ValueRange sets, ArrayAttr accumulatorFlags,
1790 ArrayRef<AccumulatorRegionBuilderFn> accumulatorBuilderFns) {
1791 assert(dsts.size() == srcs.size() &&
1792 "`dsts` and `srcs` should have the same size");
1793 assert(sets.size() == srcs.size() &&
1794 "`sets` and `srcs` should have the same size");
1795 assert(accumulatorFlags.size() == srcs.size() &&
1796 "`accumulatorFlags` and `srcs` should have the same size");
1797
1798 auto accumulatorCount = llvm::count_if(accumulatorFlags, [](Attribute attr) {
1799 return attr.cast<BoolAttr>().getValue();
1800 });
1801 (void)accumulatorCount;
1802 assert(accumulatorCount ==
1803 static_cast<int64_t>(accumulatorBuilderFns.size()) &&
1804 "the number of flags set in `accumulatorFlags` attribute should be "
1805 "equal to the number of `accumulatorBuilderFns`");
1806
1807 result.addOperands(srcs);
1808 result.addOperands(dsts);
1809 result.addOperands(sets);
1810 result.addAttribute(SetYieldOp::accumulatorFlagsAttrName(result.name),
1811 accumulatorFlags);
1812
1813 const auto *builderFnIt = accumulatorBuilderFns.begin();
1814 for (auto item : llvm::zip(srcs, accumulatorFlags)) {
1815 Value src = std::get<0>(item);
1816 auto accumulatorFlag = std::get<1>(item).cast<BoolAttr>();
1817
1818 if (!accumulatorFlag.getValue()) continue;
1819 Region *region = result.addRegion();
1820 OpBuilder::InsertionGuard g(builder);
1821 SmallVector<Type, 2> argTypes(2, src.getType());
1822 builder.createBlock(region);
1823 Block &bodyBlock = region->front();
1824 bodyBlock.addArguments(argTypes, {result.location, result.location});
1825
1826 builder.setInsertionPointToStart(&bodyBlock);
1827 (*builderFnIt)(builder, result.location, bodyBlock.getArgument(0),
1828 bodyBlock.getArgument(1));
1829 ++builderFnIt;
1830 }
1831 }
1832
verify()1833 LogicalResult SetYieldOp::verify() {
1834 auto accumulatorCount = llvm::count_if(
1835 accumulatorFlags(),
1836 [](Attribute attr) { return attr.cast<BoolAttr>().getValue(); });
1837 if (accumulatorCount != static_cast<int64_t>(accumulators().size()))
1838 return emitOpError("expected the number of accumulator regions ")
1839 << accumulators().size()
1840 << " to match the number of set accumulator flags "
1841 << accumulatorCount;
1842
1843 auto *regionIt = accumulators().begin();
1844 for (auto item : llvm::zip(srcs(), accumulatorFlags())) {
1845 Type srcType = std::get<0>(item).getType();
1846 BoolAttr accumulatorFlag = std::get<1>(item).cast<BoolAttr>();
1847 if (!accumulatorFlag.getValue()) continue;
1848
1849 Block &block = regionIt->front();
1850 if (block.getArgumentTypes() != SmallVector<Type>{srcType, srcType})
1851 return emitOpError()
1852 << "expected accumulator region to have 2 arguments of type "
1853 << srcType;
1854 ++regionIt;
1855 }
1856 return success();
1857 }
1858
print(OpAsmPrinter & p)1859 void SetYieldOp::print(OpAsmPrinter &p) {
1860 p.printOptionalAttrDict(getOperation()->getAttrs(), /*elidedAttrs = */
1861 {accumulatorFlagsAttrName().str()});
1862
1863 auto *regionIt = getOperation()->getRegions().begin();
1864 for (auto &en :
1865 llvm::enumerate(llvm::zip(srcs(), dsts(), sets(), accumulatorFlags()))) {
1866 if (en.index() > 0) {
1867 p << ',';
1868 p.printNewline();
1869 }
1870 Value src = std::get<0>(en.value());
1871 Value dst = std::get<1>(en.value());
1872 Value set = std::get<2>(en.value());
1873 auto accumulatorFlag = std::get<3>(en.value()).cast<BoolAttr>();
1874
1875 p << ' ' << src << " into " << dst << '[' << set << ']';
1876
1877 if (accumulatorFlag.getValue()) {
1878 auto &block = regionIt->getBlocks().front();
1879 Value newValue = block.getArgument(0);
1880 Value oldValue = block.getArgument(1);
1881 p << " acc (" << newValue << ", " << oldValue << ": "
1882 << oldValue.getType() << ") ";
1883
1884 p.printRegion(*regionIt, false);
1885 ++regionIt;
1886 }
1887
1888 p << " : " << src.getType() << " into " << dst.getType() << '['
1889 << set.getType() << ']';
1890 }
1891 }
1892
parse(OpAsmParser & parser,OperationState & result)1893 ParseResult SetYieldOp::parse(OpAsmParser &parser, OperationState &result) {
1894 if (parser.parseOptionalAttrDict(result.attributes)) return failure();
1895
1896 SmallVector<bool, 2> accumulatorFlags;
1897 SmallVector<OpAsmParser::UnresolvedOperand, 4> srcs, dsts, sets;
1898 SmallVector<Type, 4> srcTypes, dstTypes, setTypes;
1899
1900 auto parseElt = [&]() -> ParseResult {
1901 OpAsmParser::UnresolvedOperand src;
1902 auto parseResult = parser.parseOptionalOperand(src, false);
1903
1904 if (!parseResult.hasValue()) return success();
1905 srcs.push_back(src);
1906
1907 if (parser.parseKeyword("into") ||
1908 parser.parseOperand(dsts.emplace_back()) || parser.parseLSquare() ||
1909 parser.parseOperand(sets.emplace_back()) || parser.parseRSquare())
1910 return failure();
1911
1912 OpBuilder b(parser.getBuilder().getContext());
1913 bool hasAccumulatorRegion = succeeded(parser.parseOptionalKeyword("acc"));
1914 accumulatorFlags.push_back(hasAccumulatorRegion);
1915 if (hasAccumulatorRegion) {
1916 auto region = std::make_unique<Region>();
1917 OpAsmParser::UnresolvedOperand newValue, oldValue;
1918 Type argType;
1919 if (parser.parseLParen() || parser.parseOperand(newValue) ||
1920 parser.parseComma() || parser.parseOperand(oldValue) ||
1921 parser.parseColonType(argType) || parser.parseRParen())
1922 return failure();
1923
1924 SmallVector<OpAsmParser::Argument, 4> regionArgs;
1925 for (auto value : {newValue, oldValue}) {
1926 auto &arg = regionArgs.emplace_back();
1927 arg.ssaName = value;
1928 arg.type = argType;
1929 }
1930
1931 if (parser.parseRegion(*region, regionArgs)) return failure();
1932 result.addRegion(std::move(region));
1933 }
1934 if (parser.parseColon() || parser.parseType(srcTypes.emplace_back()) ||
1935 parser.parseKeyword("into") ||
1936 parser.parseType(dstTypes.emplace_back()) || parser.parseLSquare() ||
1937 parser.parseType(setTypes.emplace_back()) || parser.parseRSquare())
1938 return failure();
1939
1940 return success();
1941 };
1942 if (parser.parseCommaSeparatedList(AsmParser::Delimiter::None, parseElt))
1943 return failure();
1944
1945 if (parser.resolveOperands(srcs, srcTypes, parser.getCurrentLocation(),
1946 result.operands) ||
1947 parser.resolveOperands(dsts, dstTypes, parser.getCurrentLocation(),
1948 result.operands) ||
1949 parser.resolveOperands(sets, setTypes, parser.getCurrentLocation(),
1950 result.operands))
1951 return failure();
1952
1953 result.addAttribute(SetYieldOp::accumulatorFlagsAttrName(result.name),
1954 parser.getBuilder().getBoolArrayAttr(accumulatorFlags));
1955 return success();
1956 }
1957
1958 //===----------------------------------------------------------------------===//
1959 // OffsetOp
1960 //===----------------------------------------------------------------------===//
1961
fold(ArrayRef<Attribute> operands)1962 OpFoldResult OffsetOp::fold(ArrayRef<Attribute> operands) {
1963 auto idxAttr = operands[1].dyn_cast_or_null<IntegerAttr>();
1964 if (!idxAttr) return {};
1965 int64_t idx = idxAttr.getInt();
1966
1967 // Case: offset(point(space))
1968 Operation *subsetDef = subset().getDefiningOp();
1969 if (auto pointOp = llvm::dyn_cast_or_null<PointOp>(subsetDef)) {
1970 Operation *supersetDef = pointOp.superset().getDefiningOp();
1971
1972 // Can only fold locally if the superset is the root space. Otherwise, rely
1973 // on subset composition.
1974 if (!llvm::isa_and_nonnull<SpaceOp>(supersetDef)) return {};
1975
1976 return ensureIndexTypeForAttribute(mlir::getMixedStridesOrOffsets(
1977 pointOp.static_indices(), pointOp.dynamic_indices())[idx]);
1978 }
1979
1980 // Case: offset(tile(space))
1981 if (auto tileOp = llvm::dyn_cast_or_null<TileOp>(subsetDef)) {
1982 Operation *supersetDef = tileOp.superset().getDefiningOp();
1983
1984 // Can only fold locally if the superset is the root space. Otherwise, rely
1985 // on subset composition.
1986 if (!llvm::isa_and_nonnull<SpaceOp>(supersetDef)) return {};
1987
1988 return ensureIndexTypeForAttribute(mlir::getMixedStridesOrOffsets(
1989 tileOp.static_offsets(), tileOp.offsets())[idx]);
1990 }
1991
1992 // Case: offset(space)
1993 if (llvm::isa_and_nonnull<SpaceOp>(subsetDef)) {
1994 Builder b(getContext());
1995 return b.getIndexAttr(0);
1996 }
1997
1998 return {};
1999 }
2000
2001 //===----------------------------------------------------------------------===//
2002 // SizeOp
2003 //===----------------------------------------------------------------------===//
2004
fold(ArrayRef<Attribute> operands)2005 OpFoldResult SizeOp::fold(ArrayRef<Attribute> operands) {
2006 auto idxAttr = operands[1].dyn_cast_or_null<IntegerAttr>();
2007 if (!idxAttr) return {};
2008 int64_t idx = idxAttr.getInt();
2009
2010 // Case: size(tile(...))
2011 // Note that sizes can also be folded in the presence of nested tiling. There
2012 // is no need to check for an immediate root space here.
2013 Operation *tileDef = tile().getDefiningOp();
2014 if (auto tileOp = llvm::dyn_cast_or_null<TileOp>(tileDef)) {
2015 return ensureIndexTypeForAttribute(tileOp.getMixedSizes()[idx]);
2016 }
2017
2018 // Case: size(space)
2019 if (auto spaceOp = llvm::dyn_cast_or_null<SpaceOp>(tileDef)) {
2020 return ensureIndexTypeForAttribute(mlir::getMixedSizes(
2021 spaceOp.static_sizes(), spaceOp.dynamic_sizes())[idx]);
2022 }
2023
2024 return {};
2025 }
2026
2027 //===----------------------------------------------------------------------===//
2028 // StrideOp
2029 //===----------------------------------------------------------------------===//
2030
fold(ArrayRef<Attribute> operands)2031 OpFoldResult StrideOp::fold(ArrayRef<Attribute> operands) {
2032 auto idxAttr = operands[1].dyn_cast_or_null<IntegerAttr>();
2033 if (!idxAttr) return {};
2034 int64_t idx = idxAttr.getInt();
2035
2036 // Case: offset(tile(space))
2037 Operation *subsetDef = tile().getDefiningOp();
2038 if (auto tileOp = llvm::dyn_cast_or_null<TileOp>(subsetDef)) {
2039 Operation *supersetDef = tileOp.superset().getDefiningOp();
2040
2041 // Can only fold locally if the superset is the root space. Otherwise, rely
2042 // on subset composition.
2043 if (!llvm::isa_and_nonnull<SpaceOp>(supersetDef)) return {};
2044
2045 return ensureIndexTypeForAttribute(mlir::getMixedStridesOrOffsets(
2046 tileOp.static_strides(), tileOp.strides())[idx]);
2047 }
2048
2049 // Case: offset(space)
2050 if (llvm::isa_and_nonnull<SpaceOp>(subsetDef)) {
2051 Builder b(getContext());
2052 return b.getIndexAttr(1);
2053 }
2054
2055 return {};
2056 }
2057
2058 } // namespace gml_st
2059 } // namespace mlir
2060
2061 // Generated op classes.
2062 #define GET_OP_CLASSES
2063 #include "mlir-hlo/Dialect/gml_st/IR/gml_st_ops.cc.inc"
2064