1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // This file implements logic for lowering MHLO dialect to SCF dialect.
17 #include <utility>
18
19 #include "llvm/ADT/STLExtras.h"
20 #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
21 #include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h"
22 #include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
23 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
24 #include "mlir/Dialect/Func/IR/FuncOps.h"
25 #include "mlir/Dialect/SCF/IR/SCF.h"
26 #include "mlir/Dialect/Tensor/IR/Tensor.h" // TF:llvm-project
27 #include "mlir/IR/Block.h"
28 #include "mlir/IR/BlockAndValueMapping.h"
29 #include "mlir/IR/Builders.h"
30 #include "mlir/IR/BuiltinTypes.h"
31 #include "mlir/IR/PatternMatch.h"
32 #include "mlir/IR/TypeRange.h"
33 #include "mlir/Pass/Pass.h"
34 #include "mlir/Support/LLVM.h"
35 #include "mlir/Support/LogicalResult.h"
36 #include "mlir/Transforms/DialectConversion.h"
37
38 namespace mlir {
39 namespace mhlo {
40 namespace {
41
42 struct SortOpPattern : public OpConversionPattern<mhlo::SortOp> {
43 using OpConversionPattern<SortOp>::OpConversionPattern;
44
45 // Create a loop for each dimension of the input. Finally, create the inner
46 // sorting loop and the inner scalar code. Track the indcution variables to be
47 // used by the scalar loop and return the result of the outermost loop being
48 // created by this (potentially recursive) call.
lowerToLoopsImplmlir::mhlo::__anon8a63941e0111::SortOpPattern49 static scf::ForOp lowerToLoopsImpl(OpBuilder& builder, mhlo::SortOp op,
50 OpAdaptor adaptor, unsigned loopDepth,
51 SmallVectorImpl<Value>& ivs,
52 ValueRange args) {
53 Location loc = op.getLoc();
54 if (loopDepth ==
55 op->getResultTypes().front().cast<TensorType>().getRank()) {
56 return generateScalarImplementation(op, adaptor, builder, ivs, args);
57 }
58
59 auto lower = builder.create<arith::ConstantIndexOp>(op.getLoc(), 0);
60 auto upper = builder.create<tensor::DimOp>(
61 op.getLoc(), adaptor.operands().front(),
62 builder.create<arith::ConstantIndexOp>(op.getLoc(), loopDepth));
63 auto step = builder.create<arith::ConstantIndexOp>(op.getLoc(), 1);
64
65 auto iterArgs = loopDepth ? args : adaptor.operands();
66 return builder.create<scf::ForOp>(
67 loc, lower, upper, step, iterArgs,
68 [&](OpBuilder& b, Location loc, Value iv, ValueRange argsPrime) {
69 ivs.push_back(iv);
70 auto result =
71 lowerToLoopsImpl(b, op, adaptor, loopDepth + 1, ivs, argsPrime);
72 b.create<scf::YieldOp>(loc, result.getResults());
73 });
74 }
75
generateScalarImplementationmlir::mhlo::__anon8a63941e0111::SortOpPattern76 static scf::ForOp generateScalarImplementation(mhlo::SortOp op,
77 OpAdaptor adaptor,
78 OpBuilder& b, ValueRange ivs,
79 ValueRange args) {
80 auto loc = op.getLoc();
81 auto sortDim = adaptor.dimension();
82 SmallVector<Value> indices, sortArgs;
83 indices.append(ivs.begin(), ivs.end());
84 // Bubble sort innermost loop.
85 Value zero = b.create<arith::ConstantIndexOp>(loc, 0);
86 Value one = b.create<arith::ConstantIndexOp>(loc, 1);
87 Value ub;
88
89 auto firstOperandType =
90 adaptor.getOperands().front().getType().cast<TensorType>();
91 SmallVector<Value> results(args);
92 // Create inner most loop with one less iterations, so 1 can be added later.
93 if (firstOperandType.isDynamicDim(sortDim)) {
94 ub = b.create<tensor::DimOp>(loc, adaptor.getOperands().front(), sortDim);
95 } else {
96 ub = b.create<arith::ConstantIndexOp>(
97 loc, firstOperandType.getDimSize(sortDim));
98 }
99 ub = b.create<arith::SubIOp>(loc, ub, one);
100 auto& srcBlock = op.comparator().front();
101 auto scfFor = b.create<scf::ForOp>(
102 loc, zero, ub, one, args,
103 [&](OpBuilder& b, Location loc, Value iv, ValueRange args) {
104 // Extract and create tensors with relevant values to merge with the
105 // expected inputs to the original compare region of the mhlo.sort op.
106 SmallVector<Value> indices(ivs);
107 Value ivPlusOne = b.create<arith::AddIOp>(loc, iv, one);
108 for (const auto& idxAndOutput : llvm::enumerate(args)) {
109 indices[sortDim] = iv;
110 sortArgs.push_back(b.create<tensor::FromElementsOp>(
111 loc, srcBlock.getArgumentTypes()[2 * idxAndOutput.index()],
112 b.create<tensor::ExtractOp>(loc, idxAndOutput.value(), indices)
113 .getResult()));
114 indices[sortDim] = ivPlusOne;
115 sortArgs.push_back(b.create<tensor::FromElementsOp>(
116 loc, srcBlock.getArgumentTypes()[2 * idxAndOutput.index() + 1],
117 b.create<tensor::ExtractOp>(loc, idxAndOutput.value(), indices)
118 .getResult()));
119 }
120 });
121
122 // Clone the region twice. to compare A,B and B,A
123 Region& region = scfFor.getRegion();
124 BlockAndValueMapping bvm, bvm2;
125 {
126 OpBuilder::InsertionGuard guard(b);
127 auto& block = region.front();
128 b.setInsertionPointToEnd(&block);
129 for (int64_t i = 0; i < srcBlock.getNumArguments(); i += 2) {
130 bvm.map(srcBlock.getArgument(i), sortArgs[i]);
131 bvm.map(srcBlock.getArgument(i + 1), sortArgs[i + 1]);
132
133 bvm2.map(srcBlock.getArgument(i), sortArgs[i + 1]);
134 bvm2.map(srcBlock.getArgument(i + 1), sortArgs[i]);
135 }
136 for (auto& blockOp : srcBlock.without_terminator()) {
137 b.clone(blockOp, bvm2);
138 }
139 for (auto& blockOp : srcBlock.without_terminator()) {
140 b.clone(blockOp, bvm);
141 }
142 }
143
144 // Determine if swapping should occur which happens only if NOT(CMP(A,B)) &&
145 // CMP(B,A).
146 OpBuilder::InsertionGuard g(b);
147 b.setInsertionPointToEnd(®ion.front());
148 Value cond = b.create<tensor::ExtractOp>(
149 loc, bvm.lookupOrDefault(srcBlock.getTerminator()->getOperand(0)));
150 Value cond2 = b.create<tensor::ExtractOp>(
151 loc, bvm2.lookupOrDefault(srcBlock.getTerminator()->getOperand(0)));
152 Value negCond = b.create<arith::XOrIOp>(
153 loc, cond, b.create<arith::ConstantIntOp>(loc, 1, cond.getType()));
154 Value combined = b.create<arith::AndIOp>(loc, negCond, cond2);
155
156 auto swapResult = b.create<scf::IfOp>(
157 loc, op->getResultTypes(), combined,
158 [&](OpBuilder& b, Location loc) {
159 SmallVector<Value> indices(ivs.begin(), ivs.end());
160 Value ivPlusOne =
161 b.create<arith::AddIOp>(loc, scfFor.getInductionVar(), one);
162 SmallVector<Value> swappedResults;
163 for (const auto& idxAndOutput :
164 llvm::enumerate(scfFor.getRegionIterArgs())) {
165 Value v1 = sortArgs[idxAndOutput.index() * 2];
166 Value v2 = sortArgs[idxAndOutput.index() * 2 + 1];
167 indices[sortDim] = scfFor.getInductionVar();
168 Value afterFirstInsert = b.create<tensor::InsertOp>(
169 loc, b.create<tensor::ExtractOp>(loc, v2), idxAndOutput.value(),
170 indices);
171 indices[sortDim] = ivPlusOne;
172 swappedResults.push_back(b.create<tensor::InsertOp>(
173 loc, b.create<tensor::ExtractOp>(loc, v1), afterFirstInsert,
174 indices));
175 }
176 b.create<scf::YieldOp>(loc, swappedResults);
177 },
178 [&](OpBuilder& b, Location loc) {
179 b.create<scf::YieldOp>(loc, scfFor.getRegionIterArgs());
180 });
181 b.create<scf::YieldOp>(loc, swapResult.getResults());
182 return scfFor;
183 }
184
matchAndRewritemlir::mhlo::__anon8a63941e0111::SortOpPattern185 LogicalResult matchAndRewrite(
186 mhlo::SortOp op, OpAdaptor adaptor,
187 ConversionPatternRewriter& rewriter) const override {
188 SmallVector<Value> ivs;
189 auto scfFor = lowerToLoopsImpl(rewriter, op, adaptor, 0, ivs, {});
190 rewriter.replaceOp(op, scfFor.getResults());
191 return success();
192 }
193 };
194
195 struct LegalizeSortPass : public HloLegalizeSortPassBase<LegalizeSortPass> {
196 // Perform the lowering to MLIR control flow.
runOnOperationmlir::mhlo::__anon8a63941e0111::LegalizeSortPass197 void runOnOperation() override {
198 func::FuncOp f = getOperation();
199 MLIRContext* ctx = f.getContext();
200
201 RewritePatternSet patterns(&getContext());
202 patterns.add<SortOpPattern>(&getContext());
203
204 mlir::ConversionTarget target(*ctx);
205 target.markUnknownOpDynamicallyLegal([](Operation*) { return true; });
206 target.addIllegalOp<mhlo::SortOp>();
207
208 if (failed(applyPartialConversion(f, target, std::move(patterns)))) {
209 signalPassFailure();
210 }
211 }
212 };
213
214 } // namespace
215 } // namespace mhlo
216 } // namespace mlir
217
218 std::unique_ptr<mlir::OperationPass<mlir::func::FuncOp>>
createLegalizeSortPass()219 mlir::mhlo::createLegalizeSortPass() {
220 return std::make_unique<LegalizeSortPass>();
221 }
222