xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/transforms/legalize_sort.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This file implements logic for lowering MHLO dialect to SCF dialect.
17 #include <utility>
18 
19 #include "llvm/ADT/STLExtras.h"
20 #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
21 #include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h"
22 #include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
23 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
24 #include "mlir/Dialect/Func/IR/FuncOps.h"
25 #include "mlir/Dialect/SCF/IR/SCF.h"
26 #include "mlir/Dialect/Tensor/IR/Tensor.h"  // TF:llvm-project
27 #include "mlir/IR/Block.h"
28 #include "mlir/IR/BlockAndValueMapping.h"
29 #include "mlir/IR/Builders.h"
30 #include "mlir/IR/BuiltinTypes.h"
31 #include "mlir/IR/PatternMatch.h"
32 #include "mlir/IR/TypeRange.h"
33 #include "mlir/Pass/Pass.h"
34 #include "mlir/Support/LLVM.h"
35 #include "mlir/Support/LogicalResult.h"
36 #include "mlir/Transforms/DialectConversion.h"
37 
38 namespace mlir {
39 namespace mhlo {
40 namespace {
41 
42 struct SortOpPattern : public OpConversionPattern<mhlo::SortOp> {
43   using OpConversionPattern<SortOp>::OpConversionPattern;
44 
45   // Create a loop for each dimension of the input. Finally, create the inner
46   // sorting loop and the inner scalar code. Track the indcution variables to be
47   // used by the scalar loop and return the result of the outermost loop being
48   // created by this (potentially recursive) call.
lowerToLoopsImplmlir::mhlo::__anon8a63941e0111::SortOpPattern49   static scf::ForOp lowerToLoopsImpl(OpBuilder& builder, mhlo::SortOp op,
50                                      OpAdaptor adaptor, unsigned loopDepth,
51                                      SmallVectorImpl<Value>& ivs,
52                                      ValueRange args) {
53     Location loc = op.getLoc();
54     if (loopDepth ==
55         op->getResultTypes().front().cast<TensorType>().getRank()) {
56       return generateScalarImplementation(op, adaptor, builder, ivs, args);
57     }
58 
59     auto lower = builder.create<arith::ConstantIndexOp>(op.getLoc(), 0);
60     auto upper = builder.create<tensor::DimOp>(
61         op.getLoc(), adaptor.operands().front(),
62         builder.create<arith::ConstantIndexOp>(op.getLoc(), loopDepth));
63     auto step = builder.create<arith::ConstantIndexOp>(op.getLoc(), 1);
64 
65     auto iterArgs = loopDepth ? args : adaptor.operands();
66     return builder.create<scf::ForOp>(
67         loc, lower, upper, step, iterArgs,
68         [&](OpBuilder& b, Location loc, Value iv, ValueRange argsPrime) {
69           ivs.push_back(iv);
70           auto result =
71               lowerToLoopsImpl(b, op, adaptor, loopDepth + 1, ivs, argsPrime);
72           b.create<scf::YieldOp>(loc, result.getResults());
73         });
74   }
75 
generateScalarImplementationmlir::mhlo::__anon8a63941e0111::SortOpPattern76   static scf::ForOp generateScalarImplementation(mhlo::SortOp op,
77                                                  OpAdaptor adaptor,
78                                                  OpBuilder& b, ValueRange ivs,
79                                                  ValueRange args) {
80     auto loc = op.getLoc();
81     auto sortDim = adaptor.dimension();
82     SmallVector<Value> indices, sortArgs;
83     indices.append(ivs.begin(), ivs.end());
84     // Bubble sort innermost loop.
85     Value zero = b.create<arith::ConstantIndexOp>(loc, 0);
86     Value one = b.create<arith::ConstantIndexOp>(loc, 1);
87     Value ub;
88 
89     auto firstOperandType =
90         adaptor.getOperands().front().getType().cast<TensorType>();
91     SmallVector<Value> results(args);
92     // Create inner most loop with one less iterations, so 1 can be added later.
93     if (firstOperandType.isDynamicDim(sortDim)) {
94       ub = b.create<tensor::DimOp>(loc, adaptor.getOperands().front(), sortDim);
95     } else {
96       ub = b.create<arith::ConstantIndexOp>(
97           loc, firstOperandType.getDimSize(sortDim));
98     }
99     ub = b.create<arith::SubIOp>(loc, ub, one);
100     auto& srcBlock = op.comparator().front();
101     auto scfFor = b.create<scf::ForOp>(
102         loc, zero, ub, one, args,
103         [&](OpBuilder& b, Location loc, Value iv, ValueRange args) {
104           // Extract and create tensors with relevant values to merge with the
105           // expected inputs to the original compare region of the mhlo.sort op.
106           SmallVector<Value> indices(ivs);
107           Value ivPlusOne = b.create<arith::AddIOp>(loc, iv, one);
108           for (const auto& idxAndOutput : llvm::enumerate(args)) {
109             indices[sortDim] = iv;
110             sortArgs.push_back(b.create<tensor::FromElementsOp>(
111                 loc, srcBlock.getArgumentTypes()[2 * idxAndOutput.index()],
112                 b.create<tensor::ExtractOp>(loc, idxAndOutput.value(), indices)
113                     .getResult()));
114             indices[sortDim] = ivPlusOne;
115             sortArgs.push_back(b.create<tensor::FromElementsOp>(
116                 loc, srcBlock.getArgumentTypes()[2 * idxAndOutput.index() + 1],
117                 b.create<tensor::ExtractOp>(loc, idxAndOutput.value(), indices)
118                     .getResult()));
119           }
120         });
121 
122     // Clone the region twice. to compare A,B and B,A
123     Region& region = scfFor.getRegion();
124     BlockAndValueMapping bvm, bvm2;
125     {
126       OpBuilder::InsertionGuard guard(b);
127       auto& block = region.front();
128       b.setInsertionPointToEnd(&block);
129       for (int64_t i = 0; i < srcBlock.getNumArguments(); i += 2) {
130         bvm.map(srcBlock.getArgument(i), sortArgs[i]);
131         bvm.map(srcBlock.getArgument(i + 1), sortArgs[i + 1]);
132 
133         bvm2.map(srcBlock.getArgument(i), sortArgs[i + 1]);
134         bvm2.map(srcBlock.getArgument(i + 1), sortArgs[i]);
135       }
136       for (auto& blockOp : srcBlock.without_terminator()) {
137         b.clone(blockOp, bvm2);
138       }
139       for (auto& blockOp : srcBlock.without_terminator()) {
140         b.clone(blockOp, bvm);
141       }
142     }
143 
144     // Determine if swapping should occur which happens only if NOT(CMP(A,B)) &&
145     // CMP(B,A).
146     OpBuilder::InsertionGuard g(b);
147     b.setInsertionPointToEnd(&region.front());
148     Value cond = b.create<tensor::ExtractOp>(
149         loc, bvm.lookupOrDefault(srcBlock.getTerminator()->getOperand(0)));
150     Value cond2 = b.create<tensor::ExtractOp>(
151         loc, bvm2.lookupOrDefault(srcBlock.getTerminator()->getOperand(0)));
152     Value negCond = b.create<arith::XOrIOp>(
153         loc, cond, b.create<arith::ConstantIntOp>(loc, 1, cond.getType()));
154     Value combined = b.create<arith::AndIOp>(loc, negCond, cond2);
155 
156     auto swapResult = b.create<scf::IfOp>(
157         loc, op->getResultTypes(), combined,
158         [&](OpBuilder& b, Location loc) {
159           SmallVector<Value> indices(ivs.begin(), ivs.end());
160           Value ivPlusOne =
161               b.create<arith::AddIOp>(loc, scfFor.getInductionVar(), one);
162           SmallVector<Value> swappedResults;
163           for (const auto& idxAndOutput :
164                llvm::enumerate(scfFor.getRegionIterArgs())) {
165             Value v1 = sortArgs[idxAndOutput.index() * 2];
166             Value v2 = sortArgs[idxAndOutput.index() * 2 + 1];
167             indices[sortDim] = scfFor.getInductionVar();
168             Value afterFirstInsert = b.create<tensor::InsertOp>(
169                 loc, b.create<tensor::ExtractOp>(loc, v2), idxAndOutput.value(),
170                 indices);
171             indices[sortDim] = ivPlusOne;
172             swappedResults.push_back(b.create<tensor::InsertOp>(
173                 loc, b.create<tensor::ExtractOp>(loc, v1), afterFirstInsert,
174                 indices));
175           }
176           b.create<scf::YieldOp>(loc, swappedResults);
177         },
178         [&](OpBuilder& b, Location loc) {
179           b.create<scf::YieldOp>(loc, scfFor.getRegionIterArgs());
180         });
181     b.create<scf::YieldOp>(loc, swapResult.getResults());
182     return scfFor;
183   }
184 
matchAndRewritemlir::mhlo::__anon8a63941e0111::SortOpPattern185   LogicalResult matchAndRewrite(
186       mhlo::SortOp op, OpAdaptor adaptor,
187       ConversionPatternRewriter& rewriter) const override {
188     SmallVector<Value> ivs;
189     auto scfFor = lowerToLoopsImpl(rewriter, op, adaptor, 0, ivs, {});
190     rewriter.replaceOp(op, scfFor.getResults());
191     return success();
192   }
193 };
194 
195 struct LegalizeSortPass : public HloLegalizeSortPassBase<LegalizeSortPass> {
196   // Perform the lowering to MLIR control flow.
runOnOperationmlir::mhlo::__anon8a63941e0111::LegalizeSortPass197   void runOnOperation() override {
198     func::FuncOp f = getOperation();
199     MLIRContext* ctx = f.getContext();
200 
201     RewritePatternSet patterns(&getContext());
202     patterns.add<SortOpPattern>(&getContext());
203 
204     mlir::ConversionTarget target(*ctx);
205     target.markUnknownOpDynamicallyLegal([](Operation*) { return true; });
206     target.addIllegalOp<mhlo::SortOp>();
207 
208     if (failed(applyPartialConversion(f, target, std::move(patterns)))) {
209       signalPassFailure();
210     }
211   }
212 };
213 
214 }  // namespace
215 }  // namespace mhlo
216 }  // namespace mlir
217 
218 std::unique_ptr<mlir::OperationPass<mlir::func::FuncOp>>
createLegalizeSortPass()219 mlir::mhlo::createLegalizeSortPass() {
220   return std::make_unique<LegalizeSortPass>();
221 }
222