1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <algorithm>
17 #include <utility>
18
19 #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
20 #include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h"
21 #include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
22 #include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
23 #include "mlir/Dialect/Func/IR/FuncOps.h"
24 #include "mlir/IR/BuiltinOps.h"
25 #include "mlir/Pass/Pass.h"
26 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
27
28 namespace mlir {
29 namespace mhlo {
30 namespace {
31
32 struct EinsumToDotGeneralPattern : public OpRewritePattern<EinsumOp> {
33 using OpRewritePattern<EinsumOp>::OpRewritePattern;
34
matchAndRewritemlir::mhlo::__anon91ffa24c0111::EinsumToDotGeneralPattern35 LogicalResult matchAndRewrite(EinsumOp einsum,
36 PatternRewriter &rewriter) const override {
37 StringRef equation = einsum.einsum_config();
38 SmallVector<char> lhsTokens, rhsTokens;
39 SmallVector<char> resultTokens;
40 size_t index = 0;
41 enum EquationVariable { kIsLhs, kIsRhs, kIsResult };
42 EquationVariable currentVariable = kIsLhs;
43 while (index < equation.size()) {
44 if (std::isalpha(equation[index])) {
45 if (currentVariable == kIsLhs) {
46 lhsTokens.push_back(equation[index]);
47 } else if (currentVariable == kIsRhs) {
48 rhsTokens.push_back(equation[index]);
49 } else {
50 resultTokens.push_back(equation[index]);
51 }
52 } else if (equation.substr(index, 1).contains(",")) {
53 currentVariable = kIsRhs;
54 } else if ((index < (equation.size() - 1)) &&
55 (equation.substr(index, 2).contains("->"))) {
56 currentVariable = kIsResult;
57 index++;
58 } else {
59 return einsum.emitError("unexpected character ")
60 << equation.substr(index, 1) << " encountered";
61 }
62 index++;
63 }
64
65 auto lhsType = einsum.lhs().getType().cast<RankedTensorType>();
66 auto rhsType = einsum.rhs().getType().cast<RankedTensorType>();
67 assert(static_cast<int64_t>(lhsTokens.size()) == lhsType.getRank());
68 assert(static_cast<int64_t>(rhsTokens.size()) == rhsType.getRank());
69
70 auto collectOperandDims =
71 [resultTokens](
72 RankedTensorType operandType, SmallVector<char> operandTokens,
73 SmallVector<char> others, SmallVectorImpl<int64_t> &contractingDims,
74 SmallVectorImpl<int64_t> &batchingDims,
75 SmallVector<char> &dotResultTokens,
76 SmallVector<int64_t> &dotResultShape) {
77 llvm::SmallDenseSet<char> othersSet(others.begin(), others.end());
78 llvm::SmallDenseSet<char> resultTokensSet(resultTokens.begin(),
79 resultTokens.end());
80 for (const auto &en : llvm::enumerate(operandTokens)) {
81 bool isResultToken = resultTokensSet.contains(en.value());
82 bool isOtherToken = othersSet.contains(en.value());
83
84 if (!isResultToken) {
85 contractingDims.push_back(en.index());
86 } else if (isOtherToken) {
87 batchingDims.push_back(en.index());
88 } else {
89 dotResultTokens.push_back(en.value());
90 dotResultShape.push_back(operandType.getShape()[en.index()]);
91 }
92 }
93 };
94 // Indices of batch and contracting dims, relative to each operand's
95 // dimensions.
96 SmallVector<int64_t> lhsContractingDims, lhsBatchingDims,
97 rhsContractingDims, rhsBatchingDims;
98 // Tokens representing the natural order of the dot_general op (i.e.
99 // the lhs non-contracting followed by rhs non-contracting tokens).
100 SmallVector<char> dotResultTokens;
101 SmallVector<int64_t> dotResultShape;
102
103 collectOperandDims(lhsType, lhsTokens, rhsTokens, lhsContractingDims,
104 lhsBatchingDims, dotResultTokens, dotResultShape);
105 collectOperandDims(rhsType, rhsTokens, lhsTokens, rhsContractingDims,
106 rhsBatchingDims, dotResultTokens, dotResultShape);
107
108 // Prepend batch tokens.
109 for (const auto &it : llvm::enumerate(lhsBatchingDims)) {
110 char batchingToken = lhsTokens[it.value()];
111 int64_t batchingShapeDim = lhsType.getShape()[it.value()];
112 dotResultTokens.insert(dotResultTokens.begin() + it.index(),
113 batchingToken);
114 dotResultShape.insert(dotResultShape.begin() + it.index(),
115 batchingShapeDim);
116 }
117
118 // Lowering to dot_general does not support a mismatch between the number
119 // of result dims and the number of non-contracting dims.
120 if (dotResultTokens.size() != resultTokens.size()) {
121 return rewriter.notifyMatchFailure(einsum,
122 "rank reducing einsum not supported");
123 }
124
125 // Generate a permutation sequence based on result tokens.
126 SmallVector<int64_t> resultPerms;
127 bool isNaturalOrder = true;
128 for (char resultToken : resultTokens) {
129 auto *foundIt = std::find(dotResultTokens.begin(), dotResultTokens.end(),
130 resultToken);
131 if (foundIt == dotResultTokens.end()) {
132 return rewriter.notifyMatchFailure(
133 einsum, "result token not found in operands");
134 }
135 auto resultIndex = std::distance(dotResultTokens.begin(), foundIt);
136 if (resultPerms.empty()) {
137 if (resultIndex != 0) {
138 isNaturalOrder = false;
139 }
140 } else if (resultIndex != (resultPerms.back() + 1)) {
141 isNaturalOrder = false;
142 }
143 resultPerms.push_back(resultIndex);
144 }
145
146 // Emit the dot_general, using its native result ordering.
147 auto dotGeneralResultType = RankedTensorType::get(
148 ArrayRef<int64_t>(dotResultShape), lhsType.getElementType());
149 auto dimNumbers = mhlo::DotDimensionNumbersAttr::get(
150 rewriter.getContext(), lhsBatchingDims, rhsBatchingDims,
151 lhsContractingDims, rhsContractingDims);
152 auto dotGeneralOp =
153 rewriter.create<DotGeneralOp>(einsum.getLoc(), dotGeneralResultType,
154 einsum.lhs(), einsum.rhs(), dimNumbers,
155 /*precision_config=*/ArrayAttr{});
156
157 if (isNaturalOrder) {
158 // The dot_general is already in an appropriate result order.
159 rewriter.replaceOp(einsum, ValueRange{dotGeneralOp});
160 } else {
161 // Generate a transpose.
162 rewriter.replaceOpWithNewOp<TransposeOp>(
163 einsum, dotGeneralOp, rewriter.getI64TensorAttr(resultPerms));
164 }
165 return success();
166 }
167 };
168
169 struct LegalizeEinsumToDotGeneralPass
170 : public LegalizeEinsumToDotGeneralPassBase<
171 LegalizeEinsumToDotGeneralPass> {
runOnOperationmlir::mhlo::__anon91ffa24c0111::LegalizeEinsumToDotGeneralPass172 void runOnOperation() override {
173 RewritePatternSet patterns(&getContext());
174 populateEinsumToDotGeneralPatterns(&getContext(), &patterns);
175 if (failed(applyPatternsAndFoldGreedily(getOperation(),
176 std::move(patterns)))) {
177 return signalPassFailure();
178 }
179 }
180 };
181 } // namespace
182
populateEinsumToDotGeneralPatterns(mlir::MLIRContext * context,RewritePatternSet * patterns)183 void populateEinsumToDotGeneralPatterns(mlir::MLIRContext *context,
184 RewritePatternSet *patterns) {
185 patterns->add<EinsumToDotGeneralPattern>(context);
186 }
187
188 std::unique_ptr<OperationPass<func::FuncOp>>
createLegalizeEinsumToDotGeneralPass()189 createLegalizeEinsumToDotGeneralPass() {
190 return std::make_unique<LegalizeEinsumToDotGeneralPass>();
191 }
192
193 } // namespace mhlo
194 } // namespace mlir
195