xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/gml_st/IR/gml_st_ops.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "mlir-hlo/Dialect/gml_st/IR/gml_st_ops.h"
17 
18 #include <algorithm>
19 #include <iterator>
20 #include <memory>
21 #include <utility>
22 
23 #include "llvm/ADT/ArrayRef.h"
24 #include "llvm/ADT/STLExtras.h"
25 #include "llvm/ADT/Sequence.h"
26 #include "llvm/ADT/SmallVector.h"
27 #include "llvm/ADT/TypeSwitch.h"
28 #include "llvm/Support/Casting.h"
29 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
30 #include "mlir/Dialect/MemRef/IR/MemRef.h"
31 #include "mlir/Dialect/SCF/IR/SCF.h"
32 #include "mlir/Dialect/Tensor/IR/Tensor.h"
33 #include "mlir/Dialect/Tensor/Utils/Utils.h"
34 #include "mlir/IR/BlockAndValueMapping.h"
35 #include "mlir/IR/BuiltinAttributes.h"
36 #include "mlir/IR/BuiltinTypes.h"
37 #include "mlir/IR/DialectImplementation.h"
38 #include "mlir/IR/OpDefinition.h"
39 #include "mlir/IR/OpImplementation.h"
40 #include "mlir/IR/Operation.h"
41 #include "mlir/IR/PatternMatch.h"
42 #include "mlir/Interfaces/ViewLikeInterface.h"
43 
44 namespace mlir {
45 namespace {
46 
printShapeTypeDimensionsList(AsmPrinter & printer,ArrayRef<int64_t> integers)47 void printShapeTypeDimensionsList(AsmPrinter &printer,
48                                   ArrayRef<int64_t> integers) {
49   llvm::interleave(
50       integers, printer,
51       [&](int64_t val) {
52         if (val == ShapedType::kDynamicSize)
53           printer << '?';
54         else
55           printer << val;
56       },
57       "x");
58 }
59 
parseShapeTypeDimensionsList(AsmParser & parser,FailureOr<SmallVector<int64_t>> & dims)60 ParseResult parseShapeTypeDimensionsList(
61     AsmParser &parser, FailureOr<SmallVector<int64_t>> &dims) {
62   SmallVector<int64_t> vals;
63   if (failed(parser.parseDimensionList(vals, /*allowDynamic=*/true,
64                                        /*withTrailingX=*/false))) {
65     return failure();
66   }
67   dims = vals;
68   return success();
69 }
70 
parseAssignmentListWithTypes(OpAsmParser & parser,SmallVectorImpl<OpAsmParser::UnresolvedOperand> & lhs,SmallVectorImpl<OpAsmParser::UnresolvedOperand> & rhs,SmallVectorImpl<Type> & types)71 ParseResult parseAssignmentListWithTypes(
72     OpAsmParser &parser, SmallVectorImpl<OpAsmParser::UnresolvedOperand> &lhs,
73     SmallVectorImpl<OpAsmParser::UnresolvedOperand> &rhs,
74     SmallVectorImpl<Type> &types) {
75   auto parseElt = [&]() -> ParseResult {
76     if (parser.parseOperand(lhs.emplace_back(), /*allowResultNumber=*/false) ||
77         parser.parseEqual() || parser.parseOperand(rhs.emplace_back()) ||
78         parser.parseColon() || parser.parseType(types.emplace_back())) {
79       return failure();
80     }
81     return success();
82   };
83   return parser.parseCommaSeparatedList(AsmParser::Delimiter::Paren, parseElt);
84 }
85 
86 }  // namespace
87 }  // namespace mlir
88 
89 // Generated dialect definitions.
90 #include "mlir-hlo/Dialect/gml_st/IR/gml_st_dialect.cc.inc"
91 
92 // Generated type classes.
93 #define GET_TYPEDEF_CLASSES
94 #include "mlir-hlo/Dialect/gml_st/IR/gml_st_types.cc.inc"
95 
96 namespace mlir {
97 namespace gml_st {
98 
99 //===----------------------------------------------------------------------===//
100 // GmlStDialect
101 //===----------------------------------------------------------------------===//
102 
initialize()103 void GmlStDialect::initialize() {
104   addOperations<
105 #define GET_OP_LIST
106 #include "mlir-hlo/Dialect/gml_st/IR/gml_st_ops.cc.inc"
107       >();
108   addTypes<
109 #define GET_TYPEDEF_LIST
110 #include "mlir-hlo/Dialect/gml_st/IR/gml_st_types.cc.inc"
111       >();
112 }
113 
114 // Helper function to ensure index types for some attrbutes when folding.
ensureIndexTypeForAttribute(OpFoldResult foldResult)115 static OpFoldResult ensureIndexTypeForAttribute(OpFoldResult foldResult) {
116   if (foldResult.is<Attribute>()) {
117     auto attr = foldResult.get<Attribute>().dyn_cast<IntegerAttr>();
118     if (!attr.getType().isa<IndexType>()) {
119       Builder b(attr.getContext());
120       return b.getIndexAttr(attr.getInt());
121     }
122   }
123   return foldResult;
124 }
125 
materializeConstant(OpBuilder & builder,Attribute attr,Type type,Location loc)126 Operation *GmlStDialect::materializeConstant(OpBuilder &builder, Attribute attr,
127                                              Type type, Location loc) {
128   if (type.isa<IndexType>()) {
129     int64_t intValue = attr.cast<IntegerAttr>().getInt();
130     return builder.create<arith::ConstantIndexOp>(loc, intValue);
131   }
132   return {};
133 }
134 
135 //===----------------------------------------------------------------------===//
136 // MaterializeOp
137 //===----------------------------------------------------------------------===//
138 
inferReturnTypes(MLIRContext *,Optional<Location>,ValueRange operands,DictionaryAttr attributes,RegionRange,SmallVectorImpl<Type> & inferredReturnTypes)139 LogicalResult MaterializeOp::inferReturnTypes(
140     MLIRContext *, Optional<Location>, ValueRange operands,
141     DictionaryAttr attributes, RegionRange,
142     SmallVectorImpl<Type> &inferredReturnTypes) {
143   MaterializeOp::Adaptor adaptor(operands, attributes);
144 
145   ShapedType sourceType = adaptor.source().getType().cast<ShapedType>();
146   Type setType = adaptor.set().getType();
147 
148   if (auto tileType = setType.dyn_cast<TileType>()) {
149     if (auto memrefType = sourceType.dyn_cast<MemRefType>()) {
150       inferredReturnTypes.push_back(
151           MemRefType::get(tileType.getShape(), sourceType.getElementType()));
152     } else if (auto tensorType = sourceType.dyn_cast<RankedTensorType>()) {
153       inferredReturnTypes.push_back(RankedTensorType::get(
154           tileType.getShape(), sourceType.getElementType()));
155     } else {
156       return failure();
157     }
158   } else if (setType.isa<PointType>()) {
159     inferredReturnTypes.push_back(sourceType.getElementType());
160   } else {
161     return failure();
162   }
163   return success();
164 }
165 
166 //===----------------------------------------------------------------------===//
167 // LoopOp
168 //===----------------------------------------------------------------------===//
169 
build(OpBuilder & builder,OperationState & result,ValueRange lowerBounds,ValueRange upperBounds,ValueRange steps,ValueRange inputs,ValueRange outputs,ArrayAttr iteratorTypes,function_ref<void (OpBuilder &,Location,ValueRange,ValueRange,ValueRange)> bodyBuilderFn)170 void LoopOp::build(OpBuilder &builder, OperationState &result,
171                    ValueRange lowerBounds, ValueRange upperBounds,
172                    ValueRange steps, ValueRange inputs, ValueRange outputs,
173                    ArrayAttr iteratorTypes,
174                    function_ref<void(OpBuilder &, Location, ValueRange,
175                                      ValueRange, ValueRange)>
176                        bodyBuilderFn) {
177   build(builder, result, lowerBounds, upperBounds, steps, inputs, outputs,
178         iteratorTypes, llvm::None, bodyBuilderFn);
179 }
180 
build(OpBuilder & builder,OperationState & result,ValueRange lowerBounds,ValueRange upperBounds,ValueRange steps,ValueRange inputs,ValueRange outputs,ArrayAttr iteratorTypes,Optional<ArrayAttr> distributionTypes,function_ref<void (OpBuilder &,Location,ValueRange,ValueRange,ValueRange)> bodyBuilderFn)181 void LoopOp::build(OpBuilder &builder, OperationState &result,
182                    ValueRange lowerBounds, ValueRange upperBounds,
183                    ValueRange steps, ValueRange inputs, ValueRange outputs,
184                    ArrayAttr iteratorTypes,
185                    Optional<ArrayAttr> distributionTypes,
186                    function_ref<void(OpBuilder &, Location, ValueRange,
187                                      ValueRange, ValueRange)>
188                        bodyBuilderFn) {
189   result.addOperands(lowerBounds);
190   result.addOperands(upperBounds);
191   result.addOperands(steps);
192   result.addOperands(inputs);
193   result.addOperands(outputs);
194   result.addAttribute(
195       LoopOp::getOperandSegmentSizeAttr(),
196       builder.getDenseI32ArrayAttr({static_cast<int32_t>(lowerBounds.size()),
197                                     static_cast<int32_t>(upperBounds.size()),
198                                     static_cast<int32_t>(steps.size()),
199                                     static_cast<int32_t>(inputs.size()),
200                                     static_cast<int32_t>(outputs.size())}));
201   result.addAttribute(getIteratorTypesAttrStrName(), iteratorTypes);
202 
203   if (distributionTypes.has_value())
204     result.addAttribute(getDistributionTypesAttrStrName(),
205                         distributionTypes.getValue());
206 
207   // Add output types for `RankedTensorType` output arguments.
208   for (Value output : outputs) {
209     Type outputType = output.getType();
210     if (outputType.isa<RankedTensorType>()) result.addTypes(outputType);
211   }
212 
213   OpBuilder::InsertionGuard guard(builder);
214   unsigned numIVs = steps.size();
215   SmallVector<Type, 8> argTypes(numIVs, builder.getIndexType());
216   SmallVector<Location, 8> argLocs(numIVs, result.location);
217   for (Value input : inputs) {
218     argTypes.push_back(input.getType());
219     argLocs.push_back(input.getLoc());
220   }
221   for (Value output : outputs) {
222     argTypes.push_back(output.getType());
223     argLocs.push_back(output.getLoc());
224   }
225   Region *bodyRegion = result.addRegion();
226   Block *bodyBlock = builder.createBlock(bodyRegion, {}, argTypes, argLocs);
227 
228   if (bodyBuilderFn) {
229     builder.setInsertionPointToStart(bodyBlock);
230     bodyBuilderFn(builder, result.location,
231                   bodyBlock->getArguments().take_front(numIVs),
232                   bodyBlock->getArguments().slice(numIVs, inputs.size()),
233                   bodyBlock->getArguments().take_back(outputs.size()));
234     LoopOp::ensureTerminator(*bodyRegion, builder, result.location);
235   }
236 }
237 
print(OpAsmPrinter & p)238 void LoopOp::print(OpAsmPrinter &p) {
239   p << " (" << getInductionVars() << ") = (" << lowerBound() << ") to ("
240     << upperBound() << ") step (" << step() << ")";
241 
242   if (!inputs().empty()) {
243     p << " ins (";
244     llvm::interleaveComma(llvm::zip(getRegionInputArgs(), inputs()), p,
245                           [&](auto it) {
246                             p << std::get<0>(it) << " = " << std::get<1>(it)
247                               << ": " << std::get<1>(it).getType();
248                           });
249     p << ")";
250   }
251   if (!outputs().empty()) {
252     p << " outs (";
253     llvm::interleaveComma(llvm::zip(getRegionOutputArgs(), outputs()), p,
254                           [&](auto it) {
255                             p << std::get<0>(it) << " = " << std::get<1>(it)
256                               << ": " << std::get<1>(it).getType();
257                           });
258     p << ")";
259   }
260 
261   if (llvm::any_of(iterator_types(), [](Attribute attr) {
262         return attr.cast<StringAttr>().getValue() !=
263                LoopOp::getParallelIteratorTypeName();
264       }))
265     p << " iterators" << iterator_types();
266 
267   if (distribution_types().has_value())
268     p << " distribution" << distribution_types().getValue();
269 
270   p << ' ';
271   p.printRegion(region(), /*printEntryBlockArgs=*/false);
272   p.printOptionalAttrDict(
273       getOperation()->getAttrs(),
274       /*elidedAttrs=*/{LoopOp::getOperandSegmentSizeAttr(),
275                        LoopOp::getIteratorTypesAttrName(),
276                        LoopOp::getDistributionTypesAttrName()});
277 }
278 
parse(OpAsmParser & parser,OperationState & result)279 ParseResult LoopOp::parse(OpAsmParser &parser, OperationState &result) {
280   auto &builder = parser.getBuilder();
281   // Parse an opening `(` followed by induction variables followed by `)`
282   SmallVector<OpAsmParser::UnresolvedOperand, 4> ivs;
283   if (parser.parseOperandList(ivs, OpAsmParser::Delimiter::Paren,
284                               /*allowResultNumber=*/false))
285     return failure();
286 
287   // Parse loop bounds.
288   SmallVector<OpAsmParser::UnresolvedOperand, 4> lower;
289   if (parser.parseEqual() ||
290       parser.parseOperandList(lower, ivs.size(),
291                               OpAsmParser::Delimiter::Paren) ||
292       parser.resolveOperands(lower, builder.getIndexType(), result.operands))
293     return failure();
294 
295   SmallVector<OpAsmParser::UnresolvedOperand, 4> upper;
296   if (parser.parseKeyword("to") ||
297       parser.parseOperandList(upper, ivs.size(),
298                               OpAsmParser::Delimiter::Paren) ||
299       parser.resolveOperands(upper, builder.getIndexType(), result.operands))
300     return failure();
301 
302   // Parse step values.
303   SmallVector<OpAsmParser::UnresolvedOperand, 4> steps;
304   if (parser.parseKeyword("step") ||
305       parser.parseOperandList(steps, ivs.size(),
306                               OpAsmParser::Delimiter::Paren) ||
307       parser.resolveOperands(steps, builder.getIndexType(), result.operands))
308     return failure();
309 
310   // Parse input tensors.
311   SmallVector<OpAsmParser::UnresolvedOperand, 4> inputs, inputRegionArgs;
312   SmallVector<Type, 4> inputTypes;
313   if (succeeded(parser.parseOptionalKeyword("ins"))) {
314     SMLoc inputsOperandsLoc = parser.getCurrentLocation();
315 
316     if (parseAssignmentListWithTypes(parser, inputRegionArgs, inputs,
317                                      inputTypes))
318       return failure();
319 
320     if (parser.resolveOperands(inputs, inputTypes, inputsOperandsLoc,
321                                result.operands))
322       return failure();
323   }
324 
325   // Parse output tensors.
326   SmallVector<OpAsmParser::UnresolvedOperand, 4> outputs, outputRegionArgs;
327   SmallVector<Type, 4> outputTypes;
328   if (succeeded(parser.parseOptionalKeyword("outs"))) {
329     SMLoc outputsOperandsLoc = parser.getCurrentLocation();
330 
331     if (parseAssignmentListWithTypes(parser, outputRegionArgs, outputs,
332                                      outputTypes))
333       return failure();
334 
335     if (parser.resolveOperands(outputs, outputTypes, outputsOperandsLoc,
336                                result.operands))
337       return failure();
338     for (Type outputType : outputTypes)
339       if (outputType.isa<RankedTensorType>()) result.addTypes(outputType);
340   }
341 
342   // Parse attributes.
343   SmallVector<Attribute, 4> iterTypes, distributionTypes;
344   auto parseAttr = [&](StringRef keyword, SmallVector<Attribute, 4> *attrs) {
345     if (succeeded(parser.parseOptionalKeyword(keyword))) {
346       StringAttr attr;
347 
348       if (parser.parseLSquare() || parser.parseAttribute(attr))
349         return failure();
350       attrs->push_back(attr);
351       for (int i = 1, e = ivs.size(); i < e; ++i) {
352         if (parser.parseComma() || parser.parseAttribute(attr))
353           return failure();
354         attrs->push_back(attr);
355       }
356       if (parser.parseRSquare()) return failure();
357     }
358     return success();
359   };
360   if (failed(parseAttr("iterators", &iterTypes)) ||
361       failed(parseAttr("distribution", &distributionTypes)))
362     return failure();
363 
364   // Set all loop iterator types to "parallel" if they are not printed in IR.
365   if (iterTypes.empty()) {
366     auto parallelIter =
367         builder.getStringAttr(LoopOp::getParallelIteratorTypeName());
368     iterTypes = SmallVector<Attribute, 4>(ivs.size(), parallelIter);
369   }
370   result.addAttribute(LoopOp::getIteratorTypesAttrStrName(),
371                       builder.getArrayAttr(iterTypes));
372   if (!distributionTypes.empty())
373     result.addAttribute(LoopOp::getDistributionTypesAttrStrName(),
374                         builder.getArrayAttr(distributionTypes));
375   result.addAttribute(
376       LoopOp::getOperandSegmentSizeAttr(),
377       builder.getDenseI32ArrayAttr({static_cast<int32_t>(lower.size()),
378                                     static_cast<int32_t>(upper.size()),
379                                     static_cast<int32_t>(steps.size()),
380                                     static_cast<int32_t>(inputs.size()),
381                                     static_cast<int32_t>(outputs.size())}));
382 
383   // Parse the body.
384   Region *body = result.addRegion();
385 
386   SmallVector<Type, 4> regionTypes(ivs.size(), builder.getIndexType());
387   regionTypes.append(inputTypes);
388   regionTypes.append(outputTypes);
389 
390   SmallVector<OpAsmParser::UnresolvedOperand, 4> regionOperands(ivs);
391   regionOperands.append(inputRegionArgs);
392   regionOperands.append(outputRegionArgs);
393 
394   SmallVector<OpAsmParser::Argument, 4> regionArgs;
395 
396   for (auto argAndType : llvm::zip(regionOperands, regionTypes)) {
397     auto &arg = regionArgs.emplace_back();
398     arg.ssaName = std::get<0>(argAndType);
399     arg.type = std::get<1>(argAndType);
400   }
401 
402   if (parser.parseRegion(*body, regionArgs)) return failure();
403 
404   // Parse optional attributes.
405   if (parser.parseOptionalAttrDict(result.attributes)) return failure();
406 
407   return success();
408 }
409 
getLoopBody()410 Region &LoopOp::getLoopBody() { return region(); }
411 
verify()412 LogicalResult LoopOp::verify() {
413   // Check if iterator types are provided for every loop dimension.
414   if (iterator_types().size() != getNumLoops())
415     return emitOpError("expected iterator types array attribute size = ")
416            << iterator_types().size()
417            << " to match the number of loops = " << getNumLoops();
418 
419   // Check if types of input arguments match region args types.
420   for (auto &item :
421        llvm::enumerate(llvm::zip(inputs(), getRegionInputArgs()))) {
422     Value input, inputRegionArg;
423     unsigned index = item.index();
424     std::tie(input, inputRegionArg) = item.value();
425     if (input.getType() != inputRegionArg.getType())
426       return emitOpError("expected input arg ")
427              << index << " with type = " << input.getType()
428              << " to match region arg " << index + getNumLoops()
429              << " type = " << inputRegionArg.getType();
430   }
431 
432   // Check if types of output arguments match region args types.
433   for (auto &item :
434        llvm::enumerate(llvm::zip(outputs(), getRegionOutputArgs()))) {
435     Value output, outputRegionArg;
436     unsigned index = item.index();
437     std::tie(output, outputRegionArg) = item.value();
438     if (output.getType() != outputRegionArg.getType())
439       return emitOpError("expected output arg ")
440              << index << " with type = " << output.getType()
441              << " to match region arg "
442              << index + getNumLoops() + inputs().size()
443              << " type = " << outputRegionArg.getType();
444   }
445   return success();
446 }
447 
448 //===----------------------------------------------------------------------===//
449 // LoopLikeOp
450 //===----------------------------------------------------------------------===//
451 
452 namespace {
453 
parseForOpOutputArgs(OpAsmParser & parser,OperationState & result,SmallVectorImpl<OpAsmParser::UnresolvedOperand> & regionOperands,SmallVectorImpl<Type> & regionTypes,int32_t * outputCount)454 ParseResult parseForOpOutputArgs(
455     OpAsmParser &parser, OperationState &result,
456     SmallVectorImpl<OpAsmParser::UnresolvedOperand> &regionOperands,
457     SmallVectorImpl<Type> &regionTypes, int32_t *outputCount) {
458   SmallVector<OpAsmParser::UnresolvedOperand, 4> outputs, outputRegionArgs;
459   SmallVector<Type, 4> outputTypes;
460 
461   auto parseElt = [&]() -> ParseResult {
462     if (parser.parseOperand(outputRegionArgs.emplace_back(),
463                             /*allowResultNumber=*/false) ||
464         parser.parseEqual()) {
465       return failure();
466     }
467     if (parser.parseOperand(outputs.emplace_back()) || parser.parseColon() ||
468         parser.parseType(outputTypes.emplace_back())) {
469       return failure();
470     }
471     *outputCount = outputs.size();
472     return success();
473   };
474   if (succeeded(parser.parseOptionalKeyword("outs"))) {
475     SMLoc loc = parser.getCurrentLocation();
476 
477     if (parser.parseCommaSeparatedList(AsmParser::Delimiter::Paren, parseElt))
478       return failure();
479     if (parser.resolveOperands(outputs, outputTypes, loc, result.operands))
480       return failure();
481   }
482   regionOperands.append(outputRegionArgs);
483   regionTypes.append(outputTypes);
484   return success();
485 }
486 
487 }  // namespace
488 
489 template <typename LoopTy>
parseLoopLikeOp(OpAsmParser & parser,OperationState & result)490 ParseResult parseLoopLikeOp(OpAsmParser &parser, OperationState &result) {
491   auto &builder = parser.getBuilder();
492   // Parse an opening `(` followed by induction variables followed by `)`
493   SmallVector<OpAsmParser::UnresolvedOperand, 4> ivs;
494   if (parser.parseOperandList(ivs, OpAsmParser::Delimiter::Paren,
495                               /*allowResultNumber=*/false))
496     return failure();
497 
498   // Parse loop bounds.
499   SmallVector<OpAsmParser::UnresolvedOperand, 4> lower;
500   if (parser.parseEqual() ||
501       parser.parseOperandList(lower, ivs.size(),
502                               OpAsmParser::Delimiter::Paren) ||
503       parser.resolveOperands(lower, builder.getIndexType(), result.operands))
504     return failure();
505 
506   SmallVector<OpAsmParser::UnresolvedOperand, 4> upper;
507   if (parser.parseKeyword("to") ||
508       parser.parseOperandList(upper, ivs.size(),
509                               OpAsmParser::Delimiter::Paren) ||
510       parser.resolveOperands(upper, builder.getIndexType(), result.operands))
511     return failure();
512 
513   // Parse step values.
514   SmallVector<OpAsmParser::UnresolvedOperand, 4> steps;
515   if (parser.parseKeyword("step") ||
516       parser.parseOperandList(steps, ivs.size(),
517                               OpAsmParser::Delimiter::Paren) ||
518       parser.resolveOperands(steps, builder.getIndexType(), result.operands))
519     return failure();
520 
521   SmallVector<int32_t> segmentSizes{static_cast<int32_t>(lower.size()),
522                                     static_cast<int32_t>(upper.size()),
523                                     static_cast<int32_t>(steps.size())};
524 
525   // Parse the output tensors (only for ForOp) and the body.
526   SmallVector<OpAsmParser::UnresolvedOperand, 4> regionOperands(ivs);
527   SmallVector<Type, 4> regionTypes(ivs.size(), builder.getIndexType());
528 
529   if (std::is_same<LoopTy, ForOp>::value) {
530     int32_t outputCount = 0;
531     if (parseForOpOutputArgs(parser, result, regionOperands, regionTypes,
532                              &outputCount))
533       return failure();
534     segmentSizes.push_back(outputCount);
535   }
536 
537   SmallVector<OpAsmParser::Argument, 4> regionArgs;
538   for (auto argAndType : llvm::zip(regionOperands, regionTypes)) {
539     auto &arg = regionArgs.emplace_back();
540     std::tie(arg.ssaName, arg.type) = argAndType;
541   }
542   Region *body = result.addRegion();
543   if (parser.parseRegion(*body, regionArgs)) return failure();
544 
545   // Parse attributes.
546   if (parser.parseOptionalAttrDict(result.attributes)) return failure();
547 
548   // Parser result types.
549   if (parser.parseOptionalColonTypeList(result.types)) return failure();
550 
551   // Add segment sizes.
552   result.addAttribute(LoopTy::getOperandSegmentSizeAttr(),
553                       builder.getDenseI32ArrayAttr(segmentSizes));
554 
555   return success();
556 }
557 
558 //===----------------------------------------------------------------------===//
559 // ParallelOp
560 //===----------------------------------------------------------------------===//
561 
getLoopBody()562 Region &ParallelOp::getLoopBody() { return region(); }
563 
getTerminator()564 SetYieldOp ParallelOp::getTerminator() {
565   return cast<SetYieldOp>(getBody()->getTerminator());
566 }
567 
verify()568 LogicalResult ParallelOp::verify() { return success(); }
569 
build(OpBuilder & builder,OperationState & result,TypeRange resultTypes,ValueRange lowerBounds,ValueRange upperBounds,ValueRange steps,function_ref<void (OpBuilder &,Location,ValueRange)> bodyBuilderFn)570 void ParallelOp::build(
571     OpBuilder &builder, OperationState &result, TypeRange resultTypes,
572     ValueRange lowerBounds, ValueRange upperBounds, ValueRange steps,
573     function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn) {
574   result.addOperands(lowerBounds);
575   result.addOperands(upperBounds);
576   result.addOperands(steps);
577   result.addTypes(resultTypes);
578   result.addAttribute(
579       LoopOp::getOperandSegmentSizeAttr(),
580       builder.getDenseI32ArrayAttr({static_cast<int32_t>(lowerBounds.size()),
581                                     static_cast<int32_t>(upperBounds.size()),
582                                     static_cast<int32_t>(steps.size())}));
583 
584   OpBuilder::InsertionGuard guard(builder);
585   unsigned numIvs = steps.size();
586   SmallVector<Type, 8> argTypes(numIvs, builder.getIndexType());
587   SmallVector<Location, 8> argLocs(numIvs, result.location);
588   Region *bodyRegion = result.addRegion();
589   Block *bodyBlock = builder.createBlock(bodyRegion, {}, argTypes, argLocs);
590 
591   if (bodyBuilderFn) {
592     builder.setInsertionPointToStart(bodyBlock);
593     bodyBuilderFn(builder, result.location,
594                   bodyBlock->getArguments().take_front(numIvs));
595     ParallelOp::ensureTerminator(*bodyRegion, builder, result.location);
596   }
597 }
598 
print(OpAsmPrinter & p)599 void ParallelOp::print(OpAsmPrinter &p) {
600   p << " (" << getInductionVars() << ") = (" << lowerBound() << ") to ("
601     << upperBound() << ") step (" << step() << ") ";
602 
603   p.printRegion(region(), /*printEntryBlockArgs=*/false);
604   p.printOptionalAttrDict(
605       getOperation()->getAttrs(),
606       /*elidedAttrs=*/{ParallelOp::getOperandSegmentSizeAttr()});
607 
608   if (!getResultTypes().empty()) {
609     p << " : ";
610     llvm::interleave(getResultTypes(), p, ", ");
611   }
612 }
613 
parse(OpAsmParser & parser,OperationState & result)614 ParseResult ParallelOp::parse(OpAsmParser &parser, OperationState &result) {
615   return parseLoopLikeOp<ParallelOp>(parser, result);
616 }
617 
618 //===----------------------------------------------------------------------===//
619 // ForOp
620 //===----------------------------------------------------------------------===//
621 
getLoopBody()622 Region &ForOp::getLoopBody() { return region(); }
623 
getTerminator()624 SetYieldOp ForOp::getTerminator() {
625   return cast<SetYieldOp>(getBody()->getTerminator());
626 }
627 
verify()628 LogicalResult ForOp::verify() {
629   // Check if types of output arguments match region args types.
630   for (auto &item :
631        llvm::enumerate(llvm::zip(outputs(), getRegionOutputArgs()))) {
632     Value output, outputRegionArg;
633     unsigned index = item.index();
634     std::tie(output, outputRegionArg) = item.value();
635     if (output.getType() != outputRegionArg.getType()) {
636       return emitOpError("expected output arg ")
637              << index << " with type = " << output.getType()
638              << " to match region arg " << index + getNumLoops()
639              << " type = " << outputRegionArg.getType();
640     }
641     if (getTerminator().getDstOperand(index)->get() != outputRegionArg) {
642       return getTerminator().emitOpError("expected output block argument ")
643              << index << " to match set_yield destination";
644     }
645   }
646   return success();
647 }
648 
build(OpBuilder & builder,OperationState & result,TypeRange resultTypes,ValueRange lowerBounds,ValueRange upperBounds,ValueRange steps,ValueRange outputs,function_ref<void (OpBuilder &,Location,ValueRange,ValueRange)> bodyBuilderFn)649 void ForOp::build(
650     OpBuilder &builder, OperationState &result, TypeRange resultTypes,
651     ValueRange lowerBounds, ValueRange upperBounds, ValueRange steps,
652     ValueRange outputs,
653     function_ref<void(OpBuilder &, Location, ValueRange, ValueRange)>
654         bodyBuilderFn) {
655   result.addOperands(lowerBounds);
656   result.addOperands(upperBounds);
657   result.addOperands(steps);
658   result.addOperands(outputs);
659   result.addTypes(resultTypes);
660   result.addAttribute(
661       LoopOp::getOperandSegmentSizeAttr(),
662       builder.getDenseI32ArrayAttr({static_cast<int32_t>(lowerBounds.size()),
663                                     static_cast<int32_t>(upperBounds.size()),
664                                     static_cast<int32_t>(steps.size()),
665                                     static_cast<int32_t>(outputs.size())}));
666 
667   OpBuilder::InsertionGuard guard(builder);
668   unsigned numIvs = steps.size();
669   SmallVector<Type, 8> argTypes(numIvs, builder.getIndexType());
670   SmallVector<Location, 8> argLocs(numIvs, result.location);
671   for (Value output : outputs) {
672     argTypes.push_back(output.getType());
673     argLocs.push_back(output.getLoc());
674   }
675   Region *bodyRegion = result.addRegion();
676   Block *bodyBlock = builder.createBlock(bodyRegion, {}, argTypes, argLocs);
677 
678   if (bodyBuilderFn) {
679     builder.setInsertionPointToStart(bodyBlock);
680     bodyBuilderFn(builder, result.location,
681                   bodyBlock->getArguments().take_front(numIvs),
682                   bodyBlock->getArguments().take_back(outputs.size()));
683     ForOp::ensureTerminator(*bodyRegion, builder, result.location);
684   }
685 }
686 
print(OpAsmPrinter & p)687 void ForOp::print(OpAsmPrinter &p) {
688   p << " (" << getInductionVars() << ") = (" << lowerBound() << ") to ("
689     << upperBound() << ") step (" << step() << ")";
690 
691   if (!outputs().empty()) {
692     p << " outs (";
693     llvm::interleaveComma(
694         llvm::zip(getRegionOutputArgs(), outputs()), p, [&](auto it) {
695           Value outputRegionArg, output;
696           std::tie(outputRegionArg, output) = it;
697           p << outputRegionArg << " = " << output << ": " << output.getType();
698         });
699     p << ")";
700   }
701 
702   p << ' ';
703   p.printRegion(region(), /*printEntryBlockArgs=*/false);
704   p.printOptionalAttrDict(getOperation()->getAttrs(),
705                           /*elidedAttrs=*/{ForOp::getOperandSegmentSizeAttr()});
706 
707   if (!getResultTypes().empty()) {
708     p << " : ";
709     llvm::interleave(getResultTypes(), p, ", ");
710   }
711 }
712 
parse(OpAsmParser & parser,OperationState & result)713 ParseResult ForOp::parse(OpAsmParser &parser, OperationState &result) {
714   return parseLoopLikeOp<ForOp>(parser, result);
715 }
716 
717 namespace {
718 
719 static constexpr int64_t kNoMatch = -1;
720 
721 // Folds away LoopOp inputs if they have no uses within the body.
722 //
723 // Example:
724 //
725 // %0 = gml_st.loop ...  ins (%in_ = %in: tensor<...>,
726 //                                  %in_buf_ = %in_buf: memref<...>) {...}
727 // Becomes
728 //
729 // gml_st.loop ...  ins (%in_buf_ = %in_buf: memref<...>) {...}
730 struct LoopInputsFolder : public OpRewritePattern<LoopOp> {
731   using OpRewritePattern<LoopOp>::OpRewritePattern;
732 
matchAndRewritemlir::gml_st::__anon73e87aa30b11::LoopInputsFolder733   LogicalResult matchAndRewrite(LoopOp loop,
734                                 PatternRewriter &rewriter) const final {
735     SmallVector<Value, 2> newInputs, regionInputTensorArgs;
736     // Store ids of the corresponding old and new input operands.
737     SmallVector<int64_t, 2> oldInputIdToNew(loop.inputs().size(), kNoMatch);
738     for (const auto &en :
739          llvm::enumerate(llvm::zip(loop.inputs(), loop.getRegionInputArgs()))) {
740       Value in, bbArg;
741       size_t index = en.index();
742       std::tie(in, bbArg) = en.value();
743       if (!bbArg.use_empty()) {
744         oldInputIdToNew[index] = newInputs.size();
745         newInputs.push_back(in);
746       }
747     }
748     if (newInputs.size() == loop.inputs().size()) return failure();
749     Location loc = loop.getLoc();
750     auto newLoop = rewriter.create<LoopOp>(
751         loc, loop.lowerBound(), loop.upperBound(), loop.step(), newInputs,
752         loop.outputs(), loop.iterator_types(), loop.distribution_types());
753 
754     // Clone the region.
755     BlockAndValueMapping bvm;
756     bvm.map(loop.getInductionVars(), newLoop.getInductionVars());
757     bvm.map(loop.getRegionOutputArgs(), newLoop.getRegionOutputArgs());
758     for (const auto &en : llvm::enumerate(oldInputIdToNew))
759       if (en.value() != kNoMatch)
760         bvm.map(loop.getRegionInputArgs()[en.index()],
761                 newLoop.getRegionInputArgs()[en.value()]);
762     OpBuilder innerBuilder =
763         OpBuilder::atBlockEnd(newLoop.getBody(), rewriter.getListener());
764     for (auto &op : *loop.getBody()) innerBuilder.clone(op, bvm);
765     rewriter.replaceOp(loop, newLoop.getResults());
766 
767     return success();
768   }
769 };
770 
771 }  // namespace
772 
773 /// A simple, conservative analysis to determine if the loop is shape
774 /// conserving. I.e., the type of the arg-th yielded value is the same as the
775 /// type of the corresponding basic block argument of the loop.
776 /// Note: This function handles only simple cases. Expand as needed.
isShapePreserving(LoopOp loopOp,int64_t arg)777 static bool isShapePreserving(LoopOp loopOp, int64_t arg) {
778   auto yieldOp = cast<YieldOp>(loopOp.getLoopBody().front().getTerminator());
779   if (yieldOp.values().empty())
780     // Loop either has no outputs or is a "memref-based version". In either
781     // case, the loop is shape conserving.
782     return true;
783   assert(arg < static_cast<int64_t>(yieldOp.values().size()) &&
784          "arg is out of bounds");
785   Value value = yieldOp.values()[arg];
786   while (value) {
787     if (value == loopOp.getRegionOutputArgs()[arg]) return true;
788     OpResult opResult = value.dyn_cast<OpResult>();
789     if (!opResult) return false;
790 
791     using tensor::InsertSliceOp;
792     value = llvm::TypeSwitch<Operation *, Value>(opResult.getOwner())
793                 .template Case<InsertSliceOp>(
794                     [&](InsertSliceOp op) { return op.getDest(); })
795                 .template Case<LoopOp>([&](LoopOp loopOp) {
796                   return isShapePreserving(loopOp, opResult.getResultNumber())
797                              ? loopOp.outputs()[opResult.getResultNumber()]
798                              : Value();
799                 })
800                 .Default([&](auto /*op*/) { return Value(); });
801   }
802   return false;
803 }
804 
805 namespace {
806 
807 /// Fold dim(x) where `x` is an input/output argument of a LoopOp block
808 /// to dim(y) where `y` is the initial input/output value of the argument.
809 ///
810 /// E.g.:
811 /// %y = ... : tensor<...>
812 /// gml_st.loop ... ins(%x = %y : tensor<...>) {
813 ///   tensor.dim %x, %c0 : tensor<...>
814 /// }
815 ///
816 /// is folded to:
817 /// %y = ... : tensor<...>
818 /// gml_st.loop ... ins(%x = %y : tensor<...>) {
819 ///   tensor.dim %y, %c0 : tensor<...>
820 /// }
821 ///
822 /// Note: Dim ops are folded only if it can be proven that the runtime type of
823 /// the yielded value (in case of outputs) does not change with loop iterations.
824 template <typename OpTy>
825 struct DimOfLoopInsOutsFolder : public OpRewritePattern<OpTy> {
826   using OpRewritePattern<OpTy>::OpRewritePattern;
827 
matchAndRewritemlir::gml_st::__anon73e87aa30f11::DimOfLoopInsOutsFolder828   LogicalResult matchAndRewrite(OpTy dimOp,
829                                 PatternRewriter &rewriter) const final {
830     auto src = dimOp.getSource().template dyn_cast<BlockArgument>();
831     if (!src) return failure();
832     auto loopOp = dyn_cast<LoopOp>(src.getOwner()->getParent()->getParentOp());
833     if (!loopOp) return failure();
834     unsigned numLoops = loopOp.getNumLoops();
835     unsigned numInputArgs = loopOp.getRegionInputArgs().size();
836     if (src.getArgNumber() >= numInputArgs + numLoops &&
837         !isShapePreserving(loopOp,
838                            src.getArgNumber() - numInputArgs - numLoops))
839       return failure();
840 
841     auto inputArgs = loopOp.getRegionInputArgs();
842     auto it1 = llvm::find(inputArgs, src);
843     if (it1 != inputArgs.end()) {
844       rewriter.updateRootInPlace(dimOp, [&] {
845         dimOp.getSourceMutable().assign(
846             loopOp.inputs()[it1 - inputArgs.begin()]);
847       });
848       return success();
849     }
850 
851     auto outputArgs = loopOp.getRegionOutputArgs();
852     auto it2 = llvm::find(outputArgs, src);
853     if (it2 != outputArgs.end()) {
854       rewriter.updateRootInPlace(dimOp, [&] {
855         dimOp.getSourceMutable().assign(
856             loopOp.outputs()[it2 - outputArgs.begin()]);
857       });
858       return success();
859     }
860 
861     return failure();
862   }
863 };
864 
865 /// Fold dim(r) where `r` is the result of a LoopOp to dim(y) where `y`
866 /// is the initial output value of the loop.
867 ///
868 /// E.g.:
869 /// %y = ... : tensor<...>
870 /// %r = gml_st.loop ... outs(%i = %y : tensor<...>) {
871 ///   ...
872 /// }
873 /// %0 = tensor.dim %r, %c0 : tensor<...>
874 ///
875 /// is folded to:
876 /// %y = ... : tensor<...>
877 /// gml_st.loop ... outs(%i = %y : tensor<...>) {
878 ///   ...
879 /// }
880 /// %0 = tensor.dim %y, %c0 : tensor<...>
881 ///
882 /// Note: Dim ops are folded only if it can be proven that the runtime type of
883 /// the yielded value (in case of outputs) does not change with loop iterations.
884 template <typename OpTy>
885 struct DimOfLoopResultFolder : public OpRewritePattern<OpTy> {
886   using OpRewritePattern<OpTy>::OpRewritePattern;
887 
matchAndRewritemlir::gml_st::__anon73e87aa30f11::DimOfLoopResultFolder888   LogicalResult matchAndRewrite(OpTy dimOp,
889                                 PatternRewriter &rewriter) const final {
890     auto loopOp = dimOp.getSource().template getDefiningOp<LoopOp>();
891     if (!loopOp) return failure();
892     auto opResult = dimOp.getSource().template cast<OpResult>();
893     unsigned resultNumber = opResult.getResultNumber();
894     if (!isShapePreserving(loopOp, resultNumber)) return failure();
895     rewriter.updateRootInPlace(dimOp, [&]() {
896       dimOp.getSourceMutable().assign(loopOp.outputs()[resultNumber]);
897     });
898     return success();
899   }
900 };
901 
902 // Folds away LoopOp output tensors when the following conditions are met:
903 // * result of `gml_st.loop` has no uses
904 // * output tensor is the argument of `gml_st.yield`
905 //
906 // Example:
907 //
908 // %0 = gml_st.loop ...  outs (%o_ = %out: tensor<...>,
909 //                                   %obuf_ = %out_buf: memref<...>) {
910 //   ...
911 //   gml_st.yield %o_ : tensor ...
912 // }
913 //
914 // Becomes
915 //
916 // gml_st.loop ...  outs (%obuf_ = %out_buf: memref<...>) {
917 //   ...
918 //   gml_st.yield
919 // }
920 struct LoopResultsFolder : public OpRewritePattern<LoopOp> {
921   using OpRewritePattern<LoopOp>::OpRewritePattern;
922 
matchAndRewritemlir::gml_st::__anon73e87aa30f11::LoopResultsFolder923   LogicalResult matchAndRewrite(LoopOp loop,
924                                 PatternRewriter &rewriter) const final {
925     if (loop.getNumResults() == 0) return failure();
926 
927     Block *block = loop.getBody();
928     auto yieldOp = cast<YieldOp>(block->getTerminator());
929 
930     // Match the pattern and collect output buffers that will replace the output
931     // tensors and also the ops that will be ignored when cloning the body.
932     SmallVector<Value, 2> newOutputOperands, newYieldArgs;
933     int resultId = 0;
934     // Store ids of the corresponding old and new output operands.
935     SmallVector<int64_t, 2> oldOutputIdToNew(loop.outputs().size(), kNoMatch);
936     // Store ids of the corresponding old and new results.
937     SmallVector<int64_t, 2> oldResultIdToNew(loop.getNumResults(), kNoMatch);
938     SmallVector<Value, 2> resultReplacement(loop.getNumResults());
939     for (const auto &en : llvm::enumerate(
940              llvm::zip(loop.outputs(), loop.getRegionOutputArgs()))) {
941       size_t index = en.index();
942       Value out = std::get<0>(en.value());
943       Value outRegionArg = std::get<1>(en.value());
944 
945       if (!out.getType().isa<RankedTensorType>()) {
946         oldOutputIdToNew[index] = newOutputOperands.size();
947         newOutputOperands.push_back(out);
948         continue;
949       }
950       Value result = loop.getResult(resultId);
951       Value yieldArg = yieldOp.getOperand(resultId);
952       if (yieldArg != outRegionArg || !result.use_empty()) {
953         oldOutputIdToNew[index] = newOutputOperands.size();
954         oldResultIdToNew[resultId] = newYieldArgs.size();
955         resultReplacement[resultId] = out;
956         newOutputOperands.push_back(out);
957         newYieldArgs.push_back(yieldArg);
958       }
959       ++resultId;
960     }
961     if (newOutputOperands.size() == loop.outputs().size()) return failure();
962 
963     Location loc = loop.getLoc();
964     auto newLoop = rewriter.create<LoopOp>(
965         loc, loop.lowerBound(), loop.upperBound(), loop.step(), loop.inputs(),
966         newOutputOperands, loop.iterator_types(), loop.distribution_types());
967 
968     // Clone the region.
969     BlockAndValueMapping bvm;
970     bvm.map(loop.getInductionVars(), newLoop.getInductionVars());
971     bvm.map(loop.getRegionInputArgs(), newLoop.getRegionInputArgs());
972     for (const auto &en : llvm::enumerate(oldOutputIdToNew)) {
973       if (en.value() != kNoMatch)
974         bvm.map(loop.getRegionOutputArgs()[en.index()],
975                 newLoop.getRegionOutputArgs()[en.value()]);
976       else
977         bvm.map(loop.getRegionOutputArgs()[en.index()],
978                 loop.outputs()[en.index()]);
979     }
980     OpBuilder innerBuilder =
981         OpBuilder::atBlockEnd(newLoop.getBody(), rewriter.getListener());
982     for (auto &op : loop.getBody()->without_terminator())
983       innerBuilder.clone(op, bvm);
984     innerBuilder.create<YieldOp>(
985         loc, llvm::to_vector<2>(llvm::map_range(
986                  newYieldArgs, [&](Value arg) { return bvm.lookup(arg); })));
987 
988     for (const auto &en : llvm::enumerate(oldResultIdToNew))
989       if (en.value() != kNoMatch)
990         resultReplacement[en.index()] = newLoop.getResult(en.value());
991     rewriter.replaceOp(loop, resultReplacement);
992 
993     return success();
994   }
995 };
996 
997 /// Pull `gml_st.loop` input/output arguments that are produced by
998 /// `tensor.cast` ops inside `gml_st.loop`:
999 ///
1000 /// ```
1001 ///   %in = tensor.cast %t0 : tensor<32x1024xf32> to tensor<?x?xf32>
1002 ///   %out = tensor.cast %t1 : tensor<32x1024xf32> to tensor<?x?xf32>
1003 ///   %result = gml_st.loop %i = %c0 to %c1024 step %c32
1004 ///       ins (%in_ = %in: tensor<?x?xf32>)
1005 ///       outs (%out_ = %out: tensor<?x?xf32>) {
1006 ///     %0 = call @do(%in_, %out_)
1007 ///       : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
1008 ///     scf.yield %0 : tensor<?x?xf32>
1009 ///   }
1010 ///   %result_cast = tensor.cast %result
1011 ///     : tensor<?x?xf32> to tensor<32x1024xf32>
1012 ///   use_of(%result_cast)
1013 /// ```
1014 ///
1015 /// folds into:
1016 //
1017 /// ```
1018 ///   %result = gml_st.loop %i = %c0 to %c1024 step %c32
1019 ///       ins (%in_ = %t0: tensor<32x1024xf32>)
1020 ///       outs (%out_ = %t1: tensor<32x1024xf32>) {
1021 ///     %in_cast = tensor.cast %in_ : tensor<32x1024xf32> to tensor<?x?xf32>
1022 ///     %out_cast = tensor.cast %out_ : tensor<32x1024xf32> to tensor<?x?xf32>
1023 ///     %0 = call @do(%in_, %out_)
1024 ///       : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
1025 ///     %0_cast = tensor.cast %0 : tensor<?x?xf32> to tensor<32x1024xf32>
1026 ///     scf.yield %0 : tensor<32x1024xf32>
1027 ///   }
1028 ///   use_of(%result)
1029 /// ```
1030 struct TensorCastOfLoopInsOutsFolder : public OpRewritePattern<LoopOp> {
1031   using OpRewritePattern<LoopOp>::OpRewritePattern;
1032 
matchAndRewritemlir::gml_st::__anon73e87aa30f11::TensorCastOfLoopInsOutsFolder1033   LogicalResult matchAndRewrite(LoopOp loop,
1034                                 PatternRewriter &rewriter) const override {
1035     CastOpsOfArgs inputCasts = findTensorCastOps(loop.inputs());
1036     CastOpsOfArgs outputCasts = findTensorCastOps(loop.outputs());
1037     if (!inputCasts.castFound && !outputCasts.castFound) return failure();
1038 
1039     auto newLoop = rewriter.create<LoopOp>(
1040         loop.getLoc(), loop.lowerBound(), loop.upperBound(), loop.step(),
1041         inputCasts.updatedArgs, outputCasts.updatedArgs, loop.iterator_types(),
1042         loop.distribution_types());
1043 
1044     rewriter.replaceOp(loop, insertCastsAndCloneBody(inputCasts, outputCasts,
1045                                                      loop, newLoop, rewriter));
1046     return success();
1047   }
1048 
1049  private:
1050   struct CastOpsOfArgs {
1051     SmallVector<tensor::CastOp, 4> ops;
1052     // Contains either old arguments or arguments of `tensor.cast`.
1053     SmallVector<Value, 4> updatedArgs;
1054     bool castFound = false;
1055   };
1056 
1057   // Scans through args to find what args are produced by `tensor.cast` ops.
findTensorCastOpsmlir::gml_st::__anon73e87aa30f11::TensorCastOfLoopInsOutsFolder1058   CastOpsOfArgs findTensorCastOps(ValueRange args) const {
1059     CastOpsOfArgs result;
1060     for (auto arg : args) {
1061       if (auto cast = arg.getDefiningOp<tensor::CastOp>()) {
1062         result.ops.push_back(cast);
1063         result.updatedArgs.push_back(cast.getSource());
1064         result.castFound = true;
1065         continue;
1066       }
1067       result.ops.push_back(nullptr);
1068       result.updatedArgs.push_back(arg);
1069     }
1070     return result;
1071   }
1072 
insertCastsAndCloneBodymlir::gml_st::__anon73e87aa30f11::TensorCastOfLoopInsOutsFolder1073   SmallVector<Value, 4> insertCastsAndCloneBody(
1074       const CastOpsOfArgs &inputCasts, const CastOpsOfArgs &outputCasts,
1075       LoopOp loop, LoopOp newLoop, PatternRewriter &rewriter) const {
1076     auto loc = newLoop.getLoc();
1077     BlockAndValueMapping bvm;
1078     bvm.map(loop.getInductionVars(), newLoop.getInductionVars());
1079 
1080     auto innerBuilder =
1081         OpBuilder::atBlockEnd(newLoop.getBody(), rewriter.getListener());
1082 
1083     Value oldArg, newArg, yieldArg, result;
1084     tensor::CastOp argCast;
1085 
1086     // Map inputs, insert `tensor.cast` if necessary.
1087     for (auto item : llvm::zip(loop.getRegionInputArgs(),
1088                                newLoop.getRegionInputArgs(), inputCasts.ops)) {
1089       std::tie(oldArg, newArg, argCast) = item;
1090       if (!argCast) {
1091         bvm.map(oldArg, newArg);
1092         continue;
1093       }
1094       Value newCast =
1095           innerBuilder.create<tensor::CastOp>(loc, argCast.getType(), newArg);
1096       bvm.map(oldArg, newCast);
1097     }
1098 
1099     // Map outputs, insert `tensor.cast` and cast the loop results if necessary.
1100     SmallVector<Value, 4> newResults;
1101     rewriter.setInsertionPointAfter(newLoop);
1102     for (auto item :
1103          llvm::zip(loop.getRegionOutputArgs(), newLoop.getRegionOutputArgs(),
1104                    outputCasts.ops, newLoop.getResults())) {
1105       std::tie(oldArg, newArg, argCast, result) = item;
1106       if (!argCast) {
1107         bvm.map(oldArg, newArg);
1108         newResults.push_back(result);
1109         continue;
1110       }
1111       Value newCast =
1112           innerBuilder.create<tensor::CastOp>(loc, argCast.getType(), newArg);
1113       bvm.map(oldArg, newCast);
1114 
1115       newResults.push_back(
1116           rewriter.create<tensor::CastOp>(loc, argCast.getType(), result));
1117     }
1118 
1119     // Clone loop body.
1120     for (auto &op : loop.getBody()->without_terminator())
1121       innerBuilder.clone(op, bvm);
1122 
1123     // Cast yield arguments to the new type.
1124     SmallVector<Value, 4> yieldArgs =
1125         loop.getBody()->getTerminator()->getOperands();
1126     SmallVector<Value, 4> newYieldArgs;
1127     for (auto item : llvm::zip(yieldArgs, outputCasts.ops)) {
1128       std::tie(yieldArg, argCast) = item;
1129       if (!argCast) {
1130         newYieldArgs.push_back(bvm.lookup(yieldArg));
1131         continue;
1132       }
1133       newYieldArgs.push_back(innerBuilder.create<tensor::CastOp>(
1134           loc, argCast.getSource().getType(), bvm.lookup(yieldArg)));
1135     }
1136     innerBuilder.create<YieldOp>(loc, newYieldArgs);
1137     return newResults;
1138   }
1139 };
1140 
1141 /// Removes loops in which at least one lower/upper bound pair consists
1142 /// of the same values - such loops have an empty iteration domain.
1143 struct FoldEmptyLoops : public OpRewritePattern<LoopOp> {
1144   using OpRewritePattern<LoopOp>::OpRewritePattern;
1145 
matchAndRewritemlir::gml_st::__anon73e87aa30f11::FoldEmptyLoops1146   LogicalResult matchAndRewrite(LoopOp op,
1147                                 PatternRewriter &rewriter) const override {
1148     for (auto dim : llvm::zip(op.lowerBound(), op.upperBound())) {
1149       if (std::get<0>(dim) != std::get<1>(dim)) continue;
1150       SmallVector<Value> tensorOutputs;
1151       for (Value out : op.outputs()) {
1152         if (out.getType().isa<RankedTensorType>()) tensorOutputs.push_back(out);
1153       }
1154       rewriter.replaceOp(op, tensorOutputs);
1155       return success();
1156     }
1157     return failure();
1158   }
1159 };
1160 
1161 }  // namespace
1162 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)1163 void LoopOp::getCanonicalizationPatterns(RewritePatternSet &results,
1164                                          MLIRContext *context) {
1165   results
1166       .add<FoldEmptyLoops, LoopInputsFolder, LoopResultsFolder,
1167            DimOfLoopInsOutsFolder<tensor::DimOp>,
1168            DimOfLoopInsOutsFolder<memref::DimOp>,
1169            DimOfLoopResultFolder<tensor::DimOp>,
1170            DimOfLoopResultFolder<memref::DimOp>, TensorCastOfLoopInsOutsFolder>(
1171           context);
1172 }
1173 
1174 /// This is used for patterns of the form
1175 /// ```
1176 ///    gml_st.loop(memrefcast(%src)) -> gml_st.loop(%src)
1177 /// ```
1178 /// It folds the source of the memref.cast into the root operation directly.
fold(ArrayRef<Attribute>,SmallVectorImpl<OpFoldResult> &)1179 LogicalResult LoopOp::fold(ArrayRef<Attribute>,
1180                            SmallVectorImpl<OpFoldResult> &) {
1181   LoopOp op = *this;
1182   bool folded = false;
1183   Location loc = op->getLoc();
1184 
1185   Block *body = op.getBody();
1186   OpBuilder b = OpBuilder::atBlockBegin(body);
1187 
1188   // Update `input` and `output` operands and block arguments if necessary.
1189   // Operands list: [lbs, ubs, steps, inputs, outputs].
1190   // Block args list: [ivs, inputs, outputs].
1191   for (size_t operandIndex = op.getNumControlOperands(),
1192               bbArgIndex = op.getNumLoops(), e = op.getNumOperands();
1193        operandIndex < e; ++operandIndex, ++bbArgIndex) {
1194     OpOperand &operand = op->getOpOperand(operandIndex);
1195 
1196     auto castOp = operand.get().getDefiningOp<memref::CastOp>();
1197     if (castOp && memref::CastOp::canFoldIntoConsumerOp(castOp)) {
1198       operand.set(castOp.getOperand());
1199       BlockArgument newBbArg = body->insertArgument(
1200           bbArgIndex, castOp.getOperand().getType(), op.getLoc());
1201       BlockArgument oldBbArg = body->getArgument(newBbArg.getArgNumber() + 1);
1202 
1203       // Insert memref.cast back to the original type.
1204       oldBbArg.replaceAllUsesWith(
1205           b.create<memref::CastOp>(loc, oldBbArg.getType(), newBbArg));
1206       body->eraseArgument(oldBbArg.getArgNumber());
1207 
1208       folded = true;
1209     }
1210   }
1211   return success(folded);
1212 }
1213 
1214 //===----------------------------------------------------------------------===//
1215 // YieldOp
1216 //===----------------------------------------------------------------------===//
1217 
verify()1218 LogicalResult YieldOp::verify() {
1219   auto *parentOp = getOperation()->getParentOp();
1220 
1221   if (auto setYield = dyn_cast<SetYieldOp>(parentOp)) {
1222     if (values().size() != 1)
1223       return emitOpError(
1224           "expected a single argument for the terminator of accumulator "
1225           "region");
1226     return success();
1227   }
1228   auto loopOp = cast<LoopOp>(parentOp);
1229   // Check if output args with tensor types match results types.
1230   SmallVector<Value, 2> tensorOuts;
1231   llvm::copy_if(
1232       loopOp.outputs(), std::back_inserter(tensorOuts),
1233       [&](Value out) { return out.getType().isa<RankedTensorType>(); });
1234   if (tensorOuts.size() != values().size())
1235     return emitOpError("expected number of tensor output args = ")
1236            << tensorOuts.size()
1237            << " to match the number of yield operands = " << values().size();
1238 
1239   TypeRange tensorTypes{ValueRange{tensorOuts}};
1240   for (auto &item :
1241        llvm::enumerate(llvm::zip(tensorTypes, getOperandTypes()))) {
1242     Type outType, resultType;
1243     unsigned index = item.index();
1244     std::tie(outType, resultType) = item.value();
1245     if (outType != resultType)
1246       return emitOpError("expected yield operand ")
1247              << index << " with type = " << resultType
1248              << " to match output arg type = " << outType;
1249   }
1250   return success();
1251 }
1252 
1253 //===----------------------------------------------------------------------===//
1254 // SpaceOp
1255 //===----------------------------------------------------------------------===//
1256 
build(OpBuilder & builder,OperationState & result,ArrayRef<OpFoldResult> sizes,ArrayRef<NamedAttribute> attrs)1257 void SpaceOp::build(OpBuilder &builder, OperationState &result,
1258                     ArrayRef<OpFoldResult> sizes,
1259                     ArrayRef<NamedAttribute> attrs) {
1260   SmallVector<Value> dynamicSizes;
1261   SmallVector<int64_t> staticSizes;
1262   for (OpFoldResult size : sizes)
1263     dispatchIndexOpFoldResult(size, dynamicSizes, staticSizes,
1264                               ShapedType::kDynamicSize);
1265   build(builder, result, TileType::get(builder.getContext(), staticSizes),
1266         dynamicSizes, builder.getI64ArrayAttr(staticSizes));
1267   result.addAttributes(attrs);
1268 }
1269 
inferReturnTypes(MLIRContext * ctx,Optional<Location>,ValueRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<Type> & inferredReturnTypes)1270 LogicalResult SpaceOp::inferReturnTypes(
1271     MLIRContext *ctx, Optional<Location> /*loc*/, ValueRange operands,
1272     DictionaryAttr attributes, RegionRange regions,
1273     SmallVectorImpl<Type> &inferredReturnTypes) {
1274   SpaceOp::Adaptor adaptor(operands, attributes, regions);
1275   SmallVector<int64_t> shape = llvm::to_vector(
1276       llvm::map_range(adaptor.static_sizes(), [&](const Attribute &val) {
1277         return val.cast<IntegerAttr>().getValue().getSExtValue();
1278       }));
1279   auto resultTy = TileType::get(ctx, shape);
1280   inferredReturnTypes.push_back(resultTy);
1281   return success();
1282 }
1283 
verify()1284 LogicalResult SpaceOp::verify() {
1285   auto resultTy = getType().cast<TileType>();
1286   return mlir::verifyListOfOperandsOrIntegers(
1287       getOperation(), "size", resultTy.getShape().size(), static_sizes(),
1288       dynamic_sizes(), ShapedType::isDynamic);
1289 }
1290 
getNumDynamicEntriesUpToIdx(unsigned idx)1291 unsigned SpaceOp::getNumDynamicEntriesUpToIdx(unsigned idx) {
1292   return std::count_if(static_sizes().begin(), static_sizes().begin() + idx,
1293                        [&](const mlir::Attribute size) {
1294                          return mlir::ShapedType::isDynamic(
1295                              size.cast<mlir::IntegerAttr>().getInt());
1296                        });
1297 }
1298 
getDynamicSize(unsigned idx)1299 mlir::Value SpaceOp::getDynamicSize(unsigned idx) {
1300   auto numDynamic = getNumDynamicEntriesUpToIdx(idx);
1301   return dynamic_sizes()[numDynamic];
1302 }
1303 
1304 //===----------------------------------------------------------------------===//
1305 // PointOp
1306 //===----------------------------------------------------------------------===//
1307 
build(OpBuilder & builder,OperationState & result,Value superset,ArrayRef<OpFoldResult> offsets,ArrayRef<NamedAttribute> attrs)1308 void PointOp::build(OpBuilder &builder, OperationState &result, Value superset,
1309                     ArrayRef<OpFoldResult> offsets,
1310                     ArrayRef<NamedAttribute> attrs) {
1311   SmallVector<Value> dynamicOffsets;
1312   SmallVector<int64_t> staticOffsets;
1313   for (OpFoldResult offset : offsets)
1314     dispatchIndexOpFoldResult(offset, dynamicOffsets, staticOffsets,
1315                               ShapedType::kDynamicStrideOrOffset);
1316   build(builder, result, PointType::get(builder.getContext()), superset,
1317         dynamicOffsets, builder.getI64ArrayAttr(staticOffsets));
1318   result.addAttributes(attrs);
1319 }
1320 
verify()1321 LogicalResult PointOp::verify() {
1322   auto tileShape = superset().getType().cast<TileType>().getShape();
1323   if (failed(mlir::verifyListOfOperandsOrIntegers(
1324           getOperation(), "index", tileShape.size(), static_indices(),
1325           dynamic_indices(), ShapedType::isDynamicStrideOrOffset))) {
1326     return failure();
1327   }
1328   // Check whether the known indices are in-bounds of known dimension sizes.
1329   for (auto dimAndIndex : llvm::zip(tileShape, static_indices())) {
1330     auto dimSize = std::get<0>(dimAndIndex);
1331     auto index =
1332         std::get<1>(dimAndIndex).dyn_cast<mlir::IntegerAttr>().getInt();
1333     if (index == ShapedType::kDynamicStrideOrOffset) continue;
1334     if (index < 0) {
1335       return emitOpError("expected index = ") << index << " to be non-negative";
1336     }
1337     if (dimSize != ShapedType::kDynamicSize && index >= dimSize) {
1338       return emitOpError("expected index = ")
1339              << index << " to be between 0 and " << (dimSize - 1);
1340     }
1341   }
1342   return success();
1343 }
1344 
1345 //
1346 //===----------------------------------------------------------------------===//
1347 // TileOp
1348 //===----------------------------------------------------------------------===//
1349 
build(OpBuilder & b,OperationState & result,Value superset,ArrayRef<OpFoldResult> offsets,ArrayRef<OpFoldResult> sizes,ArrayRef<OpFoldResult> strides,ArrayRef<NamedAttribute> attrs)1350 void TileOp::build(OpBuilder &b, OperationState &result, Value superset,
1351                    ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
1352                    ArrayRef<OpFoldResult> strides,
1353                    ArrayRef<NamedAttribute> attrs) {
1354   SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
1355   SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
1356   dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets,
1357                              ShapedType::kDynamicStrideOrOffset);
1358   dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes,
1359                              ShapedType::kDynamicSize);
1360   dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides,
1361                              ShapedType::kDynamicStrideOrOffset);
1362   auto tileType = TileType::get(b.getContext(), staticSizes);
1363   build(b, result, tileType, superset, dynamicOffsets, dynamicSizes,
1364         dynamicStrides, b.getI64ArrayAttr(staticOffsets),
1365         b.getI64ArrayAttr(staticSizes), b.getI64ArrayAttr(staticStrides));
1366   result.addAttributes(attrs);
1367 }
1368 
inferReturnTypes(MLIRContext * ctx,Optional<Location>,ValueRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<Type> & inferredReturnTypes)1369 LogicalResult TileOp::inferReturnTypes(
1370     MLIRContext *ctx, Optional<Location> /*loc*/, ValueRange operands,
1371     DictionaryAttr attributes, RegionRange regions,
1372     SmallVectorImpl<Type> &inferredReturnTypes) {
1373   // Derive result shape.
1374   TileOp::Adaptor adaptor(operands, attributes, regions);
1375   SmallVector<int64_t> shape = llvm::to_vector(
1376       llvm::map_range(adaptor.static_sizes(), [&](const auto &size) {
1377         return size.template dyn_cast<mlir::IntegerAttr>()
1378             .getValue()
1379             .getSExtValue();
1380       }));
1381 
1382   auto resultTy = TileType::get(ctx, shape);
1383   inferredReturnTypes.push_back(resultTy);
1384   return success();
1385 }
1386 
verify()1387 LogicalResult TileOp::verify() {
1388   auto supersetTy = superset().getType().cast<TileType>();
1389   auto rank = supersetTy.getShape().size();
1390   if (failed(mlir::verifyListOfOperandsOrIntegers(getOperation(), "size", rank,
1391                                                   static_sizes(), sizes(),
1392                                                   ShapedType::isDynamic))) {
1393     return failure();
1394   }
1395   if (failed(mlir::verifyListOfOperandsOrIntegers(
1396           getOperation(), "offset", rank, static_offsets(), offsets(),
1397           ShapedType::isDynamicStrideOrOffset))) {
1398     return failure();
1399   }
1400   if (failed(mlir::verifyListOfOperandsOrIntegers(
1401           getOperation(), "stride", rank, static_strides(), strides(),
1402           ShapedType::isDynamicStrideOrOffset))) {
1403     return failure();
1404   }
1405   for (auto it : llvm::zip(supersetTy.getShape(), static_offsets(),
1406                            static_sizes(), static_strides())) {
1407     auto offset =
1408         std::get<1>(it).dyn_cast<mlir::IntegerAttr>().getValue().getSExtValue();
1409     if (offset < 0 && offset != ShapedType::kDynamicStrideOrOffset) {
1410       return emitOpError("expected offset = ")
1411              << offset << " to be non-negative";
1412     }
1413     auto size =
1414         std::get<2>(it).dyn_cast<mlir::IntegerAttr>().getValue().getSExtValue();
1415     if (size < 0 && size != ShapedType::kDynamicSize) {
1416       return emitOpError("expected size = ") << size << " to be non-negative";
1417     }
1418     auto stride =
1419         std::get<3>(it).dyn_cast<mlir::IntegerAttr>().getValue().getSExtValue();
1420     if (stride < 0 && stride != ShapedType::kDynamicStrideOrOffset) {
1421       return emitOpError("expected stride = ")
1422              << stride << " to be non-negative";
1423     }
1424     auto argSize = std::get<0>(it);
1425     // If the argument tile has a dynamic dimension, no additional verification
1426     // is possible.
1427     if (argSize == ShapedType::kDynamicSize) continue;
1428     if (offset >= 0) {
1429       if (stride >= 0 && size > 0) {
1430         int64_t largestIndex = offset + stride * (size - 1);
1431         if (largestIndex >= argSize) {
1432           return emitOpError("offset = ")
1433                  << offset << " size = " << size << " stride = " << stride
1434                  << " causes access out of bounds at " << largestIndex
1435                  << " for argument dimension size = " << argSize;
1436         }
1437       } else if (offset >= argSize) {
1438         return emitOpError("offset = ")
1439                << offset
1440                << " is out of bounds for argument dimension size = " << argSize;
1441       }
1442     } else if (stride > 0 && size > 0 && stride * (size - 1) >= argSize) {
1443       return emitOpError("size = ")
1444              << size << " stride = " << stride
1445              << " causes access out of bounds for argument dimension size = "
1446              << argSize;
1447     }
1448   }
1449   return success();
1450 }
1451 
1452 namespace {
1453 
multiplyOperandsOrIntegers(OpBuilder & builder,Location loc,OpFoldResult lhs,OpFoldResult rhs)1454 OpFoldResult multiplyOperandsOrIntegers(OpBuilder &builder, Location loc,
1455                                         OpFoldResult lhs, OpFoldResult rhs) {
1456   // Both operands are static.
1457   if (lhs.is<Attribute>() && rhs.is<Attribute>()) {
1458     return builder.getI64IntegerAttr(
1459         lhs.get<Attribute>().cast<IntegerAttr>().getInt() *
1460         rhs.get<Attribute>().cast<IntegerAttr>().getInt());
1461   }
1462 
1463   // Exploit commutativity and move static operand to the left (if any).
1464   if (rhs.is<Attribute>()) std::swap(lhs, rhs);
1465 
1466   // Create constant if needed.
1467   if (lhs.is<Attribute>()) {
1468     int64_t lhsInt = lhs.get<Attribute>().cast<IntegerAttr>().getInt();
1469 
1470     // Exploit static operand if possible.
1471     if (lhsInt == 0) return lhs;
1472     if (lhsInt == 1) return rhs;
1473 
1474     lhs = builder.create<arith::ConstantIndexOp>(loc, lhsInt).getResult();
1475   }
1476 
1477   // Multiply.
1478   return builder.create<arith::MulIOp>(loc, lhs.get<Value>(), rhs.get<Value>())
1479       .getResult();
1480 }
1481 
addOperandsOrIntegers(OpBuilder & builder,Location loc,OpFoldResult lhs,OpFoldResult rhs)1482 OpFoldResult addOperandsOrIntegers(OpBuilder &builder, Location loc,
1483                                    OpFoldResult lhs, OpFoldResult rhs) {
1484   // Both operands are static.
1485   if (lhs.is<Attribute>() && rhs.is<Attribute>()) {
1486     return builder.getI64IntegerAttr(
1487         lhs.get<Attribute>().cast<IntegerAttr>().getInt() +
1488         rhs.get<Attribute>().cast<IntegerAttr>().getInt());
1489   }
1490 
1491   // Exploit commutativity and move static operand to the left (if any).
1492   if (rhs.is<Attribute>()) std::swap(lhs, rhs);
1493 
1494   // Create constant if needed.
1495   if (lhs.is<Attribute>()) {
1496     int64_t lhsInt = lhs.get<Attribute>().cast<IntegerAttr>().getInt();
1497 
1498     // Exploit static operand if possible.
1499     if (lhsInt == 0) return rhs;
1500 
1501     lhs = builder.create<arith::ConstantIndexOp>(loc, lhsInt).getResult();
1502   }
1503 
1504   // Add.
1505   return builder.create<arith::AddIOp>(loc, lhs.get<Value>(), rhs.get<Value>())
1506       .getResult();
1507 }
1508 
1509 // Compose offsets with newOffset = supersetOffset + supersetStride * offset.
composeOffsets(const llvm::SmallVectorImpl<OpFoldResult> & supersetOffsets,const llvm::SmallVectorImpl<OpFoldResult> & supersetStrides,const llvm::SmallVectorImpl<OpFoldResult> & offsets,Location loc,OpBuilder & builder)1510 SmallVector<OpFoldResult> composeOffsets(
1511     const llvm::SmallVectorImpl<OpFoldResult> &supersetOffsets,
1512     const llvm::SmallVectorImpl<OpFoldResult> &supersetStrides,
1513     const llvm::SmallVectorImpl<OpFoldResult> &offsets, Location loc,
1514     OpBuilder &builder) {
1515   SmallVector<OpFoldResult> composedOffsets;
1516   for (auto it : llvm::zip(supersetOffsets, supersetStrides, offsets)) {
1517     composedOffsets.push_back(addOperandsOrIntegers(
1518         builder, loc, std::get<0>(it),
1519         multiplyOperandsOrIntegers(builder, loc, std::get<1>(it),
1520                                    std::get<2>(it))));
1521   }
1522   return composedOffsets;
1523 }
1524 
1525 // Compose strides with newStride = supersetStride * stride.
composeStrides(OpBuilder & builder,Location loc,const llvm::SmallVectorImpl<OpFoldResult> & supersetStrides,const llvm::SmallVectorImpl<OpFoldResult> & strides)1526 SmallVector<OpFoldResult> composeStrides(
1527     OpBuilder &builder, Location loc,
1528     const llvm::SmallVectorImpl<OpFoldResult> &supersetStrides,
1529     const llvm::SmallVectorImpl<OpFoldResult> &strides) {
1530   SmallVector<OpFoldResult> composedStrides;
1531   for (auto it : llvm::zip(supersetStrides, strides)) {
1532     composedStrides.push_back(multiplyOperandsOrIntegers(
1533         builder, loc, std::get<0>(it), std::get<1>(it)));
1534   }
1535   return composedStrides;
1536 }
1537 
1538 }  // namespace
1539 
compose(OpBuilder & builder)1540 Value TileOp::compose(OpBuilder &builder) {
1541   auto supersetOp = llvm::dyn_cast_or_null<TileOp>(superset().getDefiningOp());
1542   if (!supersetOp) return {};
1543 
1544   // Compose offsets with newOffset = supersetOffset + supersetStride *
1545   // offset.
1546   auto loc = getLoc();
1547   auto composedOffsets =
1548       composeOffsets(supersetOp.getMixedOffsets(), supersetOp.getMixedStrides(),
1549                      getMixedOffsets(), loc, builder);
1550 
1551   // Compose strides with newStride = supersetStride * stride.
1552   auto composedStrides = composeStrides(
1553       builder, loc, supersetOp.getMixedStrides(), getMixedStrides());
1554 
1555   // Build the composed tile op.
1556   return builder.create<TileOp>(loc, supersetOp.superset(), composedOffsets,
1557                                 getMixedSizes(), composedStrides);
1558 }
1559 
1560 //===----------------------------------------------------------------------===//
1561 // PointOp
1562 //===----------------------------------------------------------------------===//
1563 
compose(OpBuilder & builder)1564 Value PointOp::compose(OpBuilder &builder) {
1565   auto supersetOp = llvm::dyn_cast_or_null<TileOp>(superset().getDefiningOp());
1566   if (!supersetOp) return {};
1567 
1568   // Compose offsets with newOffset = supersetOffset + supersetStride *
1569   // offset.
1570   auto loc = getLoc();
1571   auto composedOffsets = decomposeMixedStridesOrOffsets(
1572       builder,
1573       composeOffsets(
1574           supersetOp.getMixedOffsets(), supersetOp.getMixedStrides(),
1575           mlir::getMixedStridesOrOffsets(static_indices(), dynamic_indices()),
1576           loc, builder));
1577 
1578   // Build the composed point op.
1579   return builder.create<PointOp>(loc, supersetOp.superset(),
1580                                  composedOffsets.second, composedOffsets.first);
1581 }
1582 
1583 //===----------------------------------------------------------------------===//
1584 // DropDimsOp
1585 //===----------------------------------------------------------------------===//
1586 
inferReturnTypes(MLIRContext * ctx,Optional<Location>,ValueRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<Type> & inferredReturnTypes)1587 LogicalResult DropDimsOp::inferReturnTypes(
1588     MLIRContext *ctx, Optional<Location> /*loc*/, ValueRange operands,
1589     DictionaryAttr attributes, RegionRange regions,
1590     SmallVectorImpl<Type> &inferredReturnTypes) {
1591   DropDimsOp::Adaptor adaptor(operands, attributes, regions);
1592   Type argTy = adaptor.superset().getType();
1593 
1594   // If the argument is of point type, so is the result.
1595   if (auto pointTy = argTy.dyn_cast<PointType>()) {
1596     inferredReturnTypes.push_back(argTy);
1597     return success();
1598   }
1599 
1600   // If the argument is of tile type, we can skip the dropped dimensions to
1601   // derive the result type.
1602   if (auto tileTy = argTy.dyn_cast<TileType>()) {
1603     auto argShape = tileTy.getShape();
1604     SmallVector<int64_t> resultShape = llvm::to_vector(llvm::map_range(
1605         adaptor.remaining_dims(), [&](const auto &d) { return argShape[d]; }));
1606     auto resultTy = TileType::get(ctx, resultShape);
1607     inferredReturnTypes.push_back(resultTy);
1608     return success();
1609   }
1610 
1611   return failure();
1612 }
1613 
1614 namespace {
1615 
selectMixedValues(const SmallVectorImpl<OpFoldResult> & mixedValues,ArrayRef<int64_t> selection)1616 SmallVector<OpFoldResult> selectMixedValues(
1617     const SmallVectorImpl<OpFoldResult> &mixedValues,
1618     ArrayRef<int64_t> selection) {
1619   return llvm::to_vector(
1620       llvm::map_range(selection, [&](int64_t i) { return mixedValues[i]; }));
1621 }
1622 
1623 // Composition set by selecting a subset of its dimensions. Both the dimensions
1624 // to select, and the order in which they should be selected, are specified by
1625 // `selection`.
selectDimsFromSet(OpBuilder & builder,Location loc,Type type,Value set,ArrayRef<int64_t> selection)1626 Value selectDimsFromSet(OpBuilder &builder, Location loc, Type type, Value set,
1627                         ArrayRef<int64_t> selection) {
1628   // Case: space
1629   Operation *setDef = set.getDefiningOp();
1630   if (auto spaceOp = llvm::dyn_cast_or_null<SpaceOp>(setDef)) {
1631     auto spaceSizes =
1632         getMixedSizes(spaceOp.static_sizes(), spaceOp.dynamic_sizes());
1633     auto newSpaceSizes = selectMixedValues(spaceSizes, selection);
1634     auto newSpaceSizesDecomposed = decomposeMixedSizes(builder, newSpaceSizes);
1635     return builder.create<SpaceOp>(loc, newSpaceSizesDecomposed.second,
1636                                    newSpaceSizesDecomposed.first);
1637   }
1638 
1639   // Case: point(space)
1640   if (PointOp pointOp = llvm::dyn_cast_or_null<PointOp>(setDef)) {
1641     auto newSpace =
1642         selectDimsFromSet(builder, loc, type, pointOp.superset(), selection);
1643     auto pointOffsets = getMixedStridesOrOffsets(pointOp.static_indices(),
1644                                                  pointOp.dynamic_indices());
1645     auto newPointOffsets = selectMixedValues(pointOffsets, selection);
1646     auto newPointOffsetsDecomposed =
1647         decomposeMixedStridesOrOffsets(builder, newPointOffsets);
1648     return builder.create<PointOp>(loc, newSpace,
1649                                    newPointOffsetsDecomposed.second,
1650                                    newPointOffsetsDecomposed.first);
1651   }
1652 
1653   // Case: tile(space)
1654   if (TileOp tileOp = llvm::dyn_cast_or_null<TileOp>(setDef)) {
1655     auto newSpace =
1656         selectDimsFromSet(builder, loc, type, tileOp.superset(), selection);
1657 
1658     auto tileOffsets =
1659         getMixedStridesOrOffsets(tileOp.static_offsets(), tileOp.offsets());
1660     auto newTileOffsets = selectMixedValues(tileOffsets, selection);
1661     auto newTileOffsetsDecomposed =
1662         decomposeMixedStridesOrOffsets(builder, newTileOffsets);
1663 
1664     auto tileSizes = getMixedSizes(tileOp.static_sizes(), tileOp.sizes());
1665     auto newTileSizes = selectMixedValues(tileSizes, selection);
1666     auto newTileSizesDecomposed = decomposeMixedSizes(builder, newTileSizes);
1667 
1668     auto tileStrides =
1669         getMixedStridesOrOffsets(tileOp.static_strides(), tileOp.strides());
1670     auto newTileStrides = selectMixedValues(tileStrides, selection);
1671     auto newTileStridesDecomposed =
1672         decomposeMixedStridesOrOffsets(builder, newTileStrides);
1673 
1674     return builder.create<TileOp>(
1675         loc, newSpace, newTileOffsetsDecomposed.second,
1676         newTileSizesDecomposed.second, newTileStridesDecomposed.second,
1677         newTileOffsetsDecomposed.first, newTileSizesDecomposed.first,
1678         newTileStridesDecomposed.first);
1679   }
1680 
1681   return {};
1682 }
1683 
1684 }  // namespace
1685 
compose(OpBuilder & builder)1686 Value DropDimsOp::compose(OpBuilder &builder) {
1687   // We can compose with a TileOp operand which has a SpaceOp operand, or
1688   // compose with a SpaceOp operand.
1689   return selectDimsFromSet(builder, getLoc(), getType(), superset(),
1690                            remaining_dims());
1691 }
1692 
1693 //===----------------------------------------------------------------------===//
1694 // TransposeDimsOp
1695 //===----------------------------------------------------------------------===//
1696 
inferReturnTypes(MLIRContext * ctx,Optional<Location>,ValueRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<Type> & inferredReturnTypes)1697 LogicalResult TransposeDimsOp::inferReturnTypes(
1698     MLIRContext *ctx, Optional<Location> /*loc*/, ValueRange operands,
1699     DictionaryAttr attributes, RegionRange regions,
1700     SmallVectorImpl<Type> &inferredReturnTypes) {
1701   TransposeDimsOp::Adaptor adaptor(operands, attributes, regions);
1702   const Type argTy = adaptor.superset().getType();
1703 
1704   // If the argument is of point type, so is the result.
1705   if (auto pointTy = argTy.dyn_cast<PointType>()) {
1706     inferredReturnTypes.push_back(argTy);
1707     return success();
1708   }
1709 
1710   // If the argument is of tile type, we can transpose the type's dimensions.
1711   if (auto tileTy = argTy.dyn_cast<TileType>()) {
1712     auto argShape = tileTy.getShape();
1713     const SmallVector<int64_t> resultShape = llvm::to_vector(llvm::map_range(
1714         adaptor.permutation(), [&](const auto &d) { return argShape[d]; }));
1715     auto resultTy = TileType::get(ctx, resultShape);
1716     inferredReturnTypes.push_back(resultTy);
1717     return success();
1718   }
1719 
1720   return failure();
1721 }
1722 
compose(OpBuilder & builder)1723 Value TransposeDimsOp::compose(OpBuilder &builder) {
1724   // We can compose with a TileOp operand which has a SpaceOp operand, or
1725   // compose with a SpaceOp operand. transpose_tile(tile(space, offsets, sizes,
1726   // strides)) is replaced by tile(transpose(space), transpose(offsets),
1727   // transpose(sizes), transpose(strides)). transpose_tile(space) is replaced by
1728   // transpose(space).
1729 
1730   return selectDimsFromSet(builder, getLoc(), getType(), superset(),
1731                            permutation());
1732 }
1733 
verify()1734 LogicalResult TransposeDimsOp::verify() {
1735   // Verify that `permutation` is in fact a permutation.
1736   size_t rank = permutation().size();
1737   SmallVector<int64_t> position(rank, -1);
1738   for (const auto &it : llvm::enumerate(permutation())) {
1739     int64_t dim = it.value();
1740     if (dim < 0 || dim >= static_cast<int64_t>(rank)) {
1741       return emitOpError("permutation[")
1742              << it.index() << "] = " << dim << " is outside of range [0, "
1743              << rank - 1 << "]";
1744     }
1745     if (position[dim] >= 0) {
1746       return emitOpError(
1747                  "expected permutation attribute to contain no duplicate "
1748                  "values, but got ")
1749              << dim << " at positions " << position[dim] << " and "
1750              << it.index();
1751     }
1752     position[dim] = it.index();
1753   }
1754 
1755   // Verify tile-specific relationship between types and permutation. The
1756   // constraints between argument and result type are verified through the
1757   // implementation of `inferReturnTypes`.
1758   if (auto tileTy = getType().dyn_cast<TileType>()) {
1759     size_t tileRank = tileTy.getShape().size();
1760     if (tileRank != rank) {
1761       return emitOpError("expected result rank ")
1762              << tileRank << " to match the permutation size of " << rank << ".";
1763     }
1764   }
1765 
1766   return success();
1767 }
1768 
1769 //===----------------------------------------------------------------------===//
1770 // SetYieldOp
1771 //===----------------------------------------------------------------------===//
1772 
1773 using AccumulatorRegionBuilderFn =
1774     function_ref<void(OpBuilder &, Location, Value, Value)>;
1775 
build(OpBuilder & builder,OperationState & result)1776 void SetYieldOp::build(OpBuilder &builder, OperationState &result) {
1777   build(builder, result, llvm::None, llvm::None, llvm::None);
1778 }
1779 
build(OpBuilder & builder,OperationState & result,ValueRange srcs,ValueRange dsts,ValueRange sets)1780 void SetYieldOp::build(OpBuilder &builder, OperationState &result,
1781                        ValueRange srcs, ValueRange dsts, ValueRange sets) {
1782   SmallVector<bool, 2> accumulatorFlags(srcs.size(), false);
1783   build(builder, result, srcs, dsts, sets,
1784         builder.getBoolArrayAttr(accumulatorFlags), llvm::None);
1785 }
1786 
build(OpBuilder & builder,OperationState & result,ValueRange srcs,ValueRange dsts,ValueRange sets,ArrayAttr accumulatorFlags,ArrayRef<AccumulatorRegionBuilderFn> accumulatorBuilderFns)1787 void SetYieldOp::build(
1788     OpBuilder &builder, OperationState &result, ValueRange srcs,
1789     ValueRange dsts, ValueRange sets, ArrayAttr accumulatorFlags,
1790     ArrayRef<AccumulatorRegionBuilderFn> accumulatorBuilderFns) {
1791   assert(dsts.size() == srcs.size() &&
1792          "`dsts` and `srcs` should have the same size");
1793   assert(sets.size() == srcs.size() &&
1794          "`sets` and `srcs` should have the same size");
1795   assert(accumulatorFlags.size() == srcs.size() &&
1796          "`accumulatorFlags` and `srcs` should have the same size");
1797 
1798   auto accumulatorCount = llvm::count_if(accumulatorFlags, [](Attribute attr) {
1799     return attr.cast<BoolAttr>().getValue();
1800   });
1801   (void)accumulatorCount;
1802   assert(accumulatorCount ==
1803              static_cast<int64_t>(accumulatorBuilderFns.size()) &&
1804          "the number of flags set in `accumulatorFlags` attribute should be "
1805          "equal to the number of `accumulatorBuilderFns`");
1806 
1807   result.addOperands(srcs);
1808   result.addOperands(dsts);
1809   result.addOperands(sets);
1810   result.addAttribute(SetYieldOp::accumulatorFlagsAttrName(result.name),
1811                       accumulatorFlags);
1812 
1813   const auto *builderFnIt = accumulatorBuilderFns.begin();
1814   for (auto item : llvm::zip(srcs, accumulatorFlags)) {
1815     Value src = std::get<0>(item);
1816     auto accumulatorFlag = std::get<1>(item).cast<BoolAttr>();
1817 
1818     if (!accumulatorFlag.getValue()) continue;
1819     Region *region = result.addRegion();
1820     OpBuilder::InsertionGuard g(builder);
1821     SmallVector<Type, 2> argTypes(2, src.getType());
1822     builder.createBlock(region);
1823     Block &bodyBlock = region->front();
1824     bodyBlock.addArguments(argTypes, {result.location, result.location});
1825 
1826     builder.setInsertionPointToStart(&bodyBlock);
1827     (*builderFnIt)(builder, result.location, bodyBlock.getArgument(0),
1828                    bodyBlock.getArgument(1));
1829     ++builderFnIt;
1830   }
1831 }
1832 
verify()1833 LogicalResult SetYieldOp::verify() {
1834   auto accumulatorCount = llvm::count_if(
1835       accumulatorFlags(),
1836       [](Attribute attr) { return attr.cast<BoolAttr>().getValue(); });
1837   if (accumulatorCount != static_cast<int64_t>(accumulators().size()))
1838     return emitOpError("expected the number of accumulator regions ")
1839            << accumulators().size()
1840            << " to match the number of set accumulator flags "
1841            << accumulatorCount;
1842 
1843   auto *regionIt = accumulators().begin();
1844   for (auto item : llvm::zip(srcs(), accumulatorFlags())) {
1845     Type srcType = std::get<0>(item).getType();
1846     BoolAttr accumulatorFlag = std::get<1>(item).cast<BoolAttr>();
1847     if (!accumulatorFlag.getValue()) continue;
1848 
1849     Block &block = regionIt->front();
1850     if (block.getArgumentTypes() != SmallVector<Type>{srcType, srcType})
1851       return emitOpError()
1852              << "expected accumulator region to have 2 arguments of type "
1853              << srcType;
1854     ++regionIt;
1855   }
1856   return success();
1857 }
1858 
print(OpAsmPrinter & p)1859 void SetYieldOp::print(OpAsmPrinter &p) {
1860   p.printOptionalAttrDict(getOperation()->getAttrs(), /*elidedAttrs = */
1861                           {accumulatorFlagsAttrName().str()});
1862 
1863   auto *regionIt = getOperation()->getRegions().begin();
1864   for (auto &en :
1865        llvm::enumerate(llvm::zip(srcs(), dsts(), sets(), accumulatorFlags()))) {
1866     if (en.index() > 0) {
1867       p << ',';
1868       p.printNewline();
1869     }
1870     Value src = std::get<0>(en.value());
1871     Value dst = std::get<1>(en.value());
1872     Value set = std::get<2>(en.value());
1873     auto accumulatorFlag = std::get<3>(en.value()).cast<BoolAttr>();
1874 
1875     p << ' ' << src << " into " << dst << '[' << set << ']';
1876 
1877     if (accumulatorFlag.getValue()) {
1878       auto &block = regionIt->getBlocks().front();
1879       Value newValue = block.getArgument(0);
1880       Value oldValue = block.getArgument(1);
1881       p << " acc (" << newValue << ", " << oldValue << ": "
1882         << oldValue.getType() << ") ";
1883 
1884       p.printRegion(*regionIt, false);
1885       ++regionIt;
1886     }
1887 
1888     p << " : " << src.getType() << " into " << dst.getType() << '['
1889       << set.getType() << ']';
1890   }
1891 }
1892 
parse(OpAsmParser & parser,OperationState & result)1893 ParseResult SetYieldOp::parse(OpAsmParser &parser, OperationState &result) {
1894   if (parser.parseOptionalAttrDict(result.attributes)) return failure();
1895 
1896   SmallVector<bool, 2> accumulatorFlags;
1897   SmallVector<OpAsmParser::UnresolvedOperand, 4> srcs, dsts, sets;
1898   SmallVector<Type, 4> srcTypes, dstTypes, setTypes;
1899 
1900   auto parseElt = [&]() -> ParseResult {
1901     OpAsmParser::UnresolvedOperand src;
1902     auto parseResult = parser.parseOptionalOperand(src, false);
1903 
1904     if (!parseResult.hasValue()) return success();
1905     srcs.push_back(src);
1906 
1907     if (parser.parseKeyword("into") ||
1908         parser.parseOperand(dsts.emplace_back()) || parser.parseLSquare() ||
1909         parser.parseOperand(sets.emplace_back()) || parser.parseRSquare())
1910       return failure();
1911 
1912     OpBuilder b(parser.getBuilder().getContext());
1913     bool hasAccumulatorRegion = succeeded(parser.parseOptionalKeyword("acc"));
1914     accumulatorFlags.push_back(hasAccumulatorRegion);
1915     if (hasAccumulatorRegion) {
1916       auto region = std::make_unique<Region>();
1917       OpAsmParser::UnresolvedOperand newValue, oldValue;
1918       Type argType;
1919       if (parser.parseLParen() || parser.parseOperand(newValue) ||
1920           parser.parseComma() || parser.parseOperand(oldValue) ||
1921           parser.parseColonType(argType) || parser.parseRParen())
1922         return failure();
1923 
1924       SmallVector<OpAsmParser::Argument, 4> regionArgs;
1925       for (auto value : {newValue, oldValue}) {
1926         auto &arg = regionArgs.emplace_back();
1927         arg.ssaName = value;
1928         arg.type = argType;
1929       }
1930 
1931       if (parser.parseRegion(*region, regionArgs)) return failure();
1932       result.addRegion(std::move(region));
1933     }
1934     if (parser.parseColon() || parser.parseType(srcTypes.emplace_back()) ||
1935         parser.parseKeyword("into") ||
1936         parser.parseType(dstTypes.emplace_back()) || parser.parseLSquare() ||
1937         parser.parseType(setTypes.emplace_back()) || parser.parseRSquare())
1938       return failure();
1939 
1940     return success();
1941   };
1942   if (parser.parseCommaSeparatedList(AsmParser::Delimiter::None, parseElt))
1943     return failure();
1944 
1945   if (parser.resolveOperands(srcs, srcTypes, parser.getCurrentLocation(),
1946                              result.operands) ||
1947       parser.resolveOperands(dsts, dstTypes, parser.getCurrentLocation(),
1948                              result.operands) ||
1949       parser.resolveOperands(sets, setTypes, parser.getCurrentLocation(),
1950                              result.operands))
1951     return failure();
1952 
1953   result.addAttribute(SetYieldOp::accumulatorFlagsAttrName(result.name),
1954                       parser.getBuilder().getBoolArrayAttr(accumulatorFlags));
1955   return success();
1956 }
1957 
1958 //===----------------------------------------------------------------------===//
1959 // OffsetOp
1960 //===----------------------------------------------------------------------===//
1961 
fold(ArrayRef<Attribute> operands)1962 OpFoldResult OffsetOp::fold(ArrayRef<Attribute> operands) {
1963   auto idxAttr = operands[1].dyn_cast_or_null<IntegerAttr>();
1964   if (!idxAttr) return {};
1965   int64_t idx = idxAttr.getInt();
1966 
1967   // Case: offset(point(space))
1968   Operation *subsetDef = subset().getDefiningOp();
1969   if (auto pointOp = llvm::dyn_cast_or_null<PointOp>(subsetDef)) {
1970     Operation *supersetDef = pointOp.superset().getDefiningOp();
1971 
1972     // Can only fold locally if the superset is the root space. Otherwise, rely
1973     // on subset composition.
1974     if (!llvm::isa_and_nonnull<SpaceOp>(supersetDef)) return {};
1975 
1976     return ensureIndexTypeForAttribute(mlir::getMixedStridesOrOffsets(
1977         pointOp.static_indices(), pointOp.dynamic_indices())[idx]);
1978   }
1979 
1980   // Case: offset(tile(space))
1981   if (auto tileOp = llvm::dyn_cast_or_null<TileOp>(subsetDef)) {
1982     Operation *supersetDef = tileOp.superset().getDefiningOp();
1983 
1984     // Can only fold locally if the superset is the root space. Otherwise, rely
1985     // on subset composition.
1986     if (!llvm::isa_and_nonnull<SpaceOp>(supersetDef)) return {};
1987 
1988     return ensureIndexTypeForAttribute(mlir::getMixedStridesOrOffsets(
1989         tileOp.static_offsets(), tileOp.offsets())[idx]);
1990   }
1991 
1992   // Case: offset(space)
1993   if (llvm::isa_and_nonnull<SpaceOp>(subsetDef)) {
1994     Builder b(getContext());
1995     return b.getIndexAttr(0);
1996   }
1997 
1998   return {};
1999 }
2000 
2001 //===----------------------------------------------------------------------===//
2002 // SizeOp
2003 //===----------------------------------------------------------------------===//
2004 
fold(ArrayRef<Attribute> operands)2005 OpFoldResult SizeOp::fold(ArrayRef<Attribute> operands) {
2006   auto idxAttr = operands[1].dyn_cast_or_null<IntegerAttr>();
2007   if (!idxAttr) return {};
2008   int64_t idx = idxAttr.getInt();
2009 
2010   // Case: size(tile(...))
2011   // Note that sizes can also be folded in the presence of nested tiling. There
2012   // is no need to check for an immediate root space here.
2013   Operation *tileDef = tile().getDefiningOp();
2014   if (auto tileOp = llvm::dyn_cast_or_null<TileOp>(tileDef)) {
2015     return ensureIndexTypeForAttribute(tileOp.getMixedSizes()[idx]);
2016   }
2017 
2018   // Case: size(space)
2019   if (auto spaceOp = llvm::dyn_cast_or_null<SpaceOp>(tileDef)) {
2020     return ensureIndexTypeForAttribute(mlir::getMixedSizes(
2021         spaceOp.static_sizes(), spaceOp.dynamic_sizes())[idx]);
2022   }
2023 
2024   return {};
2025 }
2026 
2027 //===----------------------------------------------------------------------===//
2028 // StrideOp
2029 //===----------------------------------------------------------------------===//
2030 
fold(ArrayRef<Attribute> operands)2031 OpFoldResult StrideOp::fold(ArrayRef<Attribute> operands) {
2032   auto idxAttr = operands[1].dyn_cast_or_null<IntegerAttr>();
2033   if (!idxAttr) return {};
2034   int64_t idx = idxAttr.getInt();
2035 
2036   // Case: offset(tile(space))
2037   Operation *subsetDef = tile().getDefiningOp();
2038   if (auto tileOp = llvm::dyn_cast_or_null<TileOp>(subsetDef)) {
2039     Operation *supersetDef = tileOp.superset().getDefiningOp();
2040 
2041     // Can only fold locally if the superset is the root space. Otherwise, rely
2042     // on subset composition.
2043     if (!llvm::isa_and_nonnull<SpaceOp>(supersetDef)) return {};
2044 
2045     return ensureIndexTypeForAttribute(mlir::getMixedStridesOrOffsets(
2046         tileOp.static_strides(), tileOp.strides())[idx]);
2047   }
2048 
2049   // Case: offset(space)
2050   if (llvm::isa_and_nonnull<SpaceOp>(subsetDef)) {
2051     Builder b(getContext());
2052     return b.getIndexAttr(1);
2053   }
2054 
2055   return {};
2056 }
2057 
2058 }  // namespace gml_st
2059 }  // namespace mlir
2060 
2061 // Generated op classes.
2062 #define GET_OP_CLASSES
2063 #include "mlir-hlo/Dialect/gml_st/IR/gml_st_ops.cc.inc"
2064