1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <algorithm>
17 #include <utility>
18 
19 #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
20 #include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h"
21 #include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
22 #include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
23 #include "mlir/Dialect/Func/IR/FuncOps.h"
24 #include "mlir/IR/BuiltinOps.h"
25 #include "mlir/Pass/Pass.h"
26 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
27 
28 namespace mlir {
29 namespace mhlo {
30 namespace {
31 
32 struct EinsumToDotGeneralPattern : public OpRewritePattern<EinsumOp> {
33   using OpRewritePattern<EinsumOp>::OpRewritePattern;
34 
matchAndRewritemlir::mhlo::__anon91ffa24c0111::EinsumToDotGeneralPattern35   LogicalResult matchAndRewrite(EinsumOp einsum,
36                                 PatternRewriter &rewriter) const override {
37     StringRef equation = einsum.einsum_config();
38     SmallVector<char> lhsTokens, rhsTokens;
39     SmallVector<char> resultTokens;
40     size_t index = 0;
41     enum EquationVariable { kIsLhs, kIsRhs, kIsResult };
42     EquationVariable currentVariable = kIsLhs;
43     while (index < equation.size()) {
44       if (std::isalpha(equation[index])) {
45         if (currentVariable == kIsLhs) {
46           lhsTokens.push_back(equation[index]);
47         } else if (currentVariable == kIsRhs) {
48           rhsTokens.push_back(equation[index]);
49         } else {
50           resultTokens.push_back(equation[index]);
51         }
52       } else if (equation.substr(index, 1).contains(",")) {
53         currentVariable = kIsRhs;
54       } else if ((index < (equation.size() - 1)) &&
55                  (equation.substr(index, 2).contains("->"))) {
56         currentVariable = kIsResult;
57         index++;
58       } else {
59         return einsum.emitError("unexpected character ")
60                << equation.substr(index, 1) << " encountered";
61       }
62       index++;
63     }
64 
65     auto lhsType = einsum.lhs().getType().cast<RankedTensorType>();
66     auto rhsType = einsum.rhs().getType().cast<RankedTensorType>();
67     assert(static_cast<int64_t>(lhsTokens.size()) == lhsType.getRank());
68     assert(static_cast<int64_t>(rhsTokens.size()) == rhsType.getRank());
69 
70     auto collectOperandDims =
71         [resultTokens](
72             RankedTensorType operandType, SmallVector<char> operandTokens,
73             SmallVector<char> others, SmallVectorImpl<int64_t> &contractingDims,
74             SmallVectorImpl<int64_t> &batchingDims,
75             SmallVector<char> &dotResultTokens,
76             SmallVector<int64_t> &dotResultShape) {
77           llvm::SmallDenseSet<char> othersSet(others.begin(), others.end());
78           llvm::SmallDenseSet<char> resultTokensSet(resultTokens.begin(),
79                                                     resultTokens.end());
80           for (const auto &en : llvm::enumerate(operandTokens)) {
81             bool isResultToken = resultTokensSet.contains(en.value());
82             bool isOtherToken = othersSet.contains(en.value());
83 
84             if (!isResultToken) {
85               contractingDims.push_back(en.index());
86             } else if (isOtherToken) {
87               batchingDims.push_back(en.index());
88             } else {
89               dotResultTokens.push_back(en.value());
90               dotResultShape.push_back(operandType.getShape()[en.index()]);
91             }
92           }
93         };
94     // Indices of batch and contracting dims, relative to each operand's
95     // dimensions.
96     SmallVector<int64_t> lhsContractingDims, lhsBatchingDims,
97         rhsContractingDims, rhsBatchingDims;
98     // Tokens representing the natural order of the dot_general op (i.e.
99     // the lhs non-contracting followed by rhs non-contracting tokens).
100     SmallVector<char> dotResultTokens;
101     SmallVector<int64_t> dotResultShape;
102 
103     collectOperandDims(lhsType, lhsTokens, rhsTokens, lhsContractingDims,
104                        lhsBatchingDims, dotResultTokens, dotResultShape);
105     collectOperandDims(rhsType, rhsTokens, lhsTokens, rhsContractingDims,
106                        rhsBatchingDims, dotResultTokens, dotResultShape);
107 
108     // Prepend batch tokens.
109     for (const auto &it : llvm::enumerate(lhsBatchingDims)) {
110       char batchingToken = lhsTokens[it.value()];
111       int64_t batchingShapeDim = lhsType.getShape()[it.value()];
112       dotResultTokens.insert(dotResultTokens.begin() + it.index(),
113                              batchingToken);
114       dotResultShape.insert(dotResultShape.begin() + it.index(),
115                             batchingShapeDim);
116     }
117 
118     // Lowering to dot_general does not support a mismatch between the number
119     // of result dims and the number of non-contracting dims.
120     if (dotResultTokens.size() != resultTokens.size()) {
121       return rewriter.notifyMatchFailure(einsum,
122                                          "rank reducing einsum not supported");
123     }
124 
125     // Generate a permutation sequence based on result tokens.
126     SmallVector<int64_t> resultPerms;
127     bool isNaturalOrder = true;
128     for (char resultToken : resultTokens) {
129       auto *foundIt = std::find(dotResultTokens.begin(), dotResultTokens.end(),
130                                 resultToken);
131       if (foundIt == dotResultTokens.end()) {
132         return rewriter.notifyMatchFailure(
133             einsum, "result token not found in operands");
134       }
135       auto resultIndex = std::distance(dotResultTokens.begin(), foundIt);
136       if (resultPerms.empty()) {
137         if (resultIndex != 0) {
138           isNaturalOrder = false;
139         }
140       } else if (resultIndex != (resultPerms.back() + 1)) {
141         isNaturalOrder = false;
142       }
143       resultPerms.push_back(resultIndex);
144     }
145 
146     // Emit the dot_general, using its native result ordering.
147     auto dotGeneralResultType = RankedTensorType::get(
148         ArrayRef<int64_t>(dotResultShape), lhsType.getElementType());
149     auto dimNumbers = mhlo::DotDimensionNumbersAttr::get(
150         rewriter.getContext(), lhsBatchingDims, rhsBatchingDims,
151         lhsContractingDims, rhsContractingDims);
152     auto dotGeneralOp =
153         rewriter.create<DotGeneralOp>(einsum.getLoc(), dotGeneralResultType,
154                                       einsum.lhs(), einsum.rhs(), dimNumbers,
155                                       /*precision_config=*/ArrayAttr{});
156 
157     if (isNaturalOrder) {
158       // The dot_general is already in an appropriate result order.
159       rewriter.replaceOp(einsum, ValueRange{dotGeneralOp});
160     } else {
161       // Generate a transpose.
162       rewriter.replaceOpWithNewOp<TransposeOp>(
163           einsum, dotGeneralOp, rewriter.getI64TensorAttr(resultPerms));
164     }
165     return success();
166   }
167 };
168 
169 struct LegalizeEinsumToDotGeneralPass
170     : public LegalizeEinsumToDotGeneralPassBase<
171           LegalizeEinsumToDotGeneralPass> {
runOnOperationmlir::mhlo::__anon91ffa24c0111::LegalizeEinsumToDotGeneralPass172   void runOnOperation() override {
173     RewritePatternSet patterns(&getContext());
174     populateEinsumToDotGeneralPatterns(&getContext(), &patterns);
175     if (failed(applyPatternsAndFoldGreedily(getOperation(),
176                                             std::move(patterns)))) {
177       return signalPassFailure();
178     }
179   }
180 };
181 }  // namespace
182 
populateEinsumToDotGeneralPatterns(mlir::MLIRContext * context,RewritePatternSet * patterns)183 void populateEinsumToDotGeneralPatterns(mlir::MLIRContext *context,
184                                         RewritePatternSet *patterns) {
185   patterns->add<EinsumToDotGeneralPattern>(context);
186 }
187 
188 std::unique_ptr<OperationPass<func::FuncOp>>
createLegalizeEinsumToDotGeneralPass()189 createLegalizeEinsumToDotGeneralPass() {
190   return std::make_unique<LegalizeEinsumToDotGeneralPass>();
191 }
192 
193 }  // namespace mhlo
194 }  // namespace mlir
195