1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "mlir-hlo/Dialect/gml_st/transforms/transforms.h"
17
18 #include <tuple>
19 #include <utility>
20
21 #include "mlir/Dialect/Affine/IR/AffineOps.h"
22 #include "mlir/Dialect/Arithmetic/Utils/Utils.h"
23 #include "mlir/Dialect/SCF/Utils/AffineCanonicalizationUtils.h"
24 #include "mlir/Dialect/Tensor/Utils/Utils.h"
25 #include "mlir/Dialect/Utils/StaticValueUtils.h"
26 #include "mlir/IR/BlockAndValueMapping.h"
27
28 namespace mlir {
29 namespace gml_st {
30 namespace {
31
32 /// Rewrite a LoopOp with bounds/step that potentially do not divide evenly
33 /// into two LoopOps: One where the step divides the iteration space
34 /// evenly, followed another one for the last (partial) iteration (if any). This
35 /// function only rewrites the `idx`-th loop of the loop nest represented by
36 /// the LoopOp. To peel the entire loop nest, this function must be called
37 /// multiple times.
38 ///
39 /// This function rewrites the given LoopOp in-place and creates a new
40 /// LoopOp for the last iteration. It replaces all uses of the original
41 /// LoopOp with the results of the newly generated one.
42 ///
43 /// The newly generated LoopOp is returned via `result`. The boundary
44 /// at which the loop is split (new upper bound) is returned via `splitBound`.
45 /// The return value indicates whether the LoopOp was rewritten or not.
peelLoop(RewriterBase & b,LoopOp loopOp,int64_t idx,LoopOp & result,Value & splitBound)46 static LogicalResult peelLoop(RewriterBase &b, LoopOp loopOp, int64_t idx,
47 LoopOp &result, Value &splitBound) {
48 Value lb = loopOp.lowerBound()[idx], ub = loopOp.upperBound()[idx],
49 step = loopOp.step()[idx];
50 auto ubInt = getConstantIntValue(ub);
51
52 auto loc = loopOp.getLoc();
53 AffineExpr exprLb, exprUb, exprStep;
54 bindSymbols(b.getContext(), exprLb, exprUb, exprStep);
55 // New upper bound: %ub - (%ub - %lb) mod %step
56 auto modMap = AffineMap::get(0, 3, {exprUb - ((exprUb - exprLb) % exprStep)});
57 SmallVector<Value> operands{lb, ub, step};
58 canonicalizeMapAndOperands(&modMap, &operands);
59 modMap = simplifyAffineMap(modMap);
60 RewriterBase::InsertionGuard guard(b);
61 b.setInsertionPoint(loopOp);
62 splitBound = b.createOrFold<AffineApplyOp>(loc, modMap, operands);
63 // No specialization necessary if step already divides upper bound evenly.
64 if (splitBound == ub || (ubInt && ubInt == getConstantIntValue(splitBound)))
65 return failure();
66
67 // Create remainder loop.
68 b.setInsertionPointAfter(loopOp);
69 auto remainderLoop = cast<LoopOp>(b.clone(*loopOp.getOperation()));
70 loopOp.replaceAllUsesWith(remainderLoop->getResults());
71 // Outputs: Take tensors from main loop's results. Take memrefs from main
72 // loop's outputs.
73 SmallVector<Value> remainderOutputs;
74 for (unsigned o = 0, t = 0; o < loopOp.getNumOutputs(); ++o) {
75 remainderOutputs.push_back(loopOp.outputs()[o].getType().isa<MemRefType>()
76 ? loopOp.outputs()[o]
77 : loopOp->getResult(t++));
78 }
79 remainderLoop.outputsMutable().assign(remainderOutputs);
80
81 // Set new loop bounds.
82 b.updateRootInPlace(loopOp, [&]() {
83 SmallVector<Value> ubs = loopOp.upperBound();
84 ubs[idx] = splitBound;
85 loopOp.upperBoundMutable().assign(ubs);
86 });
87 SmallVector<Value> lbs = remainderLoop.lowerBound();
88 lbs[idx] = splitBound;
89 remainderLoop.lowerBoundMutable().assign(lbs);
90
91 result = remainderLoop;
92 return success();
93 }
94
95 template <typename OpTy, bool IsMin>
rewriteAffineOpAfterPeeling(RewriterBase & rewriter,LoopOp mainLoop,LoopOp remainderLoop,Value mainIv,Value remainderIv,Value ub,Value step)96 static void rewriteAffineOpAfterPeeling(RewriterBase &rewriter, LoopOp mainLoop,
97 LoopOp remainderLoop, Value mainIv,
98 Value remainderIv, Value ub,
99 Value step) {
100 mainLoop.walk([&](OpTy affineOp) {
101 AffineMap map = affineOp.getAffineMap();
102 (void)scf::rewritePeeledMinMaxOp(rewriter, affineOp, map,
103 affineOp.operands(), IsMin, mainIv, ub,
104 step, /*insideLoop=*/true);
105 });
106 remainderLoop.walk([&](OpTy affineOp) {
107 AffineMap map = affineOp.getAffineMap();
108 (void)scf::rewritePeeledMinMaxOp(rewriter, affineOp, map,
109 affineOp.operands(), IsMin, remainderIv,
110 ub, step, /*insideLoop=*/false);
111 });
112 }
113
isZero(Value v)114 bool isZero(Value v) {
115 if (auto cst = v.getDefiningOp<arith::ConstantIndexOp>())
116 return cst.value() == 0;
117 return false;
118 }
119 using ::mlir::linalg::LinalgOp;
120
generateLoopNest(OpBuilder & b,Location loc,ArrayRef<Range> loopRanges,LinalgOp linalgOp,ArrayRef<Attribute> iteratorTypes,function_ref<scf::ValueVector (OpBuilder &,Location,ValueRange,ValueRange)> bodyBuilderFn,ArrayRef<StringRef> distributionTypes)121 void generateLoopNest(OpBuilder &b, Location loc, ArrayRef<Range> loopRanges,
122 LinalgOp linalgOp, ArrayRef<Attribute> iteratorTypes,
123 function_ref<scf::ValueVector(OpBuilder &, Location,
124 ValueRange, ValueRange)>
125 bodyBuilderFn,
126 ArrayRef<StringRef> distributionTypes) {
127 SmallVector<OpFoldResult, 4> lbs, ubs, steps;
128 for (Range range : loopRanges) {
129 lbs.emplace_back(range.offset);
130 ubs.emplace_back(range.size);
131 steps.emplace_back(range.stride);
132 }
133
134 auto wrappedBuilderFn = [&](OpBuilder &nestedBuilder, Location nestedLoc,
135 ValueRange ivs, ValueRange inputs,
136 ValueRange outputs) {
137 SmallVector<Value> operandValuesToUse = inputs;
138 operandValuesToUse.append(outputs.begin(), outputs.end());
139 scf::ValueVector results =
140 bodyBuilderFn(nestedBuilder, nestedLoc, ivs, operandValuesToUse);
141 nestedBuilder.create<gml_st::YieldOp>(nestedLoc, results);
142 };
143
144 SmallVector<Value> inputOperands = linalgOp.getInputOperands();
145 SmallVector<Value> outputOperands = linalgOp.getOutputOperands();
146
147 SmallVector<Value> lbsValue =
148 mlir::getValueOrCreateConstantIndexOp(b, loc, lbs);
149 SmallVector<Value> ubsValue =
150 mlir::getValueOrCreateConstantIndexOp(b, loc, ubs);
151 SmallVector<Value> stepsValue =
152 mlir::getValueOrCreateConstantIndexOp(b, loc, steps);
153 auto tiledLoop = b.create<LoopOp>(
154 loc, lbsValue, ubsValue, stepsValue, inputOperands, outputOperands,
155 b.getArrayAttr(iteratorTypes), wrappedBuilderFn);
156 if (!distributionTypes.empty())
157 tiledLoop.setDistributionTypes(b, distributionTypes);
158 }
159
160 // Insert a tile `source` into the destination tensor `dest`. The position at
161 // which the tile is inserted (as well as size of tile) is taken from a given
162 // ExtractSliceOp `sliceOp`.
insertSliceIntoTensor(RewriterBase & b,Location loc,tensor::ExtractSliceOp sliceOp,Value source,Value dest)163 Value insertSliceIntoTensor(RewriterBase &b, Location loc,
164 tensor::ExtractSliceOp sliceOp, Value source,
165 Value dest) {
166 return b.create<tensor::InsertSliceOp>(
167 loc, sliceOp.getSource().getType(), source, dest, sliceOp.getOffsets(),
168 sliceOp.getSizes(), sliceOp.getStrides(), sliceOp.getStaticOffsets(),
169 sliceOp.getStaticSizes(), sliceOp.getStaticStrides());
170 }
171
tileLinalgOpImpl(RewriterBase & b,LinalgOp op,ValueRange tileSizes,const linalg::LinalgTilingOptions & options)172 FailureOr<linalg::TiledLinalgOp> tileLinalgOpImpl(
173 RewriterBase &b, LinalgOp op, ValueRange tileSizes,
174 const linalg::LinalgTilingOptions &options) {
175 auto nLoops = op.getNumLoops();
176 // Initial tile sizes may be too big, only take the first nLoops.
177 tileSizes = tileSizes.take_front(nLoops);
178
179 if (llvm::all_of(tileSizes, isZero)) {
180 linalg::TiledLinalgOp tiledOp;
181 tiledOp.op = cast<LinalgOp>(b.clone(*op.getOperation()));
182 tiledOp.tensorResults.assign(tiledOp.op->result_begin(),
183 tiledOp.op->result_end());
184 return tiledOp;
185 }
186
187 SmallVector<OpFoldResult> tileSizesFold;
188 for (Value tileSize : tileSizes) tileSizesFold.push_back(tileSize);
189
190 // 1. Build the tiled loop ranges.
191 auto allShapeSizes = op.createFlatListOfOperandDims(b, op.getLoc());
192 AffineMap shapeSizesToLoopsMap = op.getShapesToLoopsMap();
193 if (!shapeSizesToLoopsMap) return failure();
194
195 SmallVector<Range, 4> loopRanges;
196 mlir::linalg::LoopIndexToRangeIndexMap loopIndexToRangeIndex;
197 std::tie(loopRanges, loopIndexToRangeIndex) =
198 mlir::linalg::makeTiledLoopRanges(b, op.getLoc(), shapeSizesToLoopsMap,
199 allShapeSizes, tileSizesFold);
200
201 SmallVector<Attribute, 4> iteratorTypes;
202 for (const auto &attr :
203 enumerate(op.iterator_types().cast<ArrayAttr>().getValue())) {
204 if (loopIndexToRangeIndex.count(attr.index()))
205 iteratorTypes.push_back(attr.value());
206 }
207
208 // 2. Create the tiled loops.
209 LinalgOp res = op;
210 SmallVector<Value, 4> ivs, tensorResults;
211 auto tiledLoopBodyBuilder =
212 [&](OpBuilder & /*builder*/, Location loc, ValueRange localIvs,
213 ValueRange operandValuesToUse) -> scf::ValueVector {
214 ivs.assign(localIvs.begin(), localIvs.end());
215
216 // Tile the `operandValuesToUse` that either match the `op` operands
217 // themselves or the tile loop arguments forwarding them.
218 assert(operandValuesToUse.size() ==
219 static_cast<size_t>(op.getNumInputsAndOutputs()) &&
220 "expect the number of operands and inputs and outputs to match");
221 SmallVector<Value> valuesToTile = operandValuesToUse;
222 auto sizeBounds = makeComposedFoldedMultiResultAffineApply(
223 b, loc, shapeSizesToLoopsMap, allShapeSizes);
224 SmallVector<OpFoldResult> ivsFold(ivs.begin(), ivs.end());
225 SmallVector<Value, 4> tiledOperands = makeTiledShapes(
226 b, loc, op, valuesToTile, ivsFold, tileSizesFold, sizeBounds,
227 /*omitPartialTileCheck=*/false);
228
229 SmallVector<Type, 4> resultTensorTypes;
230 for (OpOperand *opOperand : op.getOutputTensorOperands())
231 resultTensorTypes.push_back(
232 tiledOperands[opOperand->getOperandNumber()].getType());
233
234 res = op.clone(b, loc, resultTensorTypes, tiledOperands);
235
236 // Insert a insert_slice for each output tensor.
237 unsigned resultIdx = 0;
238 for (OpOperand *opOperand : op.getOutputTensorOperands()) {
239 Value outputTensor = tiledOperands[opOperand->getOperandNumber()];
240 IRRewriter rewriter(b);
241 if (auto sliceOp = outputTensor.getDefiningOp<tensor::ExtractSliceOp>()) {
242 tensorResults.push_back(insertSliceIntoTensor(rewriter, loc, sliceOp,
243 res->getResult(resultIdx),
244 sliceOp.getSource()));
245 } else {
246 tensorResults.push_back(res->getResult(resultIdx));
247 }
248 ++resultIdx;
249 }
250 return scf::ValueVector(tensorResults.begin(), tensorResults.end());
251 };
252 generateLoopNest(b, op.getLoc(), loopRanges, op, iteratorTypes,
253 tiledLoopBodyBuilder, options.distributionTypes);
254
255 // 3. Transform IndexOp results w.r.t. the tiling.
256 mlir::linalg::transformIndexOps(b, res, ivs, loopIndexToRangeIndex);
257
258 // 4. Gather the newly created loops and return them with the new op.
259 SmallVector<Operation *, 8> loops;
260 loops.reserve(ivs.size());
261 for (auto iv : ivs) {
262 if (iv.isa<BlockArgument>()) {
263 loops.push_back(iv.cast<BlockArgument>().getOwner()->getParentOp());
264 assert(loops.back() && "no owner found for induction variable!");
265 } else {
266 loops.push_back(nullptr);
267 }
268 }
269
270 // 5. Get the tensor results from the outermost loop if available. Otherwise
271 // use the previously captured `tensorResults`.
272 Operation *outermostLoop = nullptr;
273 for (Operation *loop : loops)
274 if ((outermostLoop = loop)) break;
275
276 return linalg::TiledLinalgOp{
277 res, loops, outermostLoop ? outermostLoop->getResults() : tensorResults};
278 }
279
280 } // namespace
281
peelAndCanonicalizeGmlStLoop(RewriterBase & rewriter,LoopOp loopOp,int64_t idx,LoopOp & result)282 LogicalResult peelAndCanonicalizeGmlStLoop(RewriterBase &rewriter,
283 LoopOp loopOp, int64_t idx,
284 LoopOp &result) {
285 int64_t numLoops = loopOp.iterator_types().size();
286 if (idx < 0 || numLoops <= idx) return failure();
287
288 Value ub = loopOp.upperBound()[idx];
289 LoopOp remainderLoop;
290 Value splitBound;
291 if (failed(peelLoop(rewriter, loopOp, idx, remainderLoop, splitBound)))
292 return failure();
293
294 // Rewrite affine.min and affine.max ops.
295 Value mainIv = loopOp.getInductionVars()[idx], step = loopOp.step()[idx],
296 remainderIv = remainderLoop.getInductionVars()[idx];
297
298 rewriteAffineOpAfterPeeling<AffineMinOp, /*IsMin=*/true>(
299 rewriter, loopOp, remainderLoop, mainIv, remainderIv, ub, step);
300 rewriteAffineOpAfterPeeling<AffineMaxOp, /*IsMin=*/false>(
301 rewriter, loopOp, remainderLoop, mainIv, remainderIv, ub, step);
302
303 result = remainderLoop;
304 return success();
305 }
306
tileLinalgOp(RewriterBase & b,linalg::LinalgOp op,const linalg::LinalgTilingOptions & options)307 FailureOr<linalg::TiledLinalgOp> tileLinalgOp(
308 RewriterBase &b, linalg::LinalgOp op,
309 const linalg::LinalgTilingOptions &options) {
310 OpBuilder::InsertionGuard g(b);
311 b.setInsertionPoint(op);
312
313 if (!options.tileSizeComputationFunction) return failure();
314
315 // Enforce the convention that "tiling by zero" skips tiling a particular
316 // dimension. This convention is significantly simpler to handle instead of
317 // adjusting affine maps to account for missing dimensions.
318 auto nLoops = op.getNumLoops();
319 SmallVector<Value, 4> tileSizeVector =
320 options.tileSizeComputationFunction(b, op);
321 if (tileSizeVector.size() < nLoops) {
322 auto zero = b.create<arith::ConstantIndexOp>(op.getLoc(), 0);
323 tileSizeVector.append(nLoops - tileSizeVector.size(), zero);
324 }
325
326 return tileLinalgOpImpl(b, op, tileSizeVector, options);
327 }
328
329 constexpr llvm::StringLiteral kTransformMarker =
330 "__internal_transformation_marker__";
331
setTransformationAttr(mlir::OpBuilder & b,Operation * op)332 void setTransformationAttr(mlir::OpBuilder &b, Operation *op) {
333 op->setAttr(kTransformMarker, b.getBoolAttr(true));
334 }
335
removeTransformationAttr(Operation * op)336 void removeTransformationAttr(Operation *op) {
337 op->removeAttr(kTransformMarker);
338 }
339
hasTransformationAttr(Operation * op)340 bool hasTransformationAttr(Operation *op) {
341 auto marker = op->getAttr(kTransformMarker);
342 if (!marker) return false;
343 return marker && marker.cast<BoolAttr>().getValue();
344 }
345
346 } // namespace gml_st
347 } // namespace mlir
348