xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/gml_st/transforms/transforms.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "mlir-hlo/Dialect/gml_st/transforms/transforms.h"
17 
18 #include <tuple>
19 #include <utility>
20 
21 #include "mlir/Dialect/Affine/IR/AffineOps.h"
22 #include "mlir/Dialect/Arithmetic/Utils/Utils.h"
23 #include "mlir/Dialect/SCF/Utils/AffineCanonicalizationUtils.h"
24 #include "mlir/Dialect/Tensor/Utils/Utils.h"
25 #include "mlir/Dialect/Utils/StaticValueUtils.h"
26 #include "mlir/IR/BlockAndValueMapping.h"
27 
28 namespace mlir {
29 namespace gml_st {
30 namespace {
31 
32 /// Rewrite a LoopOp with bounds/step that potentially do not divide evenly
33 /// into two LoopOps: One where the step divides the iteration space
34 /// evenly, followed another one for the last (partial) iteration (if any). This
35 /// function only rewrites the `idx`-th loop of the loop nest represented by
36 /// the LoopOp. To peel the entire loop nest, this function must be called
37 /// multiple times.
38 ///
39 /// This function rewrites the given LoopOp in-place and creates a new
40 /// LoopOp for the last iteration. It replaces all uses of the original
41 /// LoopOp with the results of the newly generated one.
42 ///
43 /// The newly generated LoopOp is returned via `result`. The boundary
44 /// at which the loop is split (new upper bound) is returned via `splitBound`.
45 /// The return value indicates whether the LoopOp was rewritten or not.
peelLoop(RewriterBase & b,LoopOp loopOp,int64_t idx,LoopOp & result,Value & splitBound)46 static LogicalResult peelLoop(RewriterBase &b, LoopOp loopOp, int64_t idx,
47                               LoopOp &result, Value &splitBound) {
48   Value lb = loopOp.lowerBound()[idx], ub = loopOp.upperBound()[idx],
49         step = loopOp.step()[idx];
50   auto ubInt = getConstantIntValue(ub);
51 
52   auto loc = loopOp.getLoc();
53   AffineExpr exprLb, exprUb, exprStep;
54   bindSymbols(b.getContext(), exprLb, exprUb, exprStep);
55   // New upper bound: %ub - (%ub - %lb) mod %step
56   auto modMap = AffineMap::get(0, 3, {exprUb - ((exprUb - exprLb) % exprStep)});
57   SmallVector<Value> operands{lb, ub, step};
58   canonicalizeMapAndOperands(&modMap, &operands);
59   modMap = simplifyAffineMap(modMap);
60   RewriterBase::InsertionGuard guard(b);
61   b.setInsertionPoint(loopOp);
62   splitBound = b.createOrFold<AffineApplyOp>(loc, modMap, operands);
63   // No specialization necessary if step already divides upper bound evenly.
64   if (splitBound == ub || (ubInt && ubInt == getConstantIntValue(splitBound)))
65     return failure();
66 
67   // Create remainder loop.
68   b.setInsertionPointAfter(loopOp);
69   auto remainderLoop = cast<LoopOp>(b.clone(*loopOp.getOperation()));
70   loopOp.replaceAllUsesWith(remainderLoop->getResults());
71   // Outputs: Take tensors from main loop's results. Take memrefs from main
72   // loop's outputs.
73   SmallVector<Value> remainderOutputs;
74   for (unsigned o = 0, t = 0; o < loopOp.getNumOutputs(); ++o) {
75     remainderOutputs.push_back(loopOp.outputs()[o].getType().isa<MemRefType>()
76                                    ? loopOp.outputs()[o]
77                                    : loopOp->getResult(t++));
78   }
79   remainderLoop.outputsMutable().assign(remainderOutputs);
80 
81   // Set new loop bounds.
82   b.updateRootInPlace(loopOp, [&]() {
83     SmallVector<Value> ubs = loopOp.upperBound();
84     ubs[idx] = splitBound;
85     loopOp.upperBoundMutable().assign(ubs);
86   });
87   SmallVector<Value> lbs = remainderLoop.lowerBound();
88   lbs[idx] = splitBound;
89   remainderLoop.lowerBoundMutable().assign(lbs);
90 
91   result = remainderLoop;
92   return success();
93 }
94 
95 template <typename OpTy, bool IsMin>
rewriteAffineOpAfterPeeling(RewriterBase & rewriter,LoopOp mainLoop,LoopOp remainderLoop,Value mainIv,Value remainderIv,Value ub,Value step)96 static void rewriteAffineOpAfterPeeling(RewriterBase &rewriter, LoopOp mainLoop,
97                                         LoopOp remainderLoop, Value mainIv,
98                                         Value remainderIv, Value ub,
99                                         Value step) {
100   mainLoop.walk([&](OpTy affineOp) {
101     AffineMap map = affineOp.getAffineMap();
102     (void)scf::rewritePeeledMinMaxOp(rewriter, affineOp, map,
103                                      affineOp.operands(), IsMin, mainIv, ub,
104                                      step, /*insideLoop=*/true);
105   });
106   remainderLoop.walk([&](OpTy affineOp) {
107     AffineMap map = affineOp.getAffineMap();
108     (void)scf::rewritePeeledMinMaxOp(rewriter, affineOp, map,
109                                      affineOp.operands(), IsMin, remainderIv,
110                                      ub, step, /*insideLoop=*/false);
111   });
112 }
113 
isZero(Value v)114 bool isZero(Value v) {
115   if (auto cst = v.getDefiningOp<arith::ConstantIndexOp>())
116     return cst.value() == 0;
117   return false;
118 }
119 using ::mlir::linalg::LinalgOp;
120 
generateLoopNest(OpBuilder & b,Location loc,ArrayRef<Range> loopRanges,LinalgOp linalgOp,ArrayRef<Attribute> iteratorTypes,function_ref<scf::ValueVector (OpBuilder &,Location,ValueRange,ValueRange)> bodyBuilderFn,ArrayRef<StringRef> distributionTypes)121 void generateLoopNest(OpBuilder &b, Location loc, ArrayRef<Range> loopRanges,
122                       LinalgOp linalgOp, ArrayRef<Attribute> iteratorTypes,
123                       function_ref<scf::ValueVector(OpBuilder &, Location,
124                                                     ValueRange, ValueRange)>
125                           bodyBuilderFn,
126                       ArrayRef<StringRef> distributionTypes) {
127   SmallVector<OpFoldResult, 4> lbs, ubs, steps;
128   for (Range range : loopRanges) {
129     lbs.emplace_back(range.offset);
130     ubs.emplace_back(range.size);
131     steps.emplace_back(range.stride);
132   }
133 
134   auto wrappedBuilderFn = [&](OpBuilder &nestedBuilder, Location nestedLoc,
135                               ValueRange ivs, ValueRange inputs,
136                               ValueRange outputs) {
137     SmallVector<Value> operandValuesToUse = inputs;
138     operandValuesToUse.append(outputs.begin(), outputs.end());
139     scf::ValueVector results =
140         bodyBuilderFn(nestedBuilder, nestedLoc, ivs, operandValuesToUse);
141     nestedBuilder.create<gml_st::YieldOp>(nestedLoc, results);
142   };
143 
144   SmallVector<Value> inputOperands = linalgOp.getInputOperands();
145   SmallVector<Value> outputOperands = linalgOp.getOutputOperands();
146 
147   SmallVector<Value> lbsValue =
148       mlir::getValueOrCreateConstantIndexOp(b, loc, lbs);
149   SmallVector<Value> ubsValue =
150       mlir::getValueOrCreateConstantIndexOp(b, loc, ubs);
151   SmallVector<Value> stepsValue =
152       mlir::getValueOrCreateConstantIndexOp(b, loc, steps);
153   auto tiledLoop = b.create<LoopOp>(
154       loc, lbsValue, ubsValue, stepsValue, inputOperands, outputOperands,
155       b.getArrayAttr(iteratorTypes), wrappedBuilderFn);
156   if (!distributionTypes.empty())
157     tiledLoop.setDistributionTypes(b, distributionTypes);
158 }
159 
160 // Insert a tile `source` into the destination tensor `dest`. The position at
161 // which the tile is inserted (as well as size of tile) is taken from a given
162 // ExtractSliceOp `sliceOp`.
insertSliceIntoTensor(RewriterBase & b,Location loc,tensor::ExtractSliceOp sliceOp,Value source,Value dest)163 Value insertSliceIntoTensor(RewriterBase &b, Location loc,
164                             tensor::ExtractSliceOp sliceOp, Value source,
165                             Value dest) {
166   return b.create<tensor::InsertSliceOp>(
167       loc, sliceOp.getSource().getType(), source, dest, sliceOp.getOffsets(),
168       sliceOp.getSizes(), sliceOp.getStrides(), sliceOp.getStaticOffsets(),
169       sliceOp.getStaticSizes(), sliceOp.getStaticStrides());
170 }
171 
tileLinalgOpImpl(RewriterBase & b,LinalgOp op,ValueRange tileSizes,const linalg::LinalgTilingOptions & options)172 FailureOr<linalg::TiledLinalgOp> tileLinalgOpImpl(
173     RewriterBase &b, LinalgOp op, ValueRange tileSizes,
174     const linalg::LinalgTilingOptions &options) {
175   auto nLoops = op.getNumLoops();
176   // Initial tile sizes may be too big, only take the first nLoops.
177   tileSizes = tileSizes.take_front(nLoops);
178 
179   if (llvm::all_of(tileSizes, isZero)) {
180     linalg::TiledLinalgOp tiledOp;
181     tiledOp.op = cast<LinalgOp>(b.clone(*op.getOperation()));
182     tiledOp.tensorResults.assign(tiledOp.op->result_begin(),
183                                  tiledOp.op->result_end());
184     return tiledOp;
185   }
186 
187   SmallVector<OpFoldResult> tileSizesFold;
188   for (Value tileSize : tileSizes) tileSizesFold.push_back(tileSize);
189 
190   // 1. Build the tiled loop ranges.
191   auto allShapeSizes = op.createFlatListOfOperandDims(b, op.getLoc());
192   AffineMap shapeSizesToLoopsMap = op.getShapesToLoopsMap();
193   if (!shapeSizesToLoopsMap) return failure();
194 
195   SmallVector<Range, 4> loopRanges;
196   mlir::linalg::LoopIndexToRangeIndexMap loopIndexToRangeIndex;
197   std::tie(loopRanges, loopIndexToRangeIndex) =
198       mlir::linalg::makeTiledLoopRanges(b, op.getLoc(), shapeSizesToLoopsMap,
199                                         allShapeSizes, tileSizesFold);
200 
201   SmallVector<Attribute, 4> iteratorTypes;
202   for (const auto &attr :
203        enumerate(op.iterator_types().cast<ArrayAttr>().getValue())) {
204     if (loopIndexToRangeIndex.count(attr.index()))
205       iteratorTypes.push_back(attr.value());
206   }
207 
208   // 2. Create the tiled loops.
209   LinalgOp res = op;
210   SmallVector<Value, 4> ivs, tensorResults;
211   auto tiledLoopBodyBuilder =
212       [&](OpBuilder & /*builder*/, Location loc, ValueRange localIvs,
213           ValueRange operandValuesToUse) -> scf::ValueVector {
214     ivs.assign(localIvs.begin(), localIvs.end());
215 
216     // Tile the `operandValuesToUse` that either match the `op` operands
217     // themselves or the tile loop arguments forwarding them.
218     assert(operandValuesToUse.size() ==
219                static_cast<size_t>(op.getNumInputsAndOutputs()) &&
220            "expect the number of operands and inputs and outputs to match");
221     SmallVector<Value> valuesToTile = operandValuesToUse;
222     auto sizeBounds = makeComposedFoldedMultiResultAffineApply(
223         b, loc, shapeSizesToLoopsMap, allShapeSizes);
224     SmallVector<OpFoldResult> ivsFold(ivs.begin(), ivs.end());
225     SmallVector<Value, 4> tiledOperands = makeTiledShapes(
226         b, loc, op, valuesToTile, ivsFold, tileSizesFold, sizeBounds,
227         /*omitPartialTileCheck=*/false);
228 
229     SmallVector<Type, 4> resultTensorTypes;
230     for (OpOperand *opOperand : op.getOutputTensorOperands())
231       resultTensorTypes.push_back(
232           tiledOperands[opOperand->getOperandNumber()].getType());
233 
234     res = op.clone(b, loc, resultTensorTypes, tiledOperands);
235 
236     // Insert a insert_slice for each output tensor.
237     unsigned resultIdx = 0;
238     for (OpOperand *opOperand : op.getOutputTensorOperands()) {
239       Value outputTensor = tiledOperands[opOperand->getOperandNumber()];
240       IRRewriter rewriter(b);
241       if (auto sliceOp = outputTensor.getDefiningOp<tensor::ExtractSliceOp>()) {
242         tensorResults.push_back(insertSliceIntoTensor(rewriter, loc, sliceOp,
243                                                       res->getResult(resultIdx),
244                                                       sliceOp.getSource()));
245       } else {
246         tensorResults.push_back(res->getResult(resultIdx));
247       }
248       ++resultIdx;
249     }
250     return scf::ValueVector(tensorResults.begin(), tensorResults.end());
251   };
252   generateLoopNest(b, op.getLoc(), loopRanges, op, iteratorTypes,
253                    tiledLoopBodyBuilder, options.distributionTypes);
254 
255   // 3. Transform IndexOp results w.r.t. the tiling.
256   mlir::linalg::transformIndexOps(b, res, ivs, loopIndexToRangeIndex);
257 
258   // 4. Gather the newly created loops and return them with the new op.
259   SmallVector<Operation *, 8> loops;
260   loops.reserve(ivs.size());
261   for (auto iv : ivs) {
262     if (iv.isa<BlockArgument>()) {
263       loops.push_back(iv.cast<BlockArgument>().getOwner()->getParentOp());
264       assert(loops.back() && "no owner found for induction variable!");
265     } else {
266       loops.push_back(nullptr);
267     }
268   }
269 
270   // 5. Get the tensor results from the outermost loop if available. Otherwise
271   // use the previously captured `tensorResults`.
272   Operation *outermostLoop = nullptr;
273   for (Operation *loop : loops)
274     if ((outermostLoop = loop)) break;
275 
276   return linalg::TiledLinalgOp{
277       res, loops, outermostLoop ? outermostLoop->getResults() : tensorResults};
278 }
279 
280 }  // namespace
281 
peelAndCanonicalizeGmlStLoop(RewriterBase & rewriter,LoopOp loopOp,int64_t idx,LoopOp & result)282 LogicalResult peelAndCanonicalizeGmlStLoop(RewriterBase &rewriter,
283                                            LoopOp loopOp, int64_t idx,
284                                            LoopOp &result) {
285   int64_t numLoops = loopOp.iterator_types().size();
286   if (idx < 0 || numLoops <= idx) return failure();
287 
288   Value ub = loopOp.upperBound()[idx];
289   LoopOp remainderLoop;
290   Value splitBound;
291   if (failed(peelLoop(rewriter, loopOp, idx, remainderLoop, splitBound)))
292     return failure();
293 
294   // Rewrite affine.min and affine.max ops.
295   Value mainIv = loopOp.getInductionVars()[idx], step = loopOp.step()[idx],
296         remainderIv = remainderLoop.getInductionVars()[idx];
297 
298   rewriteAffineOpAfterPeeling<AffineMinOp, /*IsMin=*/true>(
299       rewriter, loopOp, remainderLoop, mainIv, remainderIv, ub, step);
300   rewriteAffineOpAfterPeeling<AffineMaxOp, /*IsMin=*/false>(
301       rewriter, loopOp, remainderLoop, mainIv, remainderIv, ub, step);
302 
303   result = remainderLoop;
304   return success();
305 }
306 
tileLinalgOp(RewriterBase & b,linalg::LinalgOp op,const linalg::LinalgTilingOptions & options)307 FailureOr<linalg::TiledLinalgOp> tileLinalgOp(
308     RewriterBase &b, linalg::LinalgOp op,
309     const linalg::LinalgTilingOptions &options) {
310   OpBuilder::InsertionGuard g(b);
311   b.setInsertionPoint(op);
312 
313   if (!options.tileSizeComputationFunction) return failure();
314 
315   // Enforce the convention that "tiling by zero" skips tiling a particular
316   // dimension. This convention is significantly simpler to handle instead of
317   // adjusting affine maps to account for missing dimensions.
318   auto nLoops = op.getNumLoops();
319   SmallVector<Value, 4> tileSizeVector =
320       options.tileSizeComputationFunction(b, op);
321   if (tileSizeVector.size() < nLoops) {
322     auto zero = b.create<arith::ConstantIndexOp>(op.getLoc(), 0);
323     tileSizeVector.append(nLoops - tileSizeVector.size(), zero);
324   }
325 
326   return tileLinalgOpImpl(b, op, tileSizeVector, options);
327 }
328 
329 constexpr llvm::StringLiteral kTransformMarker =
330     "__internal_transformation_marker__";
331 
setTransformationAttr(mlir::OpBuilder & b,Operation * op)332 void setTransformationAttr(mlir::OpBuilder &b, Operation *op) {
333   op->setAttr(kTransformMarker, b.getBoolAttr(true));
334 }
335 
removeTransformationAttr(Operation * op)336 void removeTransformationAttr(Operation *op) {
337   op->removeAttr(kTransformMarker);
338 }
339 
hasTransformationAttr(Operation * op)340 bool hasTransformationAttr(Operation *op) {
341   auto marker = op->getAttr(kTransformMarker);
342   if (!marker) return false;
343   return marker && marker.cast<BoolAttr>().getValue();
344 }
345 
346 }  // namespace gml_st
347 }  // namespace mlir
348