1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // This file implements logic for lowering HLO/LHLO dialect to Linalg dialect.
17
18 #include <algorithm>
19 #include <cstdint>
20 #include <numeric>
21 #include <string>
22 #include <utility>
23
24 #include "llvm/ADT/ArrayRef.h"
25 #include "llvm/ADT/BitVector.h"
26 #include "llvm/ADT/STLExtras.h"
27 #include "llvm/ADT/SetVector.h"
28 #include "llvm/ADT/SmallVector.h"
29 #include "llvm/ADT/iterator_range.h"
30 #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
31 #include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h"
32 #include "mlir-hlo/Dialect/mhlo/transforms/legalize_to_linalg_utils.h"
33 #include "mlir-hlo/Dialect/mhlo/transforms/map_mhlo_to_scalar_op.h"
34 #include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
35 #include "mlir-hlo/Dialect/mhlo/transforms/type_conversion.h"
36 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
37 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
38 #include "mlir/Dialect/Complex/IR/Complex.h"
39 #include "mlir/Dialect/Func/IR/FuncOps.h"
40 #include "mlir/Dialect/Linalg/IR/Linalg.h"
41 #include "mlir/Dialect/Math/IR/Math.h"
42 #include "mlir/Dialect/MemRef/IR/MemRef.h"
43 #include "mlir/Dialect/SCF/IR/SCF.h"
44 #include "mlir/Dialect/Shape/IR/Shape.h"
45 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
46 #include "mlir/Dialect/Tensor/IR/Tensor.h"
47 #include "mlir/Dialect/Tensor/Utils/Utils.h"
48 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
49 #include "mlir/IR/AffineExpr.h"
50 #include "mlir/IR/Attributes.h"
51 #include "mlir/IR/BlockAndValueMapping.h"
52 #include "mlir/IR/Builders.h"
53 #include "mlir/IR/BuiltinAttributes.h"
54 #include "mlir/IR/BuiltinOps.h"
55 #include "mlir/IR/BuiltinTypes.h"
56 #include "mlir/IR/Location.h"
57 #include "mlir/IR/MLIRContext.h"
58 #include "mlir/IR/Operation.h"
59 #include "mlir/IR/OperationSupport.h"
60 #include "mlir/IR/PatternMatch.h"
61 #include "mlir/IR/TypeUtilities.h"
62 #include "mlir/Pass/Pass.h"
63 #include "mlir/Support/LLVM.h"
64 #include "mlir/Support/LogicalResult.h"
65 #include "mlir/Transforms/DialectConversion.h"
66
67 namespace mlir {
68 namespace mhlo {
69 namespace {
70
getResultValue(Operation * op)71 Value getResultValue(Operation* op) { return op->getResult(0); }
72
getHloOpResultType(Operation * op)73 ShapedType getHloOpResultType(Operation* op) {
74 return getResultValue(op).getType().cast<ShapedType>();
75 }
76
verifyHloOpBufferOrTensorSemantics(Operation * op)77 bool verifyHloOpBufferOrTensorSemantics(Operation* op) {
78 auto verifyType = [&](Value val) -> bool {
79 return val.getType().isa<RankedTensorType>();
80 };
81 if (!llvm::all_of(op->getOperands(), verifyType)) return false;
82 return llvm::all_of(op->getResults(), verifyType);
83 }
84
fillTensorWithZeros(OpBuilder & builder,Location loc,Value tensor)85 Value fillTensorWithZeros(OpBuilder& builder, Location loc, Value tensor) {
86 auto type = tensor.getType().cast<ShapedType>();
87 Value zero;
88 // Complex numbers are a special case.
89 if (auto complexType = type.getElementType().dyn_cast<ComplexType>()) {
90 auto zeroElement = builder.getZeroAttr(complexType.getElementType());
91 auto zeroAttr = builder.getArrayAttr({zeroElement, zeroElement});
92 zero = builder.create<complex::ConstantOp>(loc, complexType, zeroAttr);
93 } else {
94 auto zeroAttr = builder.getZeroAttr(type.getElementType());
95 zero = builder.create<arith::ConstantOp>(loc, zeroAttr);
96 }
97 return builder.create<linalg::FillOp>(loc, zero, tensor).result();
98 }
99
extract1DVector(DenseIntElementsAttr elements)100 SmallVector<int64_t, 4> extract1DVector(DenseIntElementsAttr elements) {
101 SmallVector<int64_t, 4> ret;
102 for (const APInt& element : elements) {
103 ret.push_back(element.getLimitedValue());
104 }
105 return ret;
106 }
107
108 /// Returns a permutation AffineMap that puts all reduction dimensions to the
109 /// last. The order of parallel loops and reduction loops are all sorted. E.g.,
110 /// if `rank` is 4 and `reductionDims` is {1, 3}, then
111 /// "(d0, d1, d2, d3) -> (d0, d2, d1, d3)" is used. The inverse permutation of
112 /// the AffineMap is returned.
getTransposeMapForReduction(MLIRContext * context,int rank,ArrayRef<int64_t> reductionDims)113 AffineMap getTransposeMapForReduction(MLIRContext* context, int rank,
114 ArrayRef<int64_t> reductionDims) {
115 llvm::SmallSetVector<int, 4> s;
116 for (auto dim : reductionDims) s.insert(dim);
117
118 SmallVector<unsigned, 4> permutation;
119 for (int i = 0; i < rank; ++i)
120 if (!s.count(i)) permutation.push_back(i);
121 for (auto dim : reductionDims) permutation.push_back(dim);
122
123 auto map = AffineMap::getPermutationMap(permutation, context);
124 return inversePermutation(map);
125 }
126
127 /// Returns true if the given `attr` is a splat of the given `value`.
isSplatValue(DenseIntElementsAttr attr,uint64_t value)128 bool isSplatValue(DenseIntElementsAttr attr, uint64_t value) {
129 return attr.isSplat() && attr.getSplatValue<uint64_t>() == value;
130 }
131
132 /// Returns true if the given `dimensionNumbers` from a mhlo.convolution op
133 /// follows a canonical form:
134 ///
135 /// * Input dimensions have order: (batch_count, spatial_dims,
136 /// input_channel_count).
137 /// * Filter dimensions have order: (spatial_dims, input_channel_count,
138 /// output_channel_count).
139 /// * Output dimensions have order: (batch_count, spatial_dims,
140 /// output_channel_count).
hasCanonicalDimensionNumbers(mhlo::ConvDimensionNumbersAttr dimensionNumbers)141 static bool hasCanonicalDimensionNumbers(
142 mhlo::ConvDimensionNumbersAttr dimensionNumbers) {
143 const int inputSpatialRank =
144 llvm::size(dimensionNumbers.getInputSpatialDimensions());
145 // The dimensions for input should follow the order of
146 // batch_count, spatial_dims..., input_feature_count.
147 if (dimensionNumbers.getInputBatchDimension() != 0 ||
148 dimensionNumbers.getInputFeatureDimension() != (inputSpatialRank + 1)) {
149 return false;
150 }
151
152 const int kernelSpatialRank =
153 llvm::size(dimensionNumbers.getKernelSpatialDimensions());
154 // The dimensions for filter should follow the order of
155 // spatial_dims..., input_feature_count, num_output_feature_count.
156 if (dimensionNumbers.getKernelInputFeatureDimension() != kernelSpatialRank ||
157 dimensionNumbers.getKernelOutputFeatureDimension() !=
158 (kernelSpatialRank + 1)) {
159 return false;
160 }
161
162 const int outputSpatialRank =
163 llvm::size(dimensionNumbers.getOutputSpatialDimensions());
164 // The dimensions for output should follow the order of
165 // batch_count, spatial_dims.., output_feature_count.
166 if (dimensionNumbers.getOutputBatchDimension() != 0 ||
167 dimensionNumbers.getOutputFeatureDimension() != (outputSpatialRank + 1)) {
168 return false;
169 }
170
171 if (inputSpatialRank != outputSpatialRank ||
172 inputSpatialRank != kernelSpatialRank) {
173 return false;
174 }
175
176 const auto* inputSpatialDim =
177 dimensionNumbers.getInputSpatialDimensions().begin();
178 const auto* kernelSpatialDim =
179 dimensionNumbers.getKernelSpatialDimensions().begin();
180 const auto* outputSpatialDim =
181 dimensionNumbers.getOutputSpatialDimensions().begin();
182 // Check spatial dims are ordered correctly.
183 for (int i = 0; i < inputSpatialRank; ++i) {
184 const int dim = i + 1;
185 if ((*inputSpatialDim++) != dim || (*outputSpatialDim++) != dim ||
186 (*kernelSpatialDim++) != i) {
187 return false;
188 }
189 }
190
191 return true;
192 }
193
194 //===----------------------------------------------------------------------===//
195 // mhlo.RngOp conversion patterns.
196 //===----------------------------------------------------------------------===//
197
198 // Pass to lower from rng to stateless pseudo RNG with LCG
199 // algorithm
200 struct RngUniformConversion : public OpConversionPattern<mhlo::RngOp> {
201 using OpConversionPattern<mhlo::RngOp>::OpConversionPattern;
202
matchAndRewritemlir::mhlo::__anon071fe3ef0111::RngUniformConversion203 LogicalResult matchAndRewrite(
204 mhlo::RngOp op, OpAdaptor adaptor,
205 ConversionPatternRewriter& rewriter) const final {
206 // We only handle uniform distributions
207 if (op.rng_distribution() != ::mlir::mhlo::RngDistribution::UNIFORM) {
208 return failure();
209 }
210 // TODO(raikonenfnu): Handle other element types as well.
211 auto minTy = adaptor.getOperands()[0].getType().dyn_cast<ShapedType>();
212 auto maxTy = adaptor.getOperands()[0].getType().dyn_cast<ShapedType>();
213 if (!minTy.getElementType().dyn_cast<FloatType>() ||
214 !maxTy.getElementType().dyn_cast<FloatType>()) {
215 return rewriter.notifyMatchFailure(
216 op, "expected min/max for rng op to be FloatType");
217 }
218 auto targetTy = this->typeConverter->convertType(op.getResult().getType())
219 .cast<ShapedType>();
220 if (!targetTy) {
221 return rewriter.notifyMatchFailure(
222 op, "expected target shape of rng op to be ShapedType");
223 }
224 auto loc = op.getLoc();
225 Value initTensor =
226 getInitTensorFor(rewriter, loc, targetTy, op, adaptor.getOperands());
227 // Creates index map using target matrix's rank.
228 auto targetRank = targetTy.getRank();
229 SmallVector<AffineMap, 3> indexingMaps(
230 2, AffineMap::get(targetRank, /*symbolCount=*/0,
231 SmallVector<AffineExpr>({}), rewriter.getContext()));
232 indexingMaps.push_back(rewriter.getMultiDimIdentityMap(targetRank));
233 const int kInitialSeed = 0;
234 // Generic region with LCG Algorithm that make use of element index from:
235 // https://reviews.llvm.org/D101364
236 auto linalgOp = rewriter.create<linalg::GenericOp>(
237 loc, /*resultTensors=*/targetTy,
238 /*inputs=*/
239 ValueRange{adaptor.getOperands()[0], adaptor.getOperands()[1]},
240 /*outputs=*/initTensor, indexingMaps,
241 getParallelAndReductionIterators(/*nLoops=*/targetRank,
242 /*nReduction=*/0),
243 [&](OpBuilder& b, Location loc, ValueRange args) {
244 llvm::SmallVector<Value> updateVec = {b.create<arith::ConstantOp>(
245 loc, b.getI32IntegerAttr(kInitialSeed))};
246 Value multiplier =
247 b.create<arith::ConstantOp>(loc, b.getI32IntegerAttr(1103515245));
248 Value incrementStep =
249 b.create<arith::ConstantOp>(loc, b.getI32IntegerAttr(12345));
250 // For output matrix with rank N:
251 // temp1 = (cast(I32, index(D.0)) + seed) * mult + incr
252 // ...
253 // tempN = (cast(I32, index(D.(N))) + tempN_1) * mult + incr
254 for (int i = 0; i < targetRank; i++) {
255 Value update = updateVec.back();
256 Value ind = b.create<linalg::IndexOp>(loc, i);
257 Value castInd =
258 b.create<arith::IndexCastOp>(loc, b.getI32Type(), ind);
259 Value addRes = b.create<arith::AddIOp>(loc, castInd, update);
260 Value multRes = b.create<arith::MulIOp>(loc, addRes, multiplier);
261 Value incRes = b.create<arith::AddIOp>(loc, multRes, incrementStep);
262 updateVec.push_back(incRes);
263 }
264 // Scaling = (max - min) * const(F64, 2.3283064E-10)
265 // which is derived from rand(min,max) = rand()/(RAND_MAX/(max-min)).
266 Value epsilon = b.create<arith::ConstantOp>(
267 loc, b.getFloatAttr(args[0].getType(), 2.3283064E-10));
268 Value range = b.create<arith::SubFOp>(loc, args[1], args[0]);
269 Value scale = b.create<arith::MulFOp>(loc, range, epsilon);
270 // Res = cast(T, cast(F64, tempN) * scaling + min)
271 Value updateCast = b.create<arith::UIToFPOp>(
272 loc, targetTy.getElementType(), updateVec.back());
273 Value scaleUpdate = b.create<arith::MulFOp>(loc, updateCast, scale);
274 Value res = b.create<arith::AddFOp>(loc, scaleUpdate, args[0]);
275 b.create<linalg::YieldOp>(loc, res);
276 },
277 pruneAttributeList(op));
278 rewriter.replaceOp(op, linalgOp.getResults());
279 return success();
280 }
281 };
282
283 //===----------------------------------------------------------------------===//
284 // mhlo.Einsum conversion patterns.
285 //===----------------------------------------------------------------------===//
286
287 // Looks through a set of dimension that has been marked as reduction axes,
288 // if it is found within the set, then we set it as "reduction", otherwise
289 // we can label it as "parallel".
getEinsumLoopsAttrs(const llvm::SmallSetVector<StringRef,4> & inputInd,const llvm::SmallSetVector<StringRef,4> & reductionDims)290 SmallVector<StringRef, 3> getEinsumLoopsAttrs(
291 const llvm::SmallSetVector<StringRef, 4>& inputInd,
292 const llvm::SmallSetVector<StringRef, 4>& reductionDims) {
293 SmallVector<StringRef, 3> res;
294 for (StringRef dim : inputInd) {
295 if (!reductionDims.contains(dim)) {
296 res.push_back(getParallelIteratorTypeName());
297 } else {
298 res.push_back(getReductionIteratorTypeName());
299 }
300 }
301 return res;
302 }
303
extractDynamicEinsumSizes(OpBuilder & b,Location loc,Value lhs,Value rhs,const SmallVector<std::string> & lhsLoopVec,const SmallVector<std::string> & rhsLoopVec,const SmallVector<std::string> & outputLoopVec)304 SmallVector<Value, 2> extractDynamicEinsumSizes(
305 OpBuilder& b, Location loc, Value lhs, Value rhs,
306 const SmallVector<std::string>& lhsLoopVec,
307 const SmallVector<std::string>& rhsLoopVec,
308 const SmallVector<std::string>& outputLoopVec) {
309 SmallVector<Value, 2> dynSizes;
310 for (const std::string& dimInd : outputLoopVec) {
311 Value dimSize;
312 const auto* dimIndIt =
313 std::find(lhsLoopVec.begin(), lhsLoopVec.end(), dimInd);
314 if (dimIndIt != lhsLoopVec.end()) {
315 // Query from lhs vars.
316 auto dimIndPos = dimIndIt - lhsLoopVec.begin();
317 auto lhsShape = lhs.getType().dyn_cast<RankedTensorType>().getShape();
318 if (lhsShape[dimIndPos] != ShapedType::kDynamicSize) continue;
319 dimSize = b.create<tensor::DimOp>(loc, lhs, dimIndPos);
320 } else {
321 // query from rhs vars.
322 dimIndIt = std::find(rhsLoopVec.begin(), rhsLoopVec.end(), dimInd);
323 auto dimIndPos = dimIndIt - rhsLoopVec.begin();
324 auto rhsShape = rhs.getType().dyn_cast<RankedTensorType>().getShape();
325 if (rhsShape[dimIndPos] != ShapedType::kDynamicSize) continue;
326 dimSize = b.create<tensor::DimOp>(loc, rhs, dimIndPos);
327 }
328 dynSizes.push_back(dimSize);
329 }
330 return dynSizes;
331 }
332
333 // Adds indices/axes that are missing from output set.
findSummationAxes(const llvm::SmallSetVector<StringRef,4> & inputSet,const llvm::SmallSetVector<StringRef,4> & outputSet)334 llvm::SmallSetVector<StringRef, 4> findSummationAxes(
335 const llvm::SmallSetVector<StringRef, 4>& inputSet,
336 const llvm::SmallSetVector<StringRef, 4>& outputSet) {
337 llvm::SmallSetVector<StringRef, 4> summationAxes;
338 for (StringRef ind : inputSet) {
339 if (!outputSet.contains(ind)) summationAxes.insert(ind);
340 }
341 return summationAxes;
342 }
343
344 // Given a 1:1 map from std::string -> affine dimension expression
345 // we can get the affine expression of dimensions that an
346 // operand will access based on the input_str of einsum_config.
347 // For example:
348 // let string_dim_umap = {'a' : d0, 'b' : d1, 'c' : d2}
349 // for einsum_config "abc,cb->acb"
350 // first_input_operand will get umap[{"a","b","c"}] -> (d0, d1, d2).
351 // second_input_operand will get umap[{"c","b"}] -> (d2, d1).
352 // output_operand will get umap[{"a","c","b"}] -> (d0, d2, d1).
getExprFromConfig(const SmallVector<std::string> & loopDims,const DenseMap<StringRef,AffineExpr> & strAffineDimUmap)353 SmallVector<AffineExpr> getExprFromConfig(
354 const SmallVector<std::string>& loopDims,
355 const DenseMap<StringRef, AffineExpr>& strAffineDimUmap) {
356 SmallVector<AffineExpr> exprs;
357 for (const auto& dim : loopDims) {
358 exprs.push_back(strAffineDimUmap.lookup(dim));
359 }
360 return exprs;
361 }
362
363 // Convert mhlo.einsum op into linalg.generic.
364 // Algorithm in general 3 steps:
365
366 // Step1) Dissect entire einsum_config to different operands
367 // e.g f("abc,cd->abd") = {lhs:["abc"], rhs:["cd"], out:["abd"]}.
368
369 // Step2) Split up the string into vector of the elements
370 // e.g {lhs:["abc"], rhs:["cd"], out:["abd"]} = {lhs:["a","b","c"],
371 // rhs:["c","d"], out:["a","b","d"]}.
372
373 // Step3) Convert the vector into data access
374 // patern represented by affineMaps with affineDimensions e.g
375 // {lhs:["a","b","c"], rhs:["c","d"], out:["a","b","d"]} = {lhs:[d0,d1,d2],
376 // rhs:[d2,d3], out:[d0,d1,d3]}.
377 class EinsumToLinalgConverter : public OpConversionPattern<mhlo::EinsumOp> {
378 public:
379 using OpConversionPattern<mhlo::EinsumOp>::OpConversionPattern;
380
matchAndRewrite(mhlo::EinsumOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const381 LogicalResult matchAndRewrite(
382 mhlo::EinsumOp op, OpAdaptor adaptor,
383 ConversionPatternRewriter& rewriter) const final {
384 auto getRank = [](Value v) {
385 return v.getType().cast<ShapedType>().getRank();
386 };
387 auto einsumConfig = op.einsum_config();
388
389 // With the assumption of binary input operand and single output
390 // get the inputs and output operands' indices.
391 // einsum_config = "lhs_loop,rhs_loop->out_loop"
392 std::size_t posArrow = einsumConfig.find(kArrow);
393 std::size_t posComma = einsumConfig.find(kComma);
394
395 StringRef lhsLoop = einsumConfig.substr(0, posComma);
396 StringRef rhsLoop = einsumConfig.substr(
397 posComma + kComma.size(), posArrow - (posComma + kComma.size()));
398 StringRef outLoop = einsumConfig.substr(posArrow + kArrow.size());
399
400 // Check for Invalid Configs.
401 // 1.Check that there is only maximum 2 inputs
402 // 2.Check that there is only maximum 1 output
403 // 3.Check that there is 1 kArrow
404 if (rhsLoop.find(kComma) != std::string::npos ||
405 outLoop.find(kComma) != std::string::npos ||
406 outLoop.find(kArrow) != std::string::npos) {
407 return rewriter.notifyMatchFailure(op, "Invalid einsum config!");
408 }
409
410 // Find result type, if on tensors.
411 auto resultTy = this->typeConverter->convertType(getHloOpResultType(op))
412 .dyn_cast<RankedTensorType>();
413
414 // Check result type compatibility.
415 if (!resultTy || !(resultTy.getElementType().isSignlessIntOrFloat())) {
416 return rewriter.notifyMatchFailure(op, "Invalid result type");
417 }
418
419 // Convert the representation to vector<string>.
420 SmallVector<std::string> lhsEin =
421 getEinsumConfigAsVector(lhsLoop, getRank(adaptor.lhs()));
422 SmallVector<std::string> rhsEin =
423 getEinsumConfigAsVector(rhsLoop, getRank(adaptor.rhs()));
424 SmallVector<std::string> outEin =
425 getEinsumConfigAsVector(outLoop, resultTy.getRank());
426
427 if (!checkBatchHasEqualRank(lhsEin.size(), lhsLoop, rhsEin.size(), rhsLoop,
428 outEin.size(), outLoop)) {
429 return rewriter.notifyMatchFailure(
430 op, "Invalid elipsis('...') within einsum config!");
431 }
432
433 // Find all unique indices in the input and output.
434 llvm::SmallSetVector<StringRef, 4> inputInd;
435 llvm::SmallSetVector<StringRef, 4> outputInd;
436
437 inputInd.insert(lhsEin.begin(), lhsEin.end());
438 inputInd.insert(rhsEin.begin(), rhsEin.end());
439 outputInd.insert(outEin.begin(), outEin.end());
440
441 llvm::SmallSetVector<StringRef, 4> reductionAxe =
442 findSummationAxes(inputInd, outputInd);
443
444 // Find input/output values and types.
445 auto loc = op.getLoc();
446
447 // Prepare init tensor for linalg.generic op.
448 auto dynSizes = extractDynamicEinsumSizes(
449 rewriter, loc, adaptor.lhs(), adaptor.rhs(), lhsEin, rhsEin, outEin);
450 Value output = getInitTensor(rewriter, loc, resultTy, dynSizes);
451 if (!reductionAxe.empty()) {
452 output = fillTensorWithZeros(rewriter, loc, output);
453 }
454
455 // Create indexing maps.
456 // Create a 1:1 map from f:strDimension -> affineDimension.
457 int64_t nloops = inputInd.size();
458 DenseMap<StringRef, AffineExpr> strAffineDimUmap;
459 for (auto& it : llvm::enumerate(inputInd)) {
460 strAffineDimUmap[it.value()] = rewriter.getAffineDimExpr(it.index());
461 }
462
463 // From einsum_config of each operand in vector<string>, generate
464 // the equivalent vector<AffineExpr>.
465 SmallVector<AffineMap, 4> maps;
466 for (const SmallVector<std::string>& loopOperand :
467 {lhsEin, rhsEin, outEin}) {
468 auto exprs = getExprFromConfig(loopOperand, strAffineDimUmap);
469 maps.push_back(AffineMap::get(nloops, 0, exprs, rewriter.getContext()));
470 }
471
472 auto linalgOp = rewriter.create<linalg::GenericOp>(
473 loc, resultTy ? resultTy : TypeRange{}, adaptor.getOperands(), output,
474 maps, getEinsumLoopsAttrs(inputInd, reductionAxe),
475 [reductionAxe](OpBuilder& b, Location nestedLoc, ValueRange args) {
476 Value resultVal =
477 b.create<mlir::arith::MulFOp>(nestedLoc, args[0], args[1]);
478 if (!reductionAxe.empty()) {
479 resultVal =
480 b.create<mlir::arith::AddFOp>(nestedLoc, args[2], resultVal);
481 }
482 b.create<linalg::YieldOp>(nestedLoc, resultVal);
483 },
484 pruneAttributeList(op));
485 rewriter.replaceOp(op, linalgOp.getResults());
486 return success();
487 }
488
489 private:
490 static constexpr StringRef kArrow = "->";
491 static constexpr StringRef kComma = ",";
492 static constexpr StringRef kEllipsis = "...";
493
494 static bool checkBatchHasEqualRank(size_t lhsRank, StringRef lhsLoop,
495 size_t rhsRank, StringRef rhsLoop,
496 size_t outRank, StringRef outLoop);
497 static SmallVector<std::string> getEinsumConfigAsVector(StringRef loop,
498 size_t operandRank);
499 };
500
501 // Definition of util const member variables.
502 constexpr StringRef EinsumToLinalgConverter::kArrow;
503 constexpr StringRef EinsumToLinalgConverter::kComma;
504 constexpr StringRef EinsumToLinalgConverter::kEllipsis;
505
506 // Convert the representation from string/vector<char> to vector<string>.
507 // i.e ("abc") -> {"a", "b", "c"}. For cases with ellipsis with batch rank 3:
508 // get loop_dim = f("ab...cde") = {"a","b","0","1","2","c","d","e"}
getEinsumConfigAsVector(StringRef loop,size_t operandRank)509 SmallVector<std::string> EinsumToLinalgConverter::getEinsumConfigAsVector(
510 StringRef loop, size_t operandRank) {
511 SmallVector<std::string> loopDim;
512 size_t preElip = loop.find(kEllipsis);
513 bool hasElip = preElip != std::string::npos;
514 if (!hasElip) preElip = loop.size();
515 // Add the dimension until the end or up to ellipsis if it exist.
516 for (int64_t preElipInd = 0; preElipInd < static_cast<int64_t>(preElip);
517 preElipInd++) {
518 loopDim.push_back(loop.substr(preElipInd, 1).str());
519 }
520 if (!hasElip) return loopDim;
521 // Case where Ellipsis presence:
522 size_t nonBatchRank = loop.size() - kEllipsis.size();
523 size_t batchRank = operandRank - nonBatchRank;
524 // Add the batch dimension ("0",...,"N") where N is rank of batch into the
525 // loop.
526 for (int64_t batchInd = 0; batchInd < static_cast<int64_t>(batchRank);
527 batchInd++) {
528 loopDim.push_back(std::to_string(batchInd));
529 }
530 // Add the dimension after ellipsis into the loop.
531 int postElip = preElip + kEllipsis.size();
532 for (int64_t postElipInd = postElip;
533 postElipInd < static_cast<int64_t>(loop.size()); postElipInd++) {
534 loopDim.push_back(loop.substr(postElipInd, 1).str());
535 }
536 return loopDim;
537 }
538
539 // Returns true if all operand's batch has same rank.
checkBatchHasEqualRank(size_t lhsRank,StringRef lhsLoop,size_t rhsRank,StringRef rhsLoop,size_t outRank,StringRef outLoop)540 bool EinsumToLinalgConverter::checkBatchHasEqualRank(
541 size_t lhsRank, StringRef lhsLoop, size_t rhsRank, StringRef rhsLoop,
542 size_t outRank, StringRef outLoop) {
543 SmallVector<int, 3> batchRankVec;
544 if (lhsRank != lhsLoop.size()) {
545 size_t lhsBatchRank = lhsRank - (lhsLoop.size() - kEllipsis.size());
546 batchRankVec.push_back(lhsBatchRank);
547 }
548 if (rhsRank != rhsLoop.size()) {
549 size_t rhsBatchRank = rhsRank - (rhsLoop.size() - kEllipsis.size());
550 batchRankVec.push_back(rhsBatchRank);
551 }
552 if (outRank != outLoop.size()) {
553 size_t outBatchRank = outRank - (outLoop.size() - kEllipsis.size());
554 batchRankVec.push_back(outBatchRank);
555 }
556 bool batchHasEqualRank = true;
557
558 // Condition is valid if only 1 operand or less have batches.
559 if (batchRankVec.size() < 2) return batchHasEqualRank;
560 if (!std::equal(batchRankVec.begin() + 1, batchRankVec.end(),
561 batchRankVec.begin()) &&
562 batchRankVec.size() > 1)
563 batchHasEqualRank = false;
564 return batchHasEqualRank;
565 }
566
567 template <typename MhloOp>
568 class ScalarPointwiseToStandardConverter : public OpConversionPattern<MhloOp> {
569 public:
570 using OpConversionPattern<MhloOp>::OpConversionPattern;
571
matchAndRewrite(MhloOp mhloOp,ConversionPatternRewriter & rewriter) const572 LogicalResult matchAndRewrite(
573 MhloOp mhloOp, ConversionPatternRewriter& rewriter) const final {
574 auto loc = mhloOp.getLoc();
575 auto argType =
576 mhloOp.getOperand(0).getType().template dyn_cast<ShapedType>();
577 if (!argType || !argType.getElementType().isSignlessIntOrFloat() ||
578 (argType.getRank() != 0)) {
579 return failure();
580 }
581
582 // Create two loads from the input.
583 auto lhs = rewriter.create<memref::LoadOp>(loc, mhloOp.lhs());
584 auto rhs = rewriter.create<memref::LoadOp>(loc, mhloOp.rhs());
585 Value opResult = mhlo::MhloOpToStdScalarOp::mapOp(
586 mhloOp, argType.getElementType(), llvm::ArrayRef<Value>{lhs, rhs},
587 &rewriter);
588 rewriter.create<memref::StoreOp>(loc, opResult, mhloOp.out());
589 rewriter.eraseOp(mhloOp);
590 return success();
591 }
592 };
593
594 /// Base class for lowering HLO operations that have one operand and one result,
595 /// and are semantically equivalent to a copy of the input to the output (like
596 /// transpose, some reshape, etc.). The derived classes need to provide a method
597 /// `getIndexingMaps` that returns AffineMaps for the index maps of the input
598 /// and the output.
599 template <typename Derived, typename OpTy>
600 class DataMovementOpConverter : public OpConversionPattern<OpTy> {
601 public:
602 using OpConversionPattern<OpTy>::OpConversionPattern;
603
matchAndRewrite(OpTy op,typename OpTy::Adaptor adaptor,ConversionPatternRewriter & rewriter) const604 LogicalResult matchAndRewrite(
605 OpTy op, typename OpTy::Adaptor adaptor,
606 ConversionPatternRewriter& rewriter) const final {
607 if (!verifyHloOpBufferOrTensorSemantics(op)) return failure();
608 auto resultType = getHloOpResultType(op);
609 resultType = this->typeConverter->convertType(resultType)
610 .template cast<ShapedType>();
611
612 SmallVector<AffineMap, 2> indexingMaps =
613 Derived::getIndexingMaps(op, &rewriter);
614 if (indexingMaps.empty()) return failure();
615
616 auto nloops = resultType.getRank();
617 auto loc = op.getLoc();
618 auto linalgOp = rewriter.create<linalg::GenericOp>(
619 loc,
620 /*resultTensorTypes=*/resultType,
621 /*inputs=*/adaptor.getOperands().front(),
622 /*outputBuffers=*/
623
624 ValueRange{getInitTensorFor(rewriter, loc, resultType, op,
625 adaptor.getOperands())},
626 indexingMaps, getNParallelLoopsAttrs(nloops),
627 [&](OpBuilder& nestedBuilder, Location /*nested_loc*/,
628 ValueRange args) {
629 nestedBuilder.create<linalg::YieldOp>(loc, *args.begin());
630 },
631 pruneAttributeList(op));
632 rewriter.replaceOp(op, linalgOp.getOperation()->getResults());
633 return success();
634 }
635 };
636
637 /// Pattern to convert BroadcastOp to Linalg ops.
638 template <typename OpTy>
639 class BroadcastConverter
640 : public DataMovementOpConverter<BroadcastConverter<OpTy>, OpTy> {
641 public:
642 using DataMovementOpConverter<BroadcastConverter,
643 OpTy>::DataMovementOpConverter;
644
getIndexingMaps(OpTy broadcastOp,Builder * b)645 static SmallVector<AffineMap, 2> getIndexingMaps(OpTy broadcastOp,
646 Builder* b) {
647 ShapedType inputType =
648 broadcastOp.operand().getType().template cast<ShapedType>();
649 unsigned inputRank = inputType.getRank();
650 unsigned nloops = getHloOpResultType(broadcastOp).getRank();
651
652 // BroadcastOp prepends the dimensions in the `broadcast_sizes` attribute to
653 // the input's dimensions.
654 unsigned numPrependedDims = llvm::size(broadcastOp.broadcast_sizes());
655 SmallVector<AffineExpr, 4> inputDimExprs;
656 inputDimExprs.reserve(inputRank);
657 for (unsigned i = 0; i < inputRank; ++i) {
658 inputDimExprs.push_back(b->getAffineDimExpr(numPrependedDims + i));
659 }
660
661 AffineMap inputMap;
662 MLIRContext* context = b->getContext();
663 if (inputDimExprs.empty()) {
664 // The input is a scalar, i.e. this is a scalar broadcast op.
665 inputMap = AffineMap::get(nloops, /*symbolCount=*/0, context);
666 } else {
667 inputMap =
668 AffineMap::get(nloops, /*symbolCount=*/0, inputDimExprs, context);
669 }
670 return {inputMap, b->getMultiDimIdentityMap(nloops)};
671 }
672 };
673
674 class HloBroadcastInDimConverter
675 : public DataMovementOpConverter<HloBroadcastInDimConverter,
676 mhlo::BroadcastInDimOp> {
677 public:
678 using DataMovementOpConverter<
679 HloBroadcastInDimConverter,
680 mhlo::BroadcastInDimOp>::DataMovementOpConverter;
681
getIndexingMaps(mhlo::BroadcastInDimOp broadcastOp,Builder * b)682 static SmallVector<AffineMap, 2> getIndexingMaps(
683 mhlo::BroadcastInDimOp broadcastOp, Builder* b) {
684 auto resultType = getHloOpResultType(broadcastOp);
685 auto operandType =
686 broadcastOp.operand().getType().template cast<ShapedType>();
687 unsigned nloops = resultType.getRank();
688
689 // The input is a scalar, i.e. this is a scalar broadcast op.
690 if (operandType.getRank() == 0) {
691 return {AffineMap::get(nloops, /*symbolCount=*/0, b->getContext()),
692 b->getMultiDimIdentityMap(nloops)};
693 }
694
695 auto operandShape = operandType.getShape();
696 SmallVector<AffineExpr, 4> dimExprs;
697 dimExprs.reserve(nloops);
698
699 if (broadcastOp.broadcast_dimensions()) {
700 for (const auto& broadcastDim :
701 enumerate(broadcastOp.broadcast_dimensions().getValues<APInt>())) {
702 int size = broadcastDim.value().getSExtValue();
703 bool expansionNeeded = operandShape[broadcastDim.index()] == 1 &&
704 resultType.getShape()[size] != 1;
705 dimExprs.push_back(expansionNeeded ? b->getAffineConstantExpr(0)
706 : b->getAffineDimExpr(size));
707 }
708 }
709 return {
710 AffineMap::get(nloops, /*symbolCount=*/0, dimExprs, b->getContext()),
711 b->getMultiDimIdentityMap(nloops)};
712 }
713 };
714
715 // If the input has a static shape we know exactly when the broadcast must
716 // expand (the dimension is 1, which also trivially expands to 1) or will never
717 // expand (the dimension is not 1). We can also source the information from the
718 // optionally provided attrbibutes on statically known broadcasting behavior.
719 // This means we can lower the broadcast just as we would lower a fully static
720 // broadcast and go directly to `linalg.generic`.
721
722 // This also covers the important case of broadcasting a scalar. Ideally the
723 // pattern (`mhlo.constant` -> `mhlo.dynamic_broadcast_in_dim`) should be
724 // converted to a tensor dialect op similar to TF's `ConstantLikeOp`.
725 class HloDynamicBroadcastInDimConverter
726 : public OpConversionPattern<mhlo::DynamicBroadcastInDimOp> {
727 public:
728 using OpConversionPattern<mhlo::DynamicBroadcastInDimOp>::OpConversionPattern;
729
matchAndRewrite(mhlo::DynamicBroadcastInDimOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const730 LogicalResult matchAndRewrite(
731 mhlo::DynamicBroadcastInDimOp op, OpAdaptor adaptor,
732 ConversionPatternRewriter& rewriter) const final {
733 Value operand = adaptor.operand();
734 auto operandType = operand.getType().dyn_cast<RankedTensorType>();
735 if (!operandType) return failure();
736 auto resultType =
737 typeConverter->convertType(op.getType()).dyn_cast<RankedTensorType>();
738 if (!resultType) return failure();
739
740 // Determine dimension expressions based on whether the dimension is
741 // expanding (0) or non-expanding (identity), and fail if we cannot decide
742 // this.
743 SmallVector<AffineExpr> dimExprs(operandType.getRank(), nullptr);
744
745 // Use static type info.
746 auto bcastDims = llvm::to_vector(
747 llvm::map_range(op.broadcast_dimensions(), [](const APInt& d) {
748 return static_cast<int64_t>(d.getLimitedValue());
749 }));
750 for (const auto& it : llvm::enumerate(operandType.getShape())) {
751 if (ShapedType::isDynamic(it.value())) continue;
752 bool isExpanding = it.value() == 1;
753 dimExprs[it.index()] =
754 isExpanding ? rewriter.getAffineConstantExpr(0)
755 : rewriter.getAffineDimExpr(bcastDims[it.index()]);
756 }
757
758 // Use annotated expansion behavior, if available.
759 if (op.known_expanding_dimensions()) {
760 for (const auto& it :
761 op.known_expanding_dimensions()->getValues<APInt>()) {
762 auto i = it.getLimitedValue();
763 dimExprs[i] = rewriter.getAffineConstantExpr(0);
764 }
765 }
766 if (op.known_nonexpanding_dimensions()) {
767 for (const auto& it :
768 op.known_nonexpanding_dimensions()->getValues<APInt>()) {
769 auto i = it.getLimitedValue();
770 dimExprs[i] = rewriter.getAffineDimExpr(bcastDims[i]);
771 }
772 }
773
774 // Fail if unknown expansion behavior remains.
775 if (!llvm::all_of(dimExprs, [](AffineExpr expr) { return expr; }))
776 return failure();
777
778 // Materialize `linalg.generic` op.
779 Location loc = op.getLoc();
780 int64_t nloops = resultType.getRank();
781 Value init =
782 getInitTensorFor(rewriter, loc, resultType, op, adaptor.getOperands());
783 rewriter.replaceOpWithNewOp<linalg::GenericOp>(
784 op, TypeRange{init.getType()}, ValueRange{operand},
785 /*outputBuffers=*/ValueRange{init},
786 llvm::makeArrayRef(
787 {AffineMap::get(/*dimCount=*/nloops, /*symbolCount=*/0, dimExprs,
788 rewriter.getContext()),
789 rewriter.getMultiDimIdentityMap(nloops)}),
790 getNParallelLoopsAttrs(nloops),
791 [&](OpBuilder& nestedBuilder, Location /*nested_loc*/,
792 ValueRange args) {
793 nestedBuilder.create<linalg::YieldOp>(loc, *args.begin());
794 },
795 pruneAttributeList(op));
796 return success();
797 }
798 };
799
800 template <typename OpTy>
801 class TransposeConverter
802 : public DataMovementOpConverter<TransposeConverter<OpTy>, OpTy> {
803 public:
804 using DataMovementOpConverter<TransposeConverter<OpTy>,
805 OpTy>::DataMovementOpConverter;
getIndexingMaps(OpTy op,Builder * b)806 static SmallVector<AffineMap, 2> getIndexingMaps(OpTy op, Builder* b) {
807 auto resultType = getHloOpResultType(op).template cast<ShapedType>();
808 auto nloops = resultType.getRank();
809 SmallVector<AffineExpr, 2> inputExprs;
810 inputExprs.resize(resultType.getRank());
811 for (const auto& permutation : llvm::enumerate(op.permutation())) {
812 inputExprs[permutation.value().getZExtValue()] =
813 b->getAffineDimExpr(permutation.index());
814 }
815 return {
816 AffineMap::get(nloops, /*symbolCount=*/0, inputExprs, b->getContext()),
817 b->getMultiDimIdentityMap(nloops)};
818 }
819 };
820
821 // Lowers mhlo.RealDynamicSliceOp to tensor.extract_slice and other
822 // arith/tensor dialect ops.
823 class RealDynamicSliceConverter
824 : public OpConversionPattern<mhlo::RealDynamicSliceOp> {
825 public:
826 using OpConversionPattern<mhlo::RealDynamicSliceOp>::OpConversionPattern;
827
828 // Computes size of a slice as
829 // size = ceil((limit - start)/stride)
computeSize(Location loc,Value start,Value limit,Value stride,ConversionPatternRewriter & b)830 static Value computeSize(Location loc, Value start, Value limit, Value stride,
831 ConversionPatternRewriter& b) {
832 Value delta = b.create<arith::SubIOp>(loc, limit, start);
833 Value ret = b.create<arith::CeilDivUIOp>(loc, delta, stride);
834 if (ret.getType().isIndex()) return ret;
835 return b.create<arith::IndexCastOp>(loc, b.getIndexType(), ret);
836 }
837
matchAndRewrite(mhlo::RealDynamicSliceOp realDynamicSliceOp,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const838 LogicalResult matchAndRewrite(
839 mhlo::RealDynamicSliceOp realDynamicSliceOp, OpAdaptor adaptor,
840 ConversionPatternRewriter& rewriter) const final {
841 Location loc = realDynamicSliceOp.getLoc();
842 auto argType = adaptor.operand().getType().dyn_cast<ShapedType>();
843 if (!argType || !argType.hasRank()) {
844 return rewriter.notifyMatchFailure(realDynamicSliceOp,
845 "require known-rank args");
846 }
847
848 Type dimElementType = getElementTypeOrSelf(adaptor.start_indices());
849 if (getElementTypeOrSelf(adaptor.limit_indices()) != dimElementType ||
850 getElementTypeOrSelf(adaptor.strides()) != dimElementType) {
851 return rewriter.notifyMatchFailure(
852 realDynamicSliceOp,
853 "requires same element type for all dimension specification");
854 }
855 Type arithType =
856 dimElementType.isIndex() ? rewriter.getI64Type() : dimElementType;
857 Type indexType = rewriter.getIndexType();
858
859 auto resultType =
860 this->typeConverter->convertType(realDynamicSliceOp.getType())
861 .cast<RankedTensorType>();
862 Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
863 SmallVector<OpFoldResult, 4> offsets, sizes, strides;
864 SmallVector<Type, 3> clampType(3, arithType);
865 for (auto i : llvm::seq<unsigned>(0, argType.getRank())) {
866 Value dim = rewriter.create<arith::ConstantIndexOp>(loc, i);
867 Value start =
868 rewriter.create<tensor::ExtractOp>(loc, adaptor.start_indices(), dim);
869 Value limit =
870 rewriter.create<tensor::ExtractOp>(loc, adaptor.limit_indices(), dim);
871 Value stride =
872 rewriter.create<tensor::ExtractOp>(loc, adaptor.strides(), dim);
873
874 // Compute i-th dimension size of the result : size[i].
875 // If the i-th dimension of the result type is known, we go ahead with it
876 // else we compute it using limit, start and stride values.
877 int64_t resultDimSize = resultType.getDimSize(i);
878 Value size =
879 ShapedType::isDynamic(resultDimSize)
880 ? computeSize(loc, start, limit, stride, rewriter)
881 : rewriter.create<arith::ConstantIndexOp>(loc, resultDimSize);
882
883 // We can now convert start to index.
884 if (!start.getType().isIndex())
885 start = rewriter.create<arith::IndexCastOp>(
886 loc, rewriter.getIndexType(), start);
887
888 // Fetch i-th dimension size of the operand and calculate upper bound as
889 // ub = operand_dim[i] - size[i]
890 Value operandDimSize =
891 rewriter.createOrFold<tensor::DimOp>(loc, adaptor.operand(), dim);
892 Value upperBound =
893 rewriter.createOrFold<arith::SubIOp>(loc, operandDimSize, size);
894
895 // We clamp the start_index to keep it bounded as
896 // 0 <= start_index[i] <= ub
897 // Clamp does not support index type, so cast to integer type.
898 start = rewriter.create<arith::MaxSIOp>(loc, start, zero);
899 start = rewriter.create<arith::MinSIOp>(loc, start, upperBound);
900
901 offsets.push_back(start);
902 if (ShapedType::isDynamic(resultDimSize))
903 sizes.push_back(size);
904 else
905 sizes.push_back(IntegerAttr::get(indexType, resultDimSize));
906
907 if (!stride.getType().isIndex())
908 stride =
909 rewriter.createOrFold<arith::IndexCastOp>(loc, indexType, stride);
910 strides.push_back(stride);
911 }
912
913 rewriter.replaceOpWithNewOp<tensor::ExtractSliceOp>(
914 realDynamicSliceOp, resultType, adaptor.operand(), offsets, sizes,
915 strides);
916 return success();
917 }
918 };
919
920 // Converts reshape ops that can be proven to be either a collapse of dimensions
921 // or expansion of dimensions of the operand.
922 class ReshapeOpConverter : public OpConversionPattern<mhlo::ReshapeOp> {
923 public:
924 using OpConversionPattern::OpConversionPattern;
925
matchAndRewrite(mhlo::ReshapeOp reshapeOp,mhlo::ReshapeOp::Adaptor adaptor,ConversionPatternRewriter & rewriter) const926 LogicalResult matchAndRewrite(
927 mhlo::ReshapeOp reshapeOp, mhlo::ReshapeOp::Adaptor adaptor,
928 ConversionPatternRewriter& rewriter) const final {
929 if (!verifyHloOpBufferOrTensorSemantics(reshapeOp)) return failure();
930 auto operand = adaptor.operand();
931 auto operandType = operand.getType().cast<ShapedType>();
932 auto elemType = operandType.getElementType();
933 auto resultType = reshapeOp.getType().cast<ShapedType>();
934
935 if (!resultType.hasStaticShape()) return failure();
936
937 resultType = typeConverter->convertType(resultType).cast<ShapedType>();
938
939 // Special case where the result is a scalar.
940 if (resultType.getRank() == 0 && !operandType.hasStaticShape()) {
941 // This means all dimensions of the operand need to be 1. We add a cast to
942 // cast the dynamic dimensions to 1.
943 auto staticType = RankedTensorType::get(
944 llvm::SmallVector<int64_t>(operandType.getRank(), 1), elemType);
945 operand = rewriter.create<tensor::CastOp>(reshapeOp.getLoc(), staticType,
946 operand);
947 rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>(
948 reshapeOp, resultType, operand, ArrayRef<ReassociationIndices>{});
949 return success();
950 }
951
952 // Compute the reassociation maps for the linalg operation. This will
953 // succeed if the reshape can be done with a single expand_shape or
954 // collapse_shape.
955 if (Optional<SmallVector<ReassociationIndices>> reassociationMap =
956 getReassociationIndicesForReshape(operandType, resultType)) {
957 if (resultType.getRank() < operandType.getRank()) {
958 // We have found a working reassociation map. If the operand is dynamic,
959 // we first need to cast all unknown dimensions in the input that get
960 // collapsed to a static-sized dimension in the output, to 1.
961 SmallVector<int64_t> shape(operandType.getShape().begin(),
962 operandType.getShape().end());
963 for (const auto& map : llvm::enumerate(*reassociationMap)) {
964 // If the result dim is dynamic, we do not mind dynamic entries in the
965 // source.
966 if (resultType.isDynamicDim(map.index())) continue;
967 for (auto targetDim : map.value()) {
968 if (shape[targetDim] == ShapedType::kDynamicSize)
969 shape[targetDim] = 1;
970 }
971 }
972 // Insert a cast if types are not the same (ignoring sparse encoding).
973 auto enc = sparse_tensor::getSparseTensorEncoding(operandType);
974 auto newOperandType = RankedTensorType::get(shape, elemType, enc);
975 if (newOperandType != operandType) {
976 operand = rewriter.create<tensor::CastOp>(reshapeOp.getLoc(),
977 newOperandType, operand);
978 }
979 // Generate collapse operation.
980 rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>(
981 reshapeOp, resultType, operand, *reassociationMap);
982 } else {
983 // Generate expand operation.
984 rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
985 reshapeOp, resultType, operand, *reassociationMap);
986 }
987 return success();
988 }
989
990 Value collapsedOp = operand;
991 Location loc = reshapeOp.getLoc();
992 auto getIdentityExprs = [&rewriter](int64_t n) {
993 SmallVector<AffineExpr, 4> exprs;
994 for (int i = 0; i < n; ++i) exprs.push_back(rewriter.getAffineDimExpr(i));
995 return exprs;
996 };
997 // Otherwise, we need to first reduce all source dimensions into one and
998 // then expand to the destination dimensions. If there is only a single
999 // source dimension, the reduce step can be skipped. TensorCollapseShape
1000 // expects a different rank of operand and result.
1001 if (operandType.getRank() != 1) {
1002 SmallVector<ReassociationExprs, 4> collapsingMap = {
1003 // Use operand_type here because we need to collapse all operands
1004 // dimensions.
1005 getIdentityExprs(operandType.getRank())};
1006
1007 collapsedOp =
1008 rewriter.create<tensor::CollapseShapeOp>(loc, operand, collapsingMap);
1009 }
1010 // Cast to a known static type if the input has dynamic dimensions.
1011 int64_t totalElems = resultType.getNumElements();
1012 auto collapsedType = RankedTensorType::get({totalElems}, elemType);
1013 collapsedOp =
1014 rewriter.create<tensor::CastOp>(loc, collapsedType, collapsedOp);
1015 if (resultType.getRank() == 1) {
1016 rewriter.replaceOp(reshapeOp, collapsedOp);
1017 } else {
1018 SmallVector<ReassociationExprs, 4> expandingMap = {
1019 // Use resultType here because we need to expand to all result
1020 // dimensions.
1021 getIdentityExprs(resultType.getRank())};
1022 rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
1023 reshapeOp, resultType, collapsedOp, expandingMap);
1024 }
1025 return success();
1026 }
1027 };
1028
1029 template <typename OpTy>
1030 class IotaConverter : public OpConversionPattern<OpTy> {
1031 public:
1032 using OpConversionPattern<OpTy>::OpConversionPattern;
1033
matchAndRewrite(OpTy iotaOp,typename OpTy::Adaptor adaptor,ConversionPatternRewriter & rewriter) const1034 LogicalResult matchAndRewrite(
1035 OpTy iotaOp, typename OpTy::Adaptor adaptor,
1036 ConversionPatternRewriter& rewriter) const final {
1037 ShapedType resultShapedType = getHloOpResultType(iotaOp);
1038 if (!resultShapedType) return failure();
1039 resultShapedType = this->typeConverter->convertType(resultShapedType)
1040 .template dyn_cast<ShapedType>();
1041
1042 Type resultElementType = resultShapedType.getElementType();
1043
1044 // Construct the indexing maps needed for linalg.generic ops.
1045 unsigned nloops = resultShapedType.getRank();
1046
1047 Location loc = iotaOp.getLoc();
1048 auto linalgOp = rewriter.create<linalg::GenericOp>(
1049 loc,
1050 /*resultTensorTypes=*/
1051 ArrayRef<Type>{resultShapedType},
1052 /*inputs=*/ValueRange{},
1053 /*outputBuffers=*/
1054
1055 ValueRange{getInitTensorFor(rewriter, loc, resultShapedType, iotaOp,
1056 adaptor.getOperands())},
1057 llvm::makeArrayRef(rewriter.getMultiDimIdentityMap(nloops)),
1058 getNParallelLoopsAttrs(nloops),
1059 [&](OpBuilder& nestedBuilder, Location nestedLoc, ValueRange /*args*/) {
1060 Value indexOp = nestedBuilder.create<linalg::IndexOp>(
1061 nestedLoc, iotaOp.iota_dimension());
1062 Type unwrappedResultElementType = resultElementType;
1063 if (auto complexType =
1064 unwrappedResultElementType.dyn_cast<ComplexType>())
1065 unwrappedResultElementType = complexType.getElementType();
1066 Value castOp = nestedBuilder.create<arith::IndexCastOp>(
1067 nestedLoc,
1068 nestedBuilder.getIntegerType(
1069 unwrappedResultElementType.getIntOrFloatBitWidth()),
1070 indexOp);
1071 castOp = mhlo::MhloOpToStdScalarOp::mapOpOfType<mhlo::ConvertOp>(
1072 nestedLoc, resultElementType, castOp.getType(), castOp,
1073 &nestedBuilder);
1074 nestedBuilder.create<linalg::YieldOp>(nestedLoc, castOp);
1075 },
1076 pruneAttributeList(iotaOp));
1077 rewriter.replaceOp(iotaOp, linalgOp.result_tensors());
1078 return success();
1079 }
1080 };
1081
1082 /// Converts mhlo.concatenate operation to a linalg.generic op.
1083 struct ConcatenateConverter : public OpConversionPattern<mhlo::ConcatenateOp> {
1084 using OpConversionPattern<mhlo::ConcatenateOp>::OpConversionPattern;
1085
matchAndRewritemlir::mhlo::__anon071fe3ef0111::ConcatenateConverter1086 LogicalResult matchAndRewrite(
1087 mhlo::ConcatenateOp op, OpAdaptor adaptor,
1088 ConversionPatternRewriter& rewriter) const override {
1089 // Shortcut the one-operand case, simplifies code below.
1090 if (adaptor.getOperands().size() == 1) {
1091 rewriter.replaceOp(op, adaptor.getOperands()[0]);
1092 return success();
1093 }
1094
1095 auto resultType = this->typeConverter->convertType(op.getResult().getType())
1096 .dyn_cast<RankedTensorType>();
1097 if (!resultType) return failure();
1098
1099 uint64_t dim = op.dimension();
1100 Location loc = op.getLoc();
1101 Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
1102
1103 // Allocate the output tensor with init_tensor.
1104 Value result =
1105 getInitTensorFor(rewriter, loc, resultType, op, adaptor.getOperands());
1106
1107 // Generate a generic op to gather the elements of the concatenate. This is
1108 // awkward standalone but allows fusion with other generic ops.
1109 int64_t nloops = resultType.getRank();
1110 rewriter.replaceOpWithNewOp<linalg::GenericOp>(
1111 op,
1112 /*resultTensorTypes=*/resultType,
1113 /*inputs=*/ValueRange{}, /*outputBuffers=*/result,
1114 llvm::makeArrayRef(rewriter.getMultiDimIdentityMap(nloops)),
1115 getNParallelLoopsAttrs(nloops),
1116 [&](OpBuilder& nestedBuilder, Location loc, ValueRange) {
1117 OpBuilder b = nestedBuilder;
1118 Value concatDimSize = zero;
1119 Value result;
1120
1121 SmallVector<Value, 4> extractIndices;
1122 extractIndices.reserve(nloops);
1123 for (int64_t i = 0; i < nloops; i++) {
1124 extractIndices.push_back(b.create<linalg::IndexOp>(loc, i));
1125 }
1126
1127 Value indexOp = b.create<linalg::IndexOp>(loc, dim);
1128 for (auto& it : llvm::enumerate(adaptor.getOperands())) {
1129 Value arg = it.value();
1130 Value newConcatDimSize;
1131 scf::IfOp ifOp;
1132 if (it.index() != (adaptor.getOperands().size() - 1)) {
1133 // Calculate how far along we have iterated along the concatenate
1134 // dimension. That way we can tell which input to select.
1135 newConcatDimSize = b.create<arith::AddIOp>(
1136 loc, concatDimSize, b.create<tensor::DimOp>(loc, arg, dim));
1137 Value cmp = b.create<arith::CmpIOp>(loc, rewriter.getI1Type(),
1138 arith::CmpIPredicate::ult,
1139 indexOp, newConcatDimSize);
1140 ifOp = b.create<scf::IfOp>(loc, resultType.getElementType(), cmp,
1141 true);
1142 if (result) {
1143 b.create<scf::YieldOp>(loc, ifOp->getResults()[0]);
1144 } else {
1145 result = ifOp->getResults()[0];
1146 }
1147
1148 b = ifOp.getThenBodyBuilder(b.getListener());
1149 }
1150
1151 // Now adjust the index for the concatenated dimension to fit into
1152 // the selected tensor and do an extract at that position.
1153 extractIndices[dim] =
1154 b.create<arith::SubIOp>(loc, indexOp, concatDimSize);
1155 Value extract =
1156 b.create<tensor::ExtractOp>(loc, arg, extractIndices);
1157 b.create<scf::YieldOp>(loc, extract);
1158
1159 if (ifOp) {
1160 b = ifOp.getElseBodyBuilder(b.getListener());
1161 concatDimSize = newConcatDimSize;
1162 }
1163 }
1164 nestedBuilder.create<linalg::YieldOp>(loc, result);
1165 },
1166 pruneAttributeList(op));
1167 return success();
1168 }
1169 };
1170
1171 class ConstConverterTensor : public OpConversionPattern<mhlo::ConstantOp> {
1172 public:
1173 using OpConversionPattern::OpConversionPattern;
1174
matchAndRewrite(mhlo::ConstantOp constOp,OpAdaptor,ConversionPatternRewriter & rewriter) const1175 LogicalResult matchAndRewrite(
1176 mhlo::ConstantOp constOp, OpAdaptor /*adaptor*/,
1177 ConversionPatternRewriter& rewriter) const final {
1178 auto valueAttr = constOp.value().cast<DenseElementsAttr>();
1179 auto type =
1180 typeConverter->convertType(constOp.getType()).cast<ShapedType>();
1181 if (type != constOp.getType()) {
1182 // Signedness conversion.
1183 valueAttr = valueAttr.mapValues(type.getElementType(),
1184 [](const APInt& i) { return i; });
1185 }
1186 rewriter.replaceOpWithNewOp<arith::ConstantOp>(constOp, type, valueAttr);
1187 return success();
1188 }
1189 };
1190
1191 // TODO(b/156787842): Support the lowering for dynamic shapes.
1192 class ReverseConverter
1193 : public DataMovementOpConverter<ReverseConverter, mhlo::ReverseOp> {
1194 public:
1195 using DataMovementOpConverter<ReverseConverter,
1196 mhlo::ReverseOp>::DataMovementOpConverter;
getIndexingMaps(mhlo::ReverseOp op,Builder * b)1197 static SmallVector<AffineMap, 2> getIndexingMaps(mhlo::ReverseOp op,
1198 Builder* b) {
1199 auto resultType = getHloOpResultType(op).cast<ShapedType>();
1200 auto nloops = resultType.getRank();
1201 SmallVector<AffineExpr, 2> inputExprs;
1202 inputExprs.reserve(nloops);
1203 for (int i = 0; i < nloops; ++i)
1204 inputExprs.push_back(b->getAffineDimExpr(i));
1205 for (auto dim : op.dimensions()) {
1206 int i = dim.getZExtValue();
1207 if (resultType.isDynamicDim(i)) return {};
1208 int n = resultType.getShape()[i];
1209 inputExprs[i] = b->getAffineConstantExpr(n - 1) - inputExprs[i];
1210 }
1211 return {
1212 AffineMap::get(nloops, /*symbolCount=*/0, inputExprs, b->getContext()),
1213 b->getMultiDimIdentityMap(nloops)};
1214 }
1215 };
1216
1217 class SliceConverter : public OpConversionPattern<mhlo::SliceOp> {
1218 public:
1219 using OpConversionPattern::OpConversionPattern;
1220
matchAndRewrite(mhlo::SliceOp sliceOp,typename mhlo::SliceOp::Adaptor adaptor,ConversionPatternRewriter & rewriter) const1221 LogicalResult matchAndRewrite(
1222 mhlo::SliceOp sliceOp, typename mhlo::SliceOp::Adaptor adaptor,
1223 ConversionPatternRewriter& rewriter) const final {
1224 auto argType = adaptor.getOperands()[0].getType().dyn_cast<ShapedType>();
1225 if (!argType || !argType.hasRank()) {
1226 return rewriter.notifyMatchFailure(sliceOp, "expects known-rank args");
1227 }
1228
1229 SmallVector<OpFoldResult, 3> offsets, sizes, strides;
1230 for (int i = 0, e = argType.getRank(); i < e; ++i) {
1231 auto start = sliceOp.start_indices().getValues<int64_t>()[i];
1232 auto limit = sliceOp.limit_indices().getValues<int64_t>()[i];
1233 auto stride = sliceOp.strides().getValues<int64_t>()[i];
1234 offsets.push_back(rewriter.getI64IntegerAttr(start));
1235 // Say that there are k elements in total, we have condition:
1236 // start + (k - 1) * strides <= limit - 1
1237 // ->
1238 // k <= (limit - 1 - start) / strides + 1
1239 sizes.push_back(
1240 rewriter.getI64IntegerAttr((limit - 1 - start) / stride + 1));
1241 strides.push_back(rewriter.getI64IntegerAttr(stride));
1242 }
1243 rewriter.replaceOpWithNewOp<tensor::ExtractSliceOp>(
1244 sliceOp, adaptor.getOperands()[0], offsets, sizes, strides);
1245 return success();
1246 }
1247 };
1248
1249 class DynamicSliceConverter : public OpConversionPattern<mhlo::DynamicSliceOp> {
1250 public:
1251 using OpConversionPattern<mhlo::DynamicSliceOp>::OpConversionPattern;
1252
matchAndRewrite(mhlo::DynamicSliceOp dynamicSliceOp,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const1253 LogicalResult matchAndRewrite(
1254 mhlo::DynamicSliceOp dynamicSliceOp, OpAdaptor adaptor,
1255 ConversionPatternRewriter& rewriter) const final {
1256 auto loc = dynamicSliceOp.getLoc();
1257 auto argType = adaptor.operand().getType().dyn_cast<ShapedType>();
1258 if (!argType || !argType.hasRank()) {
1259 return rewriter.notifyMatchFailure(dynamicSliceOp,
1260 "require known-rank args");
1261 }
1262
1263 SmallVector<OpFoldResult, 3> startIndices, sizes;
1264 for (auto& en : llvm::enumerate(
1265 llvm::zip(adaptor.start_indices(),
1266 dynamicSliceOp.slice_sizes().getValues<int64_t>()))) {
1267 int64_t size = std::get<1>(en.value());
1268 sizes.push_back(rewriter.getI64IntegerAttr(size));
1269
1270 // By mhlo.DynamicSlice definition:
1271 // `start_indices[i] = clamp(start_indices[i],
1272 // 0, operand.dimension_size[i] - size_indices[i])`
1273 Value startIndex =
1274 rewriter.create<tensor::ExtractOp>(loc, std::get<0>(en.value()));
1275 startIndex = rewriter.createOrFold<arith::IndexCastOp>(
1276 loc, rewriter.getIndexType(), startIndex);
1277
1278 Value mn = rewriter.create<arith::ConstantIndexOp>(loc, 0);
1279
1280 Value mx = rewriter.createOrFold<tensor::DimOp>(loc, adaptor.operand(),
1281 en.index());
1282 mx = rewriter.createOrFold<arith::SubIOp>(
1283 loc, mx, rewriter.create<arith::ConstantIndexOp>(loc, size));
1284
1285 startIndex = rewriter.create<arith::MaxSIOp>(loc, startIndex, mn);
1286 startIndex = rewriter.create<arith::MinSIOp>(loc, startIndex, mx);
1287
1288 startIndices.push_back(startIndex);
1289 }
1290
1291 int64_t rank = argType.getRank();
1292 SmallVector<OpFoldResult, 3> strides(rank, rewriter.getI64IntegerAttr(1));
1293
1294 auto resultType = this->typeConverter->convertType(dynamicSliceOp.getType())
1295 .cast<RankedTensorType>();
1296
1297 rewriter.replaceOpWithNewOp<tensor::ExtractSliceOp>(
1298 dynamicSliceOp, resultType, adaptor.operand(), startIndices, sizes,
1299 strides);
1300 return success();
1301 }
1302 };
1303
1304 class DynamicUpdateSliceConverter
1305 : public OpConversionPattern<mhlo::DynamicUpdateSliceOp> {
1306 public:
1307 using OpConversionPattern<mhlo::DynamicUpdateSliceOp>::OpConversionPattern;
1308
matchAndRewrite(mhlo::DynamicUpdateSliceOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const1309 LogicalResult matchAndRewrite(
1310 mhlo::DynamicUpdateSliceOp op, OpAdaptor adaptor,
1311 ConversionPatternRewriter& rewriter) const final {
1312 auto loc = op.getLoc();
1313 auto operandType = adaptor.operand().getType().dyn_cast<RankedTensorType>();
1314 if (!operandType || !operandType.hasStaticShape()) {
1315 return rewriter.notifyMatchFailure(
1316 op, "require static ranked type for operand");
1317 }
1318
1319 auto updateType = adaptor.update().getType().dyn_cast<RankedTensorType>();
1320 if (!updateType || !updateType.hasStaticShape()) {
1321 return rewriter.notifyMatchFailure(
1322 op, "require static ranked type for operand");
1323 }
1324
1325 // We do not have to clamp sizes because the semantic of `update`
1326 // guarantees that it is always in the bounds. See
1327 // https://www.tensorflow.org/xla/operation_semantics#dynamicupdateslice
1328 SmallVector<OpFoldResult, 3> sizes;
1329 for (auto size : updateType.getShape()) {
1330 sizes.push_back(rewriter.getIndexAttr(size));
1331 }
1332
1333 SmallVector<OpFoldResult, 3> startIndices;
1334 Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
1335 for (auto& en : llvm::enumerate(adaptor.start_indices())) {
1336 // By mhlo.DynamicUpdateSlice definition:
1337 // `start_indices[i] = clamp(start_indices[i],
1338 // 0, operand.dimension_size[i] - update.dimension_size[i])`
1339 Value startIndex = rewriter.create<tensor::ExtractOp>(loc, en.value());
1340 if (!startIndex.getType().isIndex())
1341 startIndex = rewriter.create<arith::IndexCastOp>(
1342 loc, rewriter.getIndexType(), startIndex);
1343 Value ub = rewriter.create<arith::ConstantIndexOp>(
1344 loc, operandType.getDimSize(en.index()) -
1345 updateType.getDimSize(en.index()));
1346
1347 startIndex = rewriter.create<arith::MaxSIOp>(loc, startIndex, zero);
1348 startIndex = rewriter.create<arith::MinSIOp>(loc, startIndex, ub);
1349 startIndices.push_back(startIndex);
1350 }
1351
1352 int64_t rank = operandType.getRank();
1353 SmallVector<OpFoldResult, 3> strides(rank, rewriter.getI64IntegerAttr(1));
1354 rewriter.replaceOpWithNewOp<tensor::InsertSliceOp>(
1355 op, adaptor.update(), adaptor.operand(), startIndices, sizes, strides);
1356 return success();
1357 }
1358 };
1359
1360 enum class DotOperationType {
1361 kVectorDot = 0,
1362 kMatrixVector,
1363 kVectorMatrix,
1364 kMatrixMatrix,
1365 kUnsupported
1366 };
1367
getDotOperationType(mhlo::DotOp dotOp)1368 DotOperationType getDotOperationType(mhlo::DotOp dotOp) {
1369 ArrayRef<int64_t> lhsShape =
1370 dotOp.lhs().getType().cast<ShapedType>().getShape();
1371 ArrayRef<int64_t> rhsShape =
1372 dotOp.rhs().getType().cast<ShapedType>().getShape();
1373 auto shapeMatches = [](int64_t a, int64_t b) {
1374 return a == ShapedType::kDynamicSize || b == ShapedType::kDynamicSize ||
1375 a == b;
1376 };
1377 if (lhsShape.size() == 1 && rhsShape.size() == 1 &&
1378 shapeMatches(lhsShape[0], rhsShape[0])) {
1379 return DotOperationType::kVectorDot;
1380 }
1381 if (lhsShape.size() == 2 && rhsShape.size() == 1 &&
1382 shapeMatches(lhsShape[1], rhsShape[0])) {
1383 return DotOperationType::kMatrixVector;
1384 }
1385 if (lhsShape.size() == 1 && rhsShape.size() == 2 &&
1386 shapeMatches(lhsShape[0], rhsShape[0])) {
1387 return DotOperationType::kVectorMatrix;
1388 }
1389 if (lhsShape.size() == 2 && rhsShape.size() == 2 &&
1390 shapeMatches(lhsShape[1], rhsShape[0])) {
1391 return DotOperationType::kMatrixMatrix;
1392 }
1393 return DotOperationType::kUnsupported;
1394 }
1395
getDotOpInitTensorDynSizes(OpBuilder & b,Location loc,Value lhs,Value rhs,DotOperationType type)1396 SmallVector<Value, 2> getDotOpInitTensorDynSizes(OpBuilder& b, Location loc,
1397 Value lhs, Value rhs,
1398 DotOperationType type) {
1399 SmallVector<Value, 2> dynShape;
1400 switch (type) {
1401 case DotOperationType::kMatrixMatrix: {
1402 if (lhs.getType().cast<ShapedType>().isDynamicDim(0))
1403 dynShape.push_back(b.create<tensor::DimOp>(loc, lhs, 0));
1404 if (rhs.getType().cast<ShapedType>().isDynamicDim(1))
1405 dynShape.push_back(b.create<tensor::DimOp>(loc, rhs, 1));
1406 break;
1407 }
1408 case DotOperationType::kMatrixVector: {
1409 if (lhs.getType().cast<ShapedType>().isDynamicDim(0))
1410 dynShape.push_back(b.create<tensor::DimOp>(loc, lhs, 0));
1411 break;
1412 }
1413 case DotOperationType::kVectorMatrix: {
1414 if (rhs.getType().cast<ShapedType>().isDynamicDim(1))
1415 dynShape.push_back(b.create<tensor::DimOp>(loc, rhs, 1));
1416 break;
1417 }
1418 case DotOperationType::kVectorDot:
1419 case DotOperationType::kUnsupported:
1420 break;
1421 }
1422 return dynShape;
1423 }
1424
1425 template <DotOperationType op_type, typename LinalgOp>
1426 class DotOpConversion : public OpConversionPattern<mhlo::DotOp> {
1427 public:
1428 using OpConversionPattern<mhlo::DotOp>::OpConversionPattern;
matchAndRewrite(mhlo::DotOp op,mhlo::DotOp::Adaptor adaptor,ConversionPatternRewriter & rewriter) const1429 LogicalResult matchAndRewrite(
1430 mhlo::DotOp op, mhlo::DotOp::Adaptor adaptor,
1431 ConversionPatternRewriter& rewriter) const final {
1432 if (!verifyHloOpBufferOrTensorSemantics(op)) {
1433 return failure();
1434 }
1435 if (getDotOperationType(op) != op_type) return failure();
1436
1437 Location loc = op.getLoc();
1438 // Convert unsigned to signed. This works because signed and unsigned
1439 // integer matmul is the same operation in two's complement.
1440 auto outputType =
1441 typeConverter->convertType(op.getType()).cast<ShapedType>();
1442 SmallVector<Value, 2> dynShape = getDotOpInitTensorDynSizes(
1443 rewriter, loc, adaptor.lhs(), adaptor.rhs(), op_type);
1444 auto initTensor = getInitTensor(rewriter, loc, outputType, dynShape);
1445 Value zeroTensor = fillTensorWithZeros(rewriter, loc, initTensor);
1446 rewriter.replaceOpWithNewOp<LinalgOp>(
1447 op, TypeRange{outputType}, ValueRange{adaptor.lhs(), adaptor.rhs()},
1448 ValueRange{zeroTensor}, pruneAttributeList(op));
1449 return success();
1450 }
1451 };
1452
1453 class DotGeneralBatchMatMulOpConversion
1454 : public OpConversionPattern<mhlo::DotGeneralOp> {
1455 public:
1456 using OpConversionPattern<mhlo::DotGeneralOp>::OpConversionPattern;
matchAndRewrite(mhlo::DotGeneralOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const1457 LogicalResult matchAndRewrite(
1458 mhlo::DotGeneralOp op, OpAdaptor adaptor,
1459 ConversionPatternRewriter& rewriter) const final {
1460 if (!verifyHloOpBufferOrTensorSemantics(op)) {
1461 return failure();
1462 }
1463 if (op.getType().cast<RankedTensorType>().getRank() != 3) {
1464 return rewriter.notifyMatchFailure(op, "expected a batch matmul");
1465 }
1466
1467 mhlo::DotDimensionNumbersAttr dimNumbers = op.dot_dimension_numbers();
1468 auto lhsBatchingDims = dimNumbers.getLhsBatchingDimensions();
1469 auto rhsBatchingDims = dimNumbers.getRhsBatchingDimensions();
1470 auto lhsContractingDims = dimNumbers.getLhsContractingDimensions();
1471 auto rhsContractingDims = dimNumbers.getRhsContractingDimensions();
1472 if (lhsBatchingDims.size() != 1 || lhsBatchingDims[0] != 0) {
1473 return rewriter.notifyMatchFailure(
1474 op, "expected lhs batching dimensions exactly {0}");
1475 }
1476 if (rhsBatchingDims.size() != 1 || rhsBatchingDims[0] != 0) {
1477 return rewriter.notifyMatchFailure(
1478 op, "expected rhs batching dimensions exactly {0}");
1479 }
1480 if (lhsContractingDims.size() != 1 || lhsContractingDims[0] != 2) {
1481 return rewriter.notifyMatchFailure(
1482 op, "expected lhs contracting dimensions exactly {2}");
1483 }
1484 if (rhsContractingDims.size() != 1 || rhsContractingDims[0] != 1) {
1485 return rewriter.notifyMatchFailure(
1486 op, "expected rhs contracting dimensions exactly {1}");
1487 }
1488
1489 Location loc = op.getLoc();
1490 // Convert unsigned to signed. This works because signed and unsigned
1491 // integer matmul is the same operation in two's complement.
1492 auto outputType =
1493 typeConverter->convertType(op.getType()).cast<ShapedType>();
1494 auto initTensor =
1495 getInitTensorFor(rewriter, loc, outputType, op, adaptor.getOperands());
1496 Value zeroTensor = fillTensorWithZeros(rewriter, loc, initTensor);
1497 Operation* linalgOp = rewriter.create<linalg::BatchMatmulOp>(
1498 loc, /*resultTensorTypes=*/TypeRange{outputType},
1499 /*inputs=*/ValueRange{adaptor.lhs(), adaptor.rhs()},
1500 /*outputBuffers=*/ValueRange{zeroTensor}, pruneAttributeList(op));
1501
1502 rewriter.replaceOp(op, linalgOp->getResults());
1503 return success();
1504 }
1505 };
1506
1507 class MapOpConverter : public OpConversionPattern<mhlo::MapOp> {
1508 public:
1509 using OpConversionPattern::OpConversionPattern;
matchAndRewrite(mhlo::MapOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const1510 LogicalResult matchAndRewrite(
1511 mhlo::MapOp op, OpAdaptor adaptor,
1512 ConversionPatternRewriter& rewriter) const final {
1513 if (!verifyHloOpBufferOrTensorSemantics(op)) return failure();
1514
1515 auto resultType =
1516 typeConverter->convertType(op.getType()).cast<ShapedType>();
1517 assert(op.dimensions().size() == resultType.getRank() &&
1518 "Expected a pointwise map");
1519
1520 Location loc = op.getLoc();
1521 Value output =
1522 getInitTensorFor(rewriter, loc, resultType, op, adaptor.getOperands());
1523 SmallVector<AffineMap> indexingMaps(
1524 op.getNumOperands() + 1,
1525 rewriter.getMultiDimIdentityMap(resultType.getRank()));
1526
1527 auto linalgOp = rewriter.create<linalg::GenericOp>(
1528 loc, resultType, adaptor.getOperands(), output, indexingMaps,
1529 getNParallelLoopsAttrs(resultType.getRank()),
1530 /*bodyBuild=*/nullptr, pruneAttributeList(op));
1531
1532 // Convert the signature of the body. We scalarize the operands and add a
1533 // scalar operand representing the output tensor.
1534 Region& region = linalgOp.region();
1535 rewriter.inlineRegionBefore(op.computation(), region, region.end());
1536 TypeConverter::SignatureConversion signatureConverter(op.getNumOperands() +
1537 1);
1538
1539 for (const auto& it : llvm::enumerate(op.getOperation()->getOperands())) {
1540 signatureConverter.addInputs(
1541 it.index(),
1542 typeConverter->convertType(
1543 it.value().getType().cast<ShapedType>().getElementType()));
1544 }
1545 signatureConverter.addInputs(resultType.getElementType());
1546
1547 rewriter.applySignatureConversion(®ion, signatureConverter,
1548 getTypeConverter());
1549 rewriter.replaceOp(op, linalgOp.getResults());
1550 return success();
1551 }
1552 };
1553
isInBodyOfLinalgOps(Operation * op)1554 bool isInBodyOfLinalgOps(Operation* op) {
1555 auto* parentOp = op->getParentRegion()->getParentOp();
1556 return parentOp->getDialect() ==
1557 parentOp->getContext()->getLoadedDialect<linalg::LinalgDialect>();
1558 }
1559
getReduceOpInitTensorDynSizes(OpBuilder & b,Location loc,Value arg,ShapedType resultType,ArrayRef<int64_t> reductionDims)1560 SmallVector<Value, 8> getReduceOpInitTensorDynSizes(
1561 OpBuilder& b, Location loc, Value arg, ShapedType resultType,
1562 ArrayRef<int64_t> reductionDims) {
1563 llvm::SmallSetVector<int, 4> s;
1564 for (auto dim : reductionDims) s.insert(dim);
1565
1566 SmallVector<unsigned, 4> parallelDims;
1567 SmallVector<Value, 8> dynShape;
1568 int rank = arg.getType().cast<RankedTensorType>().getRank();
1569 for (int i = 0, j = 0; i < rank; ++i) {
1570 if (s.count(i)) continue;
1571 if (!resultType.isDynamicDim(j++)) continue;
1572 dynShape.push_back(b.create<tensor::DimOp>(loc, arg, i));
1573 }
1574
1575 return dynShape;
1576 }
1577
1578 class ReduceRegionReturnOpConversion
1579 : public OpConversionPattern<mhlo::ReturnOp> {
1580 public:
1581 using OpConversionPattern<mhlo::ReturnOp>::OpConversionPattern;
matchAndRewrite(mhlo::ReturnOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const1582 LogicalResult matchAndRewrite(
1583 mhlo::ReturnOp op, OpAdaptor adaptor,
1584 ConversionPatternRewriter& rewriter) const final {
1585 if (!isInBodyOfLinalgOps(op)) {
1586 return failure();
1587 }
1588 SmallVector<Value, 4> operands(adaptor.getOperands());
1589 for (size_t i = 0; i < operands.size(); ++i) {
1590 if (operands[i].getType().isa<ShapedType>()) {
1591 auto loc = operands[i].getLoc();
1592 operands[i] = rewriter.create<tensor::ExtractOp>(loc, operands[i]);
1593 }
1594 }
1595 rewriter.replaceOpWithNewOp<linalg::YieldOp>(op, operands);
1596 return success();
1597 }
1598 };
1599
1600 class ReduceConversion : public OpConversionPattern<mhlo::ReduceOp> {
1601 public:
1602 using OpConversionPattern<mhlo::ReduceOp>::OpConversionPattern;
matchAndRewrite(mhlo::ReduceOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const1603 LogicalResult matchAndRewrite(
1604 mhlo::ReduceOp op, OpAdaptor adaptor,
1605 ConversionPatternRewriter& rewriter) const final {
1606 Location loc = op.getLoc();
1607
1608 int numOperands = static_cast<int>(adaptor.operands().size());
1609
1610 if (llvm::any_of(adaptor.operands(), [](Value v) {
1611 return !v.getType().cast<ShapedType>().getRank();
1612 })) {
1613 return rewriter.notifyMatchFailure(op, "expects known-rank args");
1614 }
1615 auto srcRank = adaptor.operands()[0].getType().cast<ShapedType>().getRank();
1616
1617 SmallVector<int64_t, 4> reductionDims = extract1DVector(op.dimensions());
1618
1619 SmallVector<Type> resultTypes;
1620 if (failed(typeConverter->convertTypes(op.getResultTypes(), resultTypes)))
1621 return failure();
1622
1623 SmallVector<Value> operands, outputs;
1624 SmallVector<AffineMap, 3> indexingMaps;
1625 for (auto values :
1626 llvm::zip(adaptor.operands(), adaptor.init_values(), resultTypes)) {
1627 // Check if init_value is constant. If so, inline the value into the
1628 // region.
1629 Value operand = std::get<0>(values);
1630 Value initValue = std::get<1>(values);
1631 Type resultType = std::get<2>(values);
1632 initValue = rewriter.createOrFold<tensor::ExtractOp>(loc, initValue);
1633
1634 operands.push_back(operand);
1635 SmallVector<Value, 8> dynShape = getReduceOpInitTensorDynSizes(
1636 rewriter, loc, operand, resultType, reductionDims);
1637 auto initTensor = getInitTensor(rewriter, loc, resultType, dynShape);
1638 Value filledTensor =
1639 rewriter.create<linalg::FillOp>(loc, initValue, initTensor).result();
1640 outputs.push_back(filledTensor);
1641 }
1642
1643 // Prepare indexing maps for linalg generic op. The elements are for src
1644 // and dst. Transpose `src` to make the reduction loops be the innermost,
1645 // because it's easier to fully utilize processors.
1646 indexingMaps.append(
1647 numOperands, getTransposeMapForReduction(rewriter.getContext(),
1648 (int)srcRank, reductionDims));
1649
1650 // The indexing map of `dst` should drop the reduction loops. Since the
1651 // reduction loops now are all in the innermost, drops
1652 // `reduction_dims.size()` dimensions. We don't need an inverse
1653 // permutation here because they are the same.
1654 SmallVector<AffineExpr, 4> exprs;
1655 for (int i = 0, e = srcRank - reductionDims.size(); i < e; ++i)
1656 exprs.push_back(rewriter.getAffineDimExpr(i));
1657 indexingMaps.append(numOperands,
1658 AffineMap::get(srcRank, /*symbolCount=*/0, exprs,
1659 rewriter.getContext()));
1660
1661 auto linalgOp = rewriter.create<linalg::GenericOp>(
1662 loc, /*resultTensorTypes=*/resultTypes, operands,
1663 /*outputBuffers=*/ValueRange{outputs}, indexingMaps,
1664 getParallelAndReductionIterators(srcRank, reductionDims.size()),
1665 /*bodyBuild=*/nullptr, pruneAttributeList(op));
1666
1667 // Convert the signature of the body. The reduce op region apply function
1668 // has a signature (lhs, rhs) -> output, all of the same tensor type t.
1669 // This is converted to a function with the same signature but with
1670 // element types. E.g., "(tensor<f32>, tensor<f32>) -> tensor<f32>" will
1671 // be converted to "(f32, f32, f32)".
1672 Region& region = linalgOp.region();
1673 rewriter.inlineRegionBefore(op.body(), region, region.end());
1674 TypeConverter::SignatureConversion signatureConverter(numOperands * 2);
1675
1676 // Reduce requires that the seed be used as a LHS operand inside the
1677 // region, and the seed is encoded in linalg in the intial out value, so
1678 // modify the signature of the block and the value mappings, so the output
1679 // args will correlate with the LHS and the inputs correlate with the RHS.
1680 for (const auto& [idx, val] : llvm::enumerate(op.init_values())) {
1681 signatureConverter.addInputs(
1682 idx + numOperands,
1683 typeConverter->convertType(
1684 val.getType().cast<ShapedType>().getElementType()));
1685 }
1686 for (const auto& [idx, val] : llvm::enumerate(op.operands())) {
1687 signatureConverter.addInputs(
1688 idx, typeConverter->convertType(
1689 val.getType().cast<ShapedType>().getElementType()));
1690 }
1691
1692 rewriter.applySignatureConversion(®ion, signatureConverter,
1693 getTypeConverter());
1694 rewriter.replaceOp(op, linalgOp.getResults());
1695 return success();
1696 }
1697 };
1698
1699 // Decomposes a pad with negative edge padding into a pad without negative edge
1700 // padding and a tensor.extract_slice.
1701 struct PadOpNegativePaddingConversion
1702 : public OpConversionPattern<mhlo::PadOp> {
1703 using OpConversionPattern::OpConversionPattern;
1704
matchAndRewritemlir::mhlo::__anon071fe3ef0111::PadOpNegativePaddingConversion1705 LogicalResult matchAndRewrite(
1706 mhlo::PadOp op, OpAdaptor adaptor,
1707 ConversionPatternRewriter& rewriter) const override {
1708 SmallVector<int64_t, 4> padLow;
1709 SmallVector<int64_t, 4> padHigh;
1710 SmallVector<OpFoldResult, 4> sliceStarts;
1711
1712 bool hasNegativePadding = false;
1713 for (int64_t low : op.edge_padding_low().getValues<int64_t>()) {
1714 if (low >= 0) {
1715 padLow.push_back(low);
1716 sliceStarts.push_back(rewriter.getIndexAttr(0));
1717 } else {
1718 padLow.push_back(0);
1719 sliceStarts.push_back(rewriter.getIndexAttr(-low));
1720 hasNegativePadding = true;
1721 }
1722 }
1723
1724 for (int64_t high : op.edge_padding_high().getValues<int64_t>()) {
1725 if (high >= 0) {
1726 padHigh.push_back(high);
1727 } else {
1728 padHigh.push_back(-high);
1729 hasNegativePadding = true;
1730 }
1731 }
1732
1733 // If there's no negative edge padding we're done.
1734 if (!hasNegativePadding) return failure();
1735
1736 // Create a new pad op with the positive values.
1737 Value pad = rewriter.create<mhlo::PadOp>(
1738 op.getLoc(), adaptor.operand(), adaptor.padding_value(),
1739 rewriter.getI64TensorAttr(padLow), rewriter.getI64TensorAttr(padHigh),
1740 op.interior_padding());
1741
1742 // Then slice according to the negative edge padding. Static shapes only for
1743 // now.
1744 if (!op.getType().hasStaticShape()) return failure();
1745 SmallVector<OpFoldResult, 4> sizes(llvm::map_range(
1746 op.getType().getShape(),
1747 [&](int64_t dim) { return rewriter.getIndexAttr(dim); }));
1748 SmallVector<OpFoldResult, 4> strides(sliceStarts.size(),
1749 rewriter.getIndexAttr(1));
1750 rewriter.replaceOpWithNewOp<tensor::ExtractSliceOp>(op, pad, sliceStarts,
1751 sizes, strides);
1752 return success();
1753 }
1754 };
1755
1756 /// Converts mhlo.pad operation to tensor.pad or tensor.insert_slice.
1757 struct PadOpConversion : public OpConversionPattern<mhlo::PadOp> {
1758 using OpConversionPattern<mhlo::PadOp>::OpConversionPattern;
1759
matchAndRewritemlir::mhlo::__anon071fe3ef0111::PadOpConversion1760 LogicalResult matchAndRewrite(
1761 mhlo::PadOp op, OpAdaptor adaptor,
1762 ConversionPatternRewriter& rewriter) const override {
1763 auto loc = op.getLoc();
1764 auto resultType = typeConverter->convertType(op.getResult().getType());
1765
1766 // Negative edge padding is decomposed separately.
1767 auto isNegative = [](const APInt& intVal) { return intVal.isNegative(); };
1768 if (llvm::any_of(op.edge_padding_low().getValues<APInt>(), isNegative) ||
1769 llvm::any_of(op.edge_padding_high().getValues<APInt>(), isNegative))
1770 return failure();
1771
1772 Value paddingVal =
1773 rewriter.createOrFold<tensor::ExtractOp>(loc, adaptor.padding_value());
1774
1775 SmallVector<OpFoldResult, 4> low(
1776 op.edge_padding_low().getValues<IntegerAttr>());
1777
1778 // If there is no interior padding lower to tensor.pad directly.
1779 if (llvm::all_of(op.interior_padding().getValues<APInt>(),
1780 [](const APInt& intVal) { return intVal.isZero(); })) {
1781 SmallVector<OpFoldResult, 4> high(
1782 op.edge_padding_high().getValues<IntegerAttr>());
1783 auto padTensorOp = tensor::createPadScalarOp(
1784 resultType, adaptor.operand(), paddingVal, low, high,
1785 /*nofold=*/false, loc, rewriter);
1786 rewriter.replaceOp(op, padTensorOp.getResult());
1787 return success();
1788 }
1789
1790 // We have interior padding, which can be lowered to tensor.insert_slice.
1791 // Start by filling a result-sized tensor with the pad value.
1792 auto initTensor =
1793 getInitTensorFor(rewriter, loc, resultType, op, adaptor.getOperands());
1794 auto fill =
1795 rewriter.create<linalg::FillOp>(loc, paddingVal, initTensor).result();
1796
1797 // Get sizes of the original operand.
1798 auto operandType = adaptor.operand().getType().cast<ShapedType>();
1799 auto sizes = llvm::to_vector<4>(llvm::map_range(
1800 llvm::seq<int64_t>(0, operandType.getRank()),
1801 [&](int64_t dim) -> OpFoldResult {
1802 if (!operandType.isDynamicDim(dim))
1803 return rewriter.getIndexAttr(operandType.getDimSize(dim));
1804 return rewriter.create<tensor::DimOp>(loc, adaptor.operand(), dim)
1805 .getResult();
1806 }));
1807 // Map interior padding to strides.
1808 auto strides = llvm::to_vector<4>(
1809 llvm::map_range(op.interior_padding().getValues<IntegerAttr>(),
1810 [&](IntegerAttr stride) -> OpFoldResult {
1811 return rewriter.getIntegerAttr(stride.getType(),
1812 stride.getValue() + 1);
1813 }));
1814
1815 rewriter.replaceOpWithNewOp<tensor::InsertSliceOp>(
1816 op, adaptor.operand(), fill, low, sizes, strides);
1817 return success();
1818 }
1819 };
1820
1821 // Apply dilation and padding to the input of a convolution.
applyConvolutionPadding(Location loc,Value input,DenseIntElementsAttr padding,DenseIntElementsAttr lhsDilation,llvm::ArrayRef<int64_t> dimMappings,OpBuilder & rewriter)1822 Value applyConvolutionPadding(Location loc, Value input,
1823 DenseIntElementsAttr padding,
1824 DenseIntElementsAttr lhsDilation,
1825 llvm::ArrayRef<int64_t> dimMappings,
1826 OpBuilder& rewriter) {
1827 if ((!padding || isSplatValue(padding, 0)) &&
1828 (!lhsDilation || isSplatValue(lhsDilation, 1)))
1829 return input;
1830
1831 auto inputType = input.getType().cast<ShapedType>();
1832 auto rank = inputType.getRank();
1833
1834 // Translate window padding into low/high padding.
1835 SmallVector<int64_t, 8> padLow(rank, 0);
1836 SmallVector<int64_t, 8> padHigh(rank, 0);
1837 if (padding) {
1838 // The padding attribute contains two values per dimension, but excludes the
1839 // batch and feature dimensions.
1840 assert(rank * 2 == padding.size() + 4 &&
1841 "There should be 2 padding values per dimension, i.e low and high.");
1842 for (auto i : llvm::seq<int64_t>(0, padding.size() / 2)) {
1843 auto dim = dimMappings[i];
1844 padLow[dim] = padding.getValues<int64_t>()[i * 2];
1845 padHigh[dim] = padding.getValues<int64_t>()[i * 2 + 1];
1846 }
1847 }
1848
1849 // Translate input dilation into interior padding.
1850 SmallVector<int64_t, 8> padInterior(rank, 0);
1851 if (lhsDilation) {
1852 assert(rank == lhsDilation.size() + 2);
1853 for (auto i : llvm::seq<int64_t>(0, lhsDilation.size())) {
1854 auto dim = dimMappings[i];
1855 padInterior[dim] = lhsDilation.getValues<int64_t>()[i] - 1;
1856 }
1857 }
1858
1859 auto indexType = rewriter.getIntegerType(64);
1860 auto attrType = RankedTensorType::get({rank}, indexType);
1861 Value zero = rewriter.create<arith::ConstantOp>(
1862 loc, rewriter.getZeroAttr(
1863 RankedTensorType::get({}, inputType.getElementType())));
1864 return rewriter.create<mhlo::PadOp>(
1865 loc, input, zero, DenseIntElementsAttr::get(attrType, padLow),
1866 DenseIntElementsAttr::get(attrType, padHigh),
1867 DenseIntElementsAttr::get(attrType, padInterior));
1868 }
1869
1870 /// Converts mhlo.conv operation to linalg named op. This only covers normal
1871 /// convolution cases. The op must have canonical dimension numbers. Depthwise
1872 /// convolution and pointwise convolution are not handled in the conversion.
1873 struct NormalConvolutionOpConversion
1874 : public OpConversionPattern<mhlo::ConvolutionOp> {
1875 using OpConversionPattern<mhlo::ConvolutionOp>::OpConversionPattern;
1876
matchAndRewritemlir::mhlo::__anon071fe3ef0111::NormalConvolutionOpConversion1877 LogicalResult matchAndRewrite(
1878 mhlo::ConvolutionOp op, OpAdaptor adaptor,
1879 ConversionPatternRewriter& rewriter) const override {
1880 if (!hasCanonicalDimensionNumbers(op.dimension_numbers())) return failure();
1881 if (op.feature_group_count() != 1u) return failure();
1882 if (op.batch_group_count() != 1u) return failure();
1883
1884 Location loc = op.getLoc();
1885 Value input = adaptor.lhs();
1886 Value filter = adaptor.rhs();
1887 auto resultType =
1888 typeConverter->convertType(op.getResult().getType()).cast<ShapedType>();
1889 int64_t rank = resultType.getRank();
1890
1891 // The output shape is N spatial_dims F.
1892 SmallVector<Value, 8> dynSizes;
1893 if (resultType.isDynamicDim(0)) {
1894 dynSizes.push_back(rewriter.create<tensor::DimOp>(loc, input, 0));
1895 }
1896 for (int64_t i = 1, e = rank - 1; i < e; ++i) {
1897 if (resultType.isDynamicDim(i)) {
1898 return rewriter.notifyMatchFailure(
1899 op, "expected output spatial dims to be static shapes");
1900 }
1901 }
1902 if (resultType.isDynamicDim(rank - 1)) {
1903 dynSizes.push_back(rewriter.create<tensor::DimOp>(loc, filter, rank - 1));
1904 }
1905 Value initTensor = rewriter.create<linalg::InitTensorOp>(
1906 loc, dynSizes, resultType.getShape(), resultType.getElementType());
1907 Value zeroTensor = fillTensorWithZeros(rewriter, loc, initTensor);
1908 linalg::LinalgOp res;
1909 Attribute strides = op.window_stridesAttr();
1910 Attribute dilations = op.rhs_dilationAttr();
1911
1912 // Apply padding and input dilation.
1913 llvm::SmallVector<int64_t> spatialDimMapping(rank - 2);
1914 std::iota(spatialDimMapping.begin(), spatialDimMapping.end(), 1);
1915 input = applyConvolutionPadding(loc, input, op.paddingAttr(),
1916 op.lhs_dilationAttr(), spatialDimMapping,
1917 rewriter);
1918
1919 switch (rank) {
1920 case 2: {
1921 res = rewriter.create<linalg::MatmulOp>(
1922 loc, resultType, ValueRange{input, filter}, ValueRange{zeroTensor},
1923 pruneAttributeList(op));
1924 break;
1925 }
1926 case 3: {
1927 res = rewriter.create<linalg::Conv1DNwcWcfOp>(
1928 loc, resultType, ValueRange{input, filter}, ValueRange{zeroTensor},
1929 strides, dilations, pruneAttributeList(op));
1930 break;
1931 }
1932 case 4: {
1933 res = rewriter.create<linalg::Conv2DNhwcHwcfOp>(
1934 loc, resultType, ValueRange{input, filter}, ValueRange{zeroTensor},
1935 strides, dilations, pruneAttributeList(op));
1936 break;
1937 }
1938 case 5: {
1939 res = rewriter.create<linalg::Conv3DNdhwcDhwcfOp>(
1940 loc, resultType, ValueRange{input, filter}, ValueRange{zeroTensor},
1941 strides, dilations, pruneAttributeList(op));
1942 break;
1943 }
1944 default:
1945 return rewriter.notifyMatchFailure(op, "expected 1/2/3D conv op");
1946 }
1947 rewriter.replaceOp(op, res.getOperation()->getResults());
1948 return success();
1949 }
1950 };
1951
1952 /// Handles all possible inputs for the mhlo::ConvolutionOp
1953 struct ConvolutionOpGeneralConversion
1954 : public OpConversionPattern<mhlo::ConvolutionOp> {
1955 using OpConversionPattern<mhlo::ConvolutionOp>::OpConversionPattern;
1956
1957 /// This lowering proceeds with the following steps:
1958 /// 1. Handle padding and dilation of the input
1959 /// 2. Handle padding and dilation of the window
1960 /// 3. Handle reversal of the window
1961 /// 4. If feature_group_count != 1:
1962 /// - Reshape the input feature dimension, kernel output feature dimension,
1963 /// and output feature dimension.
1964 /// - Create the AffineExpr for the new dimension
1965 /// - Conceptually, this splits the input feature and both output feature
1966 /// dimensions and computes sets of convolutions with these partial views
1967 /// of the values as if they were multiple convolutions combined in a
1968 /// batch.
1969 /// 5: If batch_group_count != 1:
1970 /// - Reshape the input batch dimension, kernel output feature dimension,
1971 /// and output feature dimension.
1972 /// - Create the AffineExpr for the new dimension
1973 /// - Conceptually, this splits the input batch and both output feature
1974 /// dimensions and computes sets of convolutions with these partial views
1975 /// of the values as if they were multiple convolutions combined in a
1976 /// batch.
1977 /// 6. For all dimensions not newly created by a reshape, create the
1978 /// appropriate parallel and reduction dimensions to create a convolution.
1979 /// 7. Create the linalg.generic that computes the multiply-add
1980 /// 8. Reshape the output to the original shape if it was reshaped by the
1981 /// feature or group count attributes.
matchAndRewritemlir::mhlo::__anon071fe3ef0111::ConvolutionOpGeneralConversion1982 LogicalResult matchAndRewrite(
1983 mhlo::ConvolutionOp op, OpAdaptor adaptor,
1984 ConversionPatternRewriter& rewriter) const override {
1985 auto loc = op.getLoc();
1986 auto* ctx = op.getContext();
1987
1988 auto resultType =
1989 typeConverter->convertType(op.getResult().getType()).cast<ShapedType>();
1990 auto reshapedResultShape = resultType.getShape().vec();
1991 if (!resultType.hasStaticShape()) return failure();
1992
1993 auto dimensionNumbers = op.dimension_numbers();
1994 auto inputBatchDimension = dimensionNumbers.getInputBatchDimension();
1995 auto inputFeatureDimension = dimensionNumbers.getInputFeatureDimension();
1996 auto inputSpatialDimensions = dimensionNumbers.getInputSpatialDimensions();
1997
1998 auto kernelInputFeatureDimension =
1999 dimensionNumbers.getKernelInputFeatureDimension();
2000 auto kernelOutputFeatureDimension =
2001 dimensionNumbers.getKernelOutputFeatureDimension();
2002 auto kernelSpatialDimensions =
2003 dimensionNumbers.getKernelSpatialDimensions();
2004
2005 auto outputFeatureDimension = dimensionNumbers.getOutputFeatureDimension();
2006 auto outputSpatialDimensions =
2007 dimensionNumbers.getOutputSpatialDimensions();
2008
2009 auto featureGroupCount = op.feature_group_count();
2010 auto batchGroupCount = op.batch_group_count();
2011
2012 if (op.feature_group_count() != 1 && op.batch_group_count() != 1) {
2013 return rewriter.notifyMatchFailure(
2014 op, "only one of feature and batch group counts can be non-one");
2015 }
2016
2017 // Decompose the convolution into an initial padding
2018 Value modifiedLhs = applyConvolutionPadding(
2019 op.getLoc(), adaptor.lhs(), adaptor.paddingAttr(),
2020 adaptor.lhs_dilationAttr(),
2021 op.dimension_numbers().getInputSpatialDimensions(), rewriter);
2022 Value modifiedRhs = applyConvolutionPadding(
2023 op.getLoc(), adaptor.rhs(), nullptr, adaptor.rhs_dilationAttr(),
2024 op.dimension_numbers().getKernelSpatialDimensions(), rewriter);
2025
2026 // Decompose the reversal dims into its own step
2027 auto reversals = op.window_reversal();
2028 if (reversals.value()) {
2029 llvm::SmallVector<int64_t> reversedDims;
2030 for (auto& idxAndBool :
2031 llvm::enumerate(reversals.value().getValues<bool>()))
2032 if (idxAndBool.value())
2033 reversedDims.push_back(
2034 op.dimension_numbers()
2035 .getKernelSpatialDimensions()[idxAndBool.index()]);
2036
2037 modifiedRhs = rewriter.create<mhlo::ReverseOp>(
2038 loc, modifiedRhs,
2039 mlir::DenseIntElementsAttr::get(
2040 RankedTensorType::get(reversedDims.size(),
2041 rewriter.getIntegerType(64)),
2042 reversedDims));
2043 }
2044
2045 // Non-one values for feature or batch group counts will result in reshaped
2046 // inputs and outputs. These mappings are used to keep track of the the new
2047 // index after reshaping has possibly inserted new dimensions.
2048 auto paddedLhsType = modifiedLhs.getType().cast<ShapedType>();
2049 auto paddedRhsType = modifiedRhs.getType().cast<ShapedType>();
2050 SmallVector<int64_t> lhsIndexMapping(paddedLhsType.getRank());
2051 std::iota(lhsIndexMapping.begin(), lhsIndexMapping.end(), 0);
2052 SmallVector<int64_t> rhsIndexMapping(paddedRhsType.getRank());
2053 std::iota(rhsIndexMapping.begin(), rhsIndexMapping.end(), 0);
2054 SmallVector<int64_t> resultIndexMapping(resultType.getRank());
2055 std::iota(resultIndexMapping.begin(), resultIndexMapping.end(), 0);
2056 auto updateDimMappingFromOffset =
2057 [](llvm::SmallVectorImpl<int64_t>& mapping, int64_t offset) {
2058 for (auto i = offset; i < mapping.size(); ++i) {
2059 mapping[i] += 1;
2060 }
2061 };
2062
2063 // The rest of this code prepares the inputs and a single linalg::GenericOp
2064 // to execute the convolution. The final linalg::GenericOp will be iterated
2065 // through based on the following eventual maps.
2066 SmallVector<AffineExpr, 2> srcExprs(paddedLhsType.getRank());
2067 SmallVector<AffineExpr, 2> windowExprs(paddedRhsType.getRank());
2068 SmallVector<AffineExpr, 2> dstExprs(reshapedResultShape.size());
2069 int64_t nextDim = 0;
2070 int64_t rank = resultType.getRank();
2071
2072 auto reshapeShapeVector = [](llvm::ArrayRef<int64_t> oldShape,
2073 llvm::SmallVectorImpl<int64_t>& newShape,
2074 int64_t reshapedDim, int64_t factor) {
2075 newShape.reserve(oldShape.size() + 1);
2076 for (int i = 0; i < oldShape.size(); ++i) {
2077 if (i == reshapedDim) {
2078 newShape.push_back(factor);
2079 newShape.push_back(oldShape[reshapedDim] / factor);
2080 } else {
2081 newShape.push_back(oldShape[i]);
2082 }
2083 }
2084 };
2085
2086 // If batch or feature count groupings exist, represent this through
2087 // reshaping the input to have an additional dimension that these groupings
2088 // exist along, and reduce in that dimension
2089 SmallVector<StringRef, 3> iterationLoops;
2090 if (featureGroupCount != 1) {
2091 auto parallelDim = mlir::getAffineDimExpr(nextDim++, ctx);
2092 iterationLoops.push_back(getParallelIteratorTypeName());
2093 // Reshape LHS
2094 {
2095 srcExprs.insert(srcExprs.begin() + inputFeatureDimension, parallelDim);
2096 auto prevDimsRef = paddedLhsType.getShape();
2097 llvm::SmallVector<int64_t> newShape;
2098 reshapeShapeVector(prevDimsRef, newShape, inputFeatureDimension,
2099 featureGroupCount);
2100 updateDimMappingFromOffset(lhsIndexMapping, inputFeatureDimension);
2101 modifiedLhs = rewriter.create<mhlo::ReshapeOp>(
2102 op.getLoc(),
2103 RankedTensorType::get(newShape, paddedLhsType.getElementType()),
2104 modifiedLhs);
2105 }
2106
2107 // Reshape RHS
2108 {
2109 windowExprs.insert(windowExprs.begin() + kernelOutputFeatureDimension,
2110 parallelDim);
2111 auto prevDimsRef = paddedRhsType.getShape();
2112 llvm::SmallVector<int64_t> newShape;
2113 reshapeShapeVector(prevDimsRef, newShape, kernelOutputFeatureDimension,
2114 featureGroupCount);
2115 updateDimMappingFromOffset(rhsIndexMapping,
2116 kernelOutputFeatureDimension);
2117 modifiedRhs = rewriter.create<mhlo::ReshapeOp>(
2118 op.getLoc(),
2119 RankedTensorType::get(newShape, paddedRhsType.getElementType()),
2120 modifiedRhs);
2121 }
2122 // Prepare reshaped output shape
2123 {
2124 dstExprs.insert(dstExprs.begin() + outputFeatureDimension, parallelDim);
2125 updateDimMappingFromOffset(resultIndexMapping, outputFeatureDimension);
2126 reshapedResultShape.insert(
2127 reshapedResultShape.begin() + outputFeatureDimension,
2128 featureGroupCount);
2129 reshapedResultShape[outputFeatureDimension + 1] /= featureGroupCount;
2130 }
2131 }
2132
2133 if (batchGroupCount != 1) {
2134 iterationLoops.push_back(getParallelIteratorTypeName());
2135 auto parallelDim = mlir::getAffineDimExpr(nextDim++, ctx);
2136 // Reshape LHS
2137 {
2138 srcExprs.insert(srcExprs.begin() + inputBatchDimension, parallelDim);
2139 auto prevDimsRef = paddedLhsType.getShape();
2140 llvm::SmallVector<int64_t> newShape;
2141 reshapeShapeVector(prevDimsRef, newShape, inputBatchDimension,
2142 batchGroupCount);
2143 updateDimMappingFromOffset(lhsIndexMapping, inputBatchDimension);
2144 modifiedLhs = rewriter.create<mhlo::ReshapeOp>(
2145 op.getLoc(),
2146 RankedTensorType::get(newShape, paddedLhsType.getElementType()),
2147 modifiedLhs);
2148 }
2149
2150 // Reshape RHS
2151 {
2152 windowExprs.insert(windowExprs.begin() + kernelOutputFeatureDimension,
2153 parallelDim);
2154 auto prevDimsRef = paddedRhsType.getShape();
2155 llvm::SmallVector<int64_t> newShape;
2156 reshapeShapeVector(prevDimsRef, newShape, kernelOutputFeatureDimension,
2157 batchGroupCount);
2158 updateDimMappingFromOffset(rhsIndexMapping,
2159 kernelOutputFeatureDimension);
2160 modifiedRhs = rewriter.create<mhlo::ReshapeOp>(
2161 op.getLoc(),
2162 RankedTensorType::get(newShape, paddedRhsType.getElementType()),
2163 modifiedRhs);
2164 }
2165 // Prepare reshaped output shape
2166 {
2167 auto outputFeatureDim = resultIndexMapping[outputFeatureDimension];
2168 dstExprs.insert(dstExprs.begin() + outputFeatureDim, parallelDim);
2169 updateDimMappingFromOffset(resultIndexMapping, outputFeatureDimension);
2170 reshapedResultShape.insert(
2171 reshapedResultShape.begin() + outputFeatureDim, batchGroupCount);
2172 reshapedResultShape[outputFeatureDim + 1] /= batchGroupCount;
2173 }
2174 }
2175
2176 // Handle input feature dimension
2177 {
2178 iterationLoops.push_back(getReductionIteratorTypeName());
2179 auto inputFeatureDim = mlir::getAffineDimExpr(nextDim++, ctx);
2180 srcExprs[lhsIndexMapping[inputFeatureDimension]] = inputFeatureDim;
2181 windowExprs[rhsIndexMapping[kernelInputFeatureDimension]] =
2182 inputFeatureDim;
2183 }
2184
2185 // Handle output feature dimension
2186 {
2187 iterationLoops.push_back(getParallelIteratorTypeName());
2188 auto outputFeatureDim = mlir::getAffineDimExpr(nextDim++, ctx);
2189 dstExprs[resultIndexMapping[outputFeatureDimension]] = outputFeatureDim;
2190 windowExprs[rhsIndexMapping[kernelOutputFeatureDimension]] =
2191 outputFeatureDim;
2192 }
2193
2194 // Handle spatial Dimensions
2195 int64_t numSpatialDims = rank - 2;
2196 for (int64_t i = 0; i < numSpatialDims; i++) {
2197 iterationLoops.push_back(getParallelIteratorTypeName());
2198 iterationLoops.push_back(getReductionIteratorTypeName());
2199 auto dim0 = mlir::getAffineDimExpr(nextDim++, ctx);
2200 auto dim1 = mlir::getAffineDimExpr(nextDim++, ctx);
2201
2202 auto stride = dim0;
2203 if (op.window_strides().value())
2204 stride = stride * op.window_strides().value().getValues<int64_t>()[i];
2205 AffineExpr srcExpr = stride + dim1;
2206
2207 srcExprs[lhsIndexMapping[inputSpatialDimensions[i]]] = srcExpr;
2208 dstExprs[resultIndexMapping[outputSpatialDimensions[i]]] = dim0;
2209 windowExprs[rhsIndexMapping[kernelSpatialDimensions[i]]] = dim1;
2210 }
2211
2212 // Handle batch dimension
2213 {
2214 iterationLoops.push_back(getParallelIteratorTypeName());
2215 auto batchDim = mlir::getAffineDimExpr(nextDim++, ctx);
2216
2217 srcExprs[lhsIndexMapping[inputBatchDimension]] = batchDim;
2218 dstExprs[resultIndexMapping[inputBatchDimension]] = batchDim;
2219 }
2220
2221 // Finally, create the computation
2222 auto inferredMaps =
2223 AffineMap::inferFromExprList({srcExprs, windowExprs, dstExprs});
2224
2225 Value initTensor = rewriter.create<linalg::InitTensorOp>(
2226 loc, reshapedResultShape, resultType.getElementType());
2227 Value zeroTensor = fillTensorWithZeros(rewriter, loc, initTensor);
2228
2229 Value convolved =
2230 rewriter
2231 .create<linalg::GenericOp>(
2232 loc,
2233 /*resultTensors=*/
2234 llvm::makeArrayRef<Type>(zeroTensor.getType()),
2235 /*inputs=*/
2236 llvm::makeArrayRef<Value>({modifiedLhs, modifiedRhs}),
2237 /*outputs=*/llvm::makeArrayRef<Value>(zeroTensor), inferredMaps,
2238 iterationLoops,
2239 /*bodyBuild=*/
2240 [&](OpBuilder& nestedBuilder, Location nestedLoc, ValueRange) {
2241 ImplicitLocOpBuilder builder(nestedLoc, nestedBuilder);
2242 linalg::Conv2DOp::regionBuilder(
2243 builder, *builder.getInsertionBlock(), {});
2244 },
2245 pruneAttributeList(op))
2246 .getResult(0);
2247 rewriter.replaceOpWithNewOp<mhlo::ReshapeOp>(op, resultType, convolved);
2248
2249 return success();
2250 }
2251 };
2252
2253 /// Converts mhlo.convolution operation to
2254 /// linalg.depthwise_conv_2d_input_nhwc_filter_hwcf op or
2255 /// depthwise_conv_2d_input_nhwc_filter_hwc op.
2256 struct DepthwiseConvolutionOpConversion
2257 : public OpConversionPattern<mhlo::ConvolutionOp> {
2258 using OpConversionPattern<mhlo::ConvolutionOp>::OpConversionPattern;
2259
matchAndRewritemlir::mhlo::__anon071fe3ef0111::DepthwiseConvolutionOpConversion2260 LogicalResult matchAndRewrite(
2261 mhlo::ConvolutionOp op, OpAdaptor adaptor,
2262 ConversionPatternRewriter& rewriter) const override {
2263 if (op.batch_group_count() != 1) return failure();
2264 // Fall into the normal convolution cases.
2265 if (op.feature_group_count() == 1) return failure();
2266
2267 const mhlo::ConvDimensionNumbersAttr& dimensionNumbers =
2268 op.dimension_numbers();
2269 const auto spatialRank =
2270 llvm::size(dimensionNumbers.getInputSpatialDimensions());
2271 if (spatialRank == 0 || spatialRank > 3) {
2272 return rewriter.notifyMatchFailure(op, "only support up to 3D for now");
2273 }
2274
2275 // Make sure that this is depthwise convolution.
2276 int64_t inputFeatureDim = dimensionNumbers.getInputFeatureDimension();
2277 int64_t inputFeatureCount =
2278 op.lhs().getType().cast<ShapedType>().getDimSize(inputFeatureDim);
2279 if (op.feature_group_count() != inputFeatureCount) {
2280 return rewriter.notifyMatchFailure(op, "not depth-wise convolution");
2281 }
2282
2283 // Make sure that this convolution has a canonical form.
2284 if (!hasCanonicalDimensionNumbers(dimensionNumbers)) {
2285 return rewriter.notifyMatchFailure(op, "does not have canonical form");
2286 }
2287
2288 Attribute windowStrides;
2289 if (op.window_strides()) {
2290 windowStrides = op.window_strides().value();
2291 } else {
2292 windowStrides = SplatElementsAttr::get(
2293 VectorType::get({spatialRank}, rewriter.getI64Type()),
2294 rewriter.getI64IntegerAttr(1));
2295 }
2296
2297 Attribute rhsDilation;
2298 if (op.rhs_dilation()) {
2299 rhsDilation = op.rhs_dilation().value();
2300 } else {
2301 rhsDilation = SplatElementsAttr::get(
2302 VectorType::get({spatialRank}, rewriter.getI64Type()),
2303 rewriter.getI64IntegerAttr(1));
2304 }
2305
2306 Location loc = op.getLoc();
2307 Value input = adaptor.lhs();
2308 Value filter = adaptor.rhs();
2309 auto resultType = typeConverter->convertType(op.getResult().getType())
2310 .cast<RankedTensorType>();
2311 if (!resultType.hasStaticShape()) {
2312 return rewriter.notifyMatchFailure(op,
2313 "expected output has static shapes");
2314 }
2315
2316 // Apply padding and input dilation.
2317 llvm::SmallVector<int64_t> spatialDimMapping(spatialRank);
2318 std::iota(spatialDimMapping.begin(), spatialDimMapping.end(), 1);
2319 input = applyConvolutionPadding(loc, input, op.paddingAttr(),
2320 op.lhs_dilationAttr(), spatialDimMapping,
2321 rewriter);
2322
2323 auto filterDims =
2324 llvm::to_vector<4>(op.rhs().getType().cast<ShapedType>().getShape());
2325
2326 auto getReassociationIndicesToCollapseLastTwoDims = [](Value v) {
2327 SmallVector<ReassociationIndices> reassociations;
2328 int64_t rank = v.getType().cast<ShapedType>().getRank();
2329 for (int64_t i = 0; i < rank - 1; ++i) reassociations.emplace_back(1, i);
2330 reassociations.back().push_back(rank - 1);
2331 return reassociations;
2332 };
2333
2334 int64_t kernelInputFeatureDimension =
2335 dimensionNumbers.getKernelInputFeatureDimension();
2336 int64_t kernelOutputFeatureDimension =
2337 dimensionNumbers.getKernelOutputFeatureDimension();
2338 if (filterDims[kernelInputFeatureDimension] *
2339 filterDims[kernelOutputFeatureDimension] !=
2340 op.feature_group_count()) {
2341 // For cases where channel multiplier != 1
2342
2343 // Reshaping filter shape
2344 // [filter_height, filter_width, 1, kernel-output-feature].
2345 // to
2346 // [filter_height, filter_width, feature_group_count,
2347 // kernel-output-feature/feature_group_count ]
2348 SmallVector<int64_t> reshapedFilterDims;
2349 reshapedFilterDims.assign(filterDims.begin(), filterDims.end());
2350 auto reshapedFilter = filter;
2351 if (filterDims[kernelInputFeatureDimension] == 1) {
2352 reshapedFilterDims[kernelInputFeatureDimension] =
2353 op.feature_group_count();
2354 reshapedFilterDims[kernelOutputFeatureDimension] /=
2355 op.feature_group_count();
2356 auto reshapedFilterType = RankedTensorType::get(
2357 reshapedFilterDims,
2358 op.rhs().getType().cast<RankedTensorType>().getElementType());
2359
2360 reshapedFilter =
2361 rewriter.create<mhlo::ReshapeOp>(loc, reshapedFilterType, filter);
2362 }
2363
2364 auto outputDims = resultType.getShape();
2365 auto channelMultiplier = reshapedFilterDims.back();
2366 SmallVector<int64_t> reshapedOutputDims;
2367 reshapedOutputDims.assign(outputDims.begin(), outputDims.end());
2368 reshapedOutputDims.push_back(channelMultiplier);
2369 reshapedOutputDims[reshapedOutputDims.size() - 2] /= channelMultiplier;
2370
2371 Value initTensor = rewriter.create<linalg::InitTensorOp>(
2372 loc, reshapedOutputDims, resultType.getElementType());
2373 Value zeroTensor = fillTensorWithZeros(rewriter, loc, initTensor);
2374
2375 auto reshapedOutputType = RankedTensorType::get(
2376 reshapedOutputDims, resultType.getElementType());
2377 Value conv;
2378 switch (spatialRank) {
2379 case 1:
2380 conv =
2381 rewriter
2382 .create<linalg::DepthwiseConv1DNwcWcmOp>(
2383 loc, reshapedOutputType,
2384 ValueRange{input, reshapedFilter}, ValueRange{zeroTensor},
2385 windowStrides, rhsDilation, pruneAttributeList(op))
2386 .getResult(0);
2387 break;
2388 case 2:
2389 conv =
2390 rewriter
2391 .create<linalg::DepthwiseConv2DNhwcHwcmOp>(
2392 loc, reshapedOutputType,
2393 ValueRange{input, reshapedFilter}, ValueRange{zeroTensor},
2394 windowStrides, rhsDilation, pruneAttributeList(op))
2395 .getResult(0);
2396 break;
2397 case 3:
2398 conv =
2399 rewriter
2400 .create<linalg::DepthwiseConv3DNdhwcDhwcmOp>(
2401 loc, reshapedOutputType,
2402 ValueRange{input, reshapedFilter}, ValueRange{zeroTensor},
2403 windowStrides, rhsDilation, pruneAttributeList(op))
2404 .getResult(0);
2405 break;
2406 }
2407
2408 // Create a Linalg reshape op that converts the output from 5 dimensions
2409 // into 4 dimensions (by collapsing the last two dimensions). This is
2410 // needed because linalg.depthwise_conv_2d_input_nhwc_filter_hwcf returns
2411 // 5 dimensions for the output.
2412 rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>(
2413 op, resultType, conv,
2414 getReassociationIndicesToCollapseLastTwoDims(conv));
2415 } else {
2416 // For cases where channel multiplier == 1
2417 Value initTensor = rewriter.create<linalg::InitTensorOp>(
2418 loc, resultType.getShape(), resultType.getElementType());
2419 Value zeroTensor = fillTensorWithZeros(rewriter, loc, initTensor);
2420
2421 // Create a Linalg reshape op that converts the filter from 4 dimensions
2422 // into 3 dimensions (by droping the unit dimension). This is needed
2423 // because linalg.depthwise_conv_2d_input_nhwc_filter_hwc expects 3
2424 // dimensions for the filter.
2425
2426 filterDims[filterDims.size() - 2] =
2427 static_cast<int64_t>(op.feature_group_count());
2428 filterDims.pop_back();
2429
2430 RankedTensorType filterShape =
2431 RankedTensorType::get(filterDims, op.getType().getElementType());
2432
2433 Value reshapedFilter = rewriter.create<tensor::CollapseShapeOp>(
2434 loc, filterShape, filter,
2435 getReassociationIndicesToCollapseLastTwoDims(filter));
2436
2437 switch (spatialRank) {
2438 case 1:
2439 rewriter.replaceOpWithNewOp<linalg::DepthwiseConv1DNwcWcOp>(
2440 op, resultType, ValueRange{input, reshapedFilter},
2441 ValueRange{zeroTensor}, windowStrides, rhsDilation,
2442 pruneAttributeList(op));
2443 break;
2444 case 2:
2445 rewriter.replaceOpWithNewOp<linalg::DepthwiseConv2DNhwcHwcOp>(
2446 op, resultType, ValueRange{input, reshapedFilter},
2447 ValueRange{zeroTensor}, windowStrides, rhsDilation,
2448 pruneAttributeList(op));
2449 break;
2450 case 3:
2451 rewriter.replaceOpWithNewOp<linalg::DepthwiseConv3DNdhwcDhwcOp>(
2452 op, resultType, ValueRange{input, reshapedFilter},
2453 ValueRange{zeroTensor}, windowStrides, rhsDilation,
2454 pruneAttributeList(op));
2455 break;
2456 }
2457 }
2458
2459 return success();
2460 }
2461 };
2462
2463 struct ReduceWindowOpOnTensorsGenericConversion
2464 : public OpConversionPattern<mhlo::ReduceWindowOp> {
2465 using OpConversionPattern<mhlo::ReduceWindowOp>::OpConversionPattern;
matchAndRewritemlir::mhlo::__anon071fe3ef0111::ReduceWindowOpOnTensorsGenericConversion2466 LogicalResult matchAndRewrite(
2467 mhlo::ReduceWindowOp op, OpAdaptor adaptor,
2468 ConversionPatternRewriter& rewriter) const override {
2469 MLIRContext* ctx = op->getContext();
2470 Location loc = op.getLoc();
2471 llvm::SmallVector<Value> initValues = adaptor.init_values();
2472 llvm::SmallVector<Type> resultTypes = llvm::to_vector(op.getResultTypes());
2473 auto numOperands = initValues.size();
2474
2475 llvm::SmallVector<int64_t> windowDimensions =
2476 extract1DVector(op.window_dimensions());
2477
2478 llvm::SmallVector<int64_t> padding;
2479 if (op.padding()) {
2480 padding = extract1DVector(*op.padding());
2481 }
2482
2483 llvm::SmallVector<int64_t> baseDilations;
2484 if (op.base_dilations()) {
2485 baseDilations = extract1DVector(*op.base_dilations());
2486 }
2487
2488 llvm::SmallVector<int64_t> windowStrides(windowDimensions.size(), 1);
2489 if (op.window_strides()) {
2490 windowStrides = extract1DVector(*op.window_strides());
2491 }
2492
2493 llvm::SmallVector<int64_t> windowDilations(windowDimensions.size(), 1);
2494 if (op.window_dilations()) {
2495 windowDilations = extract1DVector(*op.window_dilations());
2496 }
2497
2498 auto rank = static_cast<int64_t>(windowDimensions.size());
2499 SmallVector<AffineExpr, 2> srcExprs;
2500 SmallVector<AffineExpr, 2> windowExprs;
2501 SmallVector<AffineExpr, 2> dstExprs;
2502 SmallVector<int64_t> filteredWindowDims;
2503
2504 int windowDim = 0;
2505 for (int64_t i = 0; i < rank; i++) {
2506 AffineExpr srcExpr = mlir::getAffineDimExpr(i, ctx);
2507
2508 if (windowStrides[i] != 1) srcExpr = srcExpr * windowStrides[i];
2509
2510 if (windowDimensions[i] != 1) {
2511 filteredWindowDims.push_back(windowDimensions[i]);
2512 AffineExpr windowExpr = mlir::getAffineDimExpr(rank + windowDim, ctx);
2513 windowExprs.push_back(windowExpr);
2514
2515 if (windowDilations[i] != 1)
2516 windowExpr = windowExpr * windowDilations[i];
2517
2518 srcExpr = srcExpr + windowExpr;
2519 windowDim++;
2520 }
2521
2522 srcExprs.push_back(srcExpr);
2523 dstExprs.push_back(mlir::getAffineDimExpr(i, ctx));
2524 }
2525
2526 SmallVector<AffineMap, 4> inferredMaps(3, AffineMap::get(ctx));
2527 if (rank > 0)
2528 inferredMaps =
2529 AffineMap::inferFromExprList({srcExprs, windowExprs, dstExprs});
2530
2531 SmallVector<AffineMap, 4> indexingMaps;
2532
2533 indexingMaps.append(numOperands, inferredMaps[0]);
2534 indexingMaps.append(1, inferredMaps[1]);
2535 indexingMaps.append(numOperands, inferredMaps[2]);
2536
2537 // Setup the initial values.
2538 llvm::SmallVector<Value> broadcastValues;
2539 for (uint64_t i = 0, s = initValues.size(); i < s; i++) {
2540 Value initValue = initValues[i];
2541 auto resultTy = resultTypes[i].cast<ShapedType>();
2542 if (!resultTy.hasStaticShape()) return failure();
2543
2544 auto broadcastSizes = rewriter.getI64TensorAttr(resultTy.getShape());
2545 broadcastValues.push_back(rewriter.create<mhlo::BroadcastOp>(
2546 loc, resultTy, initValue, broadcastSizes));
2547 }
2548
2549 llvm::SmallVector<Value> inputs = llvm::to_vector(adaptor.operands());
2550
2551 // Pad as necessary.
2552 if (llvm::any_of(padding, [](int64_t v) { return v != 0; }) ||
2553 llvm::any_of(baseDilations, [](int64_t v) { return v != 1; })) {
2554 llvm::SmallVector<int64_t> staticLows(rank, 0);
2555 llvm::SmallVector<int64_t> staticHighs(rank, 0);
2556 for (int i = 0; i < padding.size(); i += 2) {
2557 staticLows[i / 2] = padding[i];
2558 staticHighs[i / 2] = padding[i + 1];
2559 }
2560 // Translate base dilation into interior padding.
2561 llvm::SmallVector<int64_t> staticInteriors(rank, 0);
2562 for (const auto& dilation : llvm::enumerate(baseDilations)) {
2563 staticInteriors[dilation.index()] = dilation.value() - 1;
2564 }
2565
2566 auto padAttrType =
2567 RankedTensorType::get({rank}, rewriter.getIntegerType(64));
2568 auto padLows = DenseIntElementsAttr::get(padAttrType, staticLows);
2569 auto padHighs = DenseIntElementsAttr::get(padAttrType, staticHighs);
2570 auto padInteriors =
2571 DenseIntElementsAttr::get(padAttrType, staticInteriors);
2572
2573 for (auto values : llvm::zip(inputs, initValues)) {
2574 auto& input = std::get<0>(values);
2575 auto& initValue = std::get<1>(values);
2576 input = rewriter.create<mhlo::PadOp>(loc, input, initValue, padLows,
2577 padHighs, padInteriors);
2578 }
2579 }
2580
2581 // Add the extra input for the reduction dimension.
2582 inputs.push_back(rewriter.create<linalg::InitTensorOp>(
2583 loc, filteredWindowDims, rewriter.getF32Type()));
2584
2585 auto linalgOp = rewriter.create<linalg::GenericOp>(
2586 loc, /*resultTensors=*/resultTypes,
2587 /*inputs=*/inputs,
2588 /*outputs=*/broadcastValues, indexingMaps,
2589 getParallelAndReductionIterators(rank + filteredWindowDims.size(),
2590 filteredWindowDims.size()),
2591 /*bodyBuild=*/nullptr, pruneAttributeList(op));
2592
2593 // Convert the signature of the body. This includes converting scalar
2594 // tensors to their scalar values and inserting an additional block arg for
2595 // the window arg.
2596 Region& region = linalgOp.region();
2597 rewriter.cloneRegionBefore(op.body(), region, region.end());
2598
2599 TypeConverter::SignatureConversion signatureConverter(
2600 inputs.size() + op->getNumResults() - 1);
2601
2602 // ReduceWindow requires that the seed be used as a LHS operand inside the
2603 // region, and the seed is encoded in linalg in the intial out value, so
2604 // modify the signature of the block and the value mappings, so the output
2605 // args will correlate with the LHS and the inputs correlate with the RHS.
2606 for (const auto& [i, type] : llvm::enumerate(resultTypes)) {
2607 auto idx = inputs.size() + i - 1;
2608 signatureConverter.addInputs(idx,
2609 type.cast<ShapedType>().getElementType());
2610 }
2611
2612 signatureConverter.addInputs(
2613 inputs.back().getType().cast<ShapedType>().getElementType());
2614
2615 for (const auto& [i, input] :
2616 llvm::enumerate(ArrayRef<Value>(inputs).drop_back())) {
2617 signatureConverter.addInputs(
2618 i, input.getType().cast<ShapedType>().getElementType());
2619 }
2620
2621 rewriter.applySignatureConversion(®ion, signatureConverter,
2622 getTypeConverter());
2623 rewriter.replaceOp(op, linalgOp.getResults());
2624 return success();
2625 }
2626 };
2627
2628 struct ReduceWindowOpConversion
2629 : public OpConversionPattern<mhlo::ReduceWindowOp> {
2630 using OpConversionPattern<mhlo::ReduceWindowOp>::OpConversionPattern;
2631
2632 /// mhlo.reduce_window is mapped to a linalg.pooling operation. The type of
2633 /// the pooling is determined based on the body of the reduce window
2634 /// operation. This class enumerates the different variants.
2635 enum class PoolingType {
2636 kInvalid,
2637 k2DMin,
2638 k3DMin,
2639 k2DMax,
2640 k3DMax,
2641 k2DAdd,
2642 k3DAdd,
2643 };
2644
getPoolingTypemlir::mhlo::__anon071fe3ef0111::ReduceWindowOpConversion2645 static PoolingType getPoolingType(mhlo::ReduceWindowOp reduceOp,
2646 int resultIndex) {
2647 auto rank =
2648 reduceOp.getResultTypes()[resultIndex].cast<ShapedType>().getRank();
2649 if (Operation* op = reduceOp.getReductionOp(resultIndex)) {
2650 if (isa<mhlo::MinOp>(*op) && rank == 4) return PoolingType::k2DMin;
2651 if (isa<mhlo::MinOp>(*op) && rank == 5) return PoolingType::k3DMin;
2652 if (isa<mhlo::MaxOp>(*op) && rank == 4) return PoolingType::k2DMax;
2653 if (isa<mhlo::MaxOp>(*op) && rank == 5) return PoolingType::k3DMax;
2654 if (isa<mhlo::AddOp>(*op) && rank == 4) return PoolingType::k2DAdd;
2655 if (isa<mhlo::AddOp>(*op) && rank == 5) return PoolingType::k3DAdd;
2656 }
2657 return PoolingType::kInvalid;
2658 }
2659
matchAndRewritemlir::mhlo::__anon071fe3ef0111::ReduceWindowOpConversion2660 LogicalResult matchAndRewrite(
2661 mhlo::ReduceWindowOp op, OpAdaptor adaptor,
2662 ConversionPatternRewriter& rewriter) const override {
2663 auto loc = op.getLoc();
2664 int rank = op.getResultTypes()[0].cast<ShapedType>().getRank();
2665 if (rank != 4 && rank != 5) {
2666 return rewriter.notifyMatchFailure(
2667 op, "expected NHWC/NDHWC pooling-based op");
2668 }
2669
2670 if (op.padding() && !isSplatValue(*op.padding(), 0)) {
2671 return rewriter.notifyMatchFailure(op, "require paddings are all zero");
2672 }
2673
2674 if (op.base_dilations() && !isSplatValue(*op.base_dilations(), 1)) {
2675 return rewriter.notifyMatchFailure(op, "expected undilated base");
2676 }
2677
2678 int lastDim = rank - 1;
2679 SmallVector<int64_t, 2> fakeWindowShapes;
2680 for (int i = 1; i < lastDim; ++i) {
2681 fakeWindowShapes.push_back(
2682 op.window_dimensions().getValues<int64_t>()[i]);
2683 }
2684
2685 if (op.window_strides() &&
2686 (op.window_strides().value().getValues<int64_t>()[0] != 1 ||
2687 op.window_strides().value().getValues<int64_t>()[lastDim] != 1)) {
2688 return rewriter.notifyMatchFailure(
2689 op, "expected window_strides to be [1,x,y,(z),1]");
2690 }
2691 if (op.window_dimensions() &&
2692 (op.window_dimensions().getValues<int64_t>()[0] != 1 ||
2693 op.window_dimensions().getValues<int64_t>()[lastDim] != 1)) {
2694 return rewriter.notifyMatchFailure(
2695 op, "expected window_dimensions to be [1,x,y,(z),1]");
2696 }
2697
2698 Attribute strides;
2699 SmallVector<int64_t> vec;
2700 if (op.window_stridesAttr()) {
2701 for (int i = 1; i < lastDim; ++i) {
2702 vec.push_back(op.window_strides().value().getValues<int64_t>()[i]);
2703 }
2704 } else {
2705 vec.assign(rank - 2, 1);
2706 }
2707 strides = rewriter.getI64VectorAttr(vec);
2708
2709 Attribute dilations;
2710 vec.clear();
2711 if (op.window_dilations()) {
2712 for (int i = 1; i < lastDim; ++i) {
2713 vec.push_back(op.window_dilations().value().getValues<int64_t>()[i]);
2714 }
2715 } else {
2716 vec.assign(rank - 2, 1);
2717 }
2718 dilations = rewriter.getI64VectorAttr(vec);
2719
2720 SmallVector<Value> poolingOps;
2721
2722 ValueRange operands = adaptor.operands();
2723 ValueRange initValues = adaptor.init_values();
2724 for (auto it : llvm::zip(op.getResults(), operands, initValues)) {
2725 OpResult result = std::get<0>(it);
2726 Value input = std::get<1>(it);
2727 Value initValue = std::get<2>(it);
2728 auto resultType = result.getType().cast<ShapedType>();
2729 if (!input.getType().cast<ShapedType>().getElementType().isF32()) {
2730 return rewriter.notifyMatchFailure(op,
2731 "expected element type to be f32");
2732 }
2733
2734 // Create a fake window dimension.
2735 auto fakeWindowDims = rewriter.create<linalg::InitTensorOp>(
2736 loc, fakeWindowShapes, resultType.getElementType());
2737
2738 SmallVector<Value> resultDynamicDims;
2739 for (auto& en : llvm::enumerate(resultType.getShape())) {
2740 if (en.value() != ShapedType::kDynamicSize) continue;
2741 Value dimSize = rewriter.create<tensor::DimOp>(loc, input, en.index());
2742 if (en.index() == 0 || static_cast<int64_t>(en.index()) == rank - 1) {
2743 // batch dims and channel dims can be derived from input dims
2744 // directly.
2745 resultDynamicDims.push_back(dimSize);
2746 } else {
2747 auto i = en.index() - 1;
2748 auto stride =
2749 strides.cast<DenseIntElementsAttr>().getValues<int64_t>()[i];
2750 auto dilation =
2751 dilations.cast<DenseIntElementsAttr>().getValues<int64_t>()[i];
2752 // let j = i * stride
2753 // output[i] = reduce( input[j, j + window_size * dilation) )
2754 Value offset = rewriter.create<arith::ConstantIndexOp>(
2755 loc, fakeWindowShapes[i] * dilation);
2756 dimSize = rewriter.create<arith::SubIOp>(loc, dimSize, offset);
2757 dimSize = rewriter.create<arith::DivUIOp>(
2758 loc, dimSize,
2759 rewriter.create<arith::ConstantIndexOp>(loc, stride));
2760 dimSize = rewriter.create<arith::AddIOp>(
2761 loc, dimSize, rewriter.create<arith::ConstantIndexOp>(loc, 1));
2762 resultDynamicDims.push_back(dimSize);
2763 }
2764 }
2765 Value initTensor = rewriter.create<linalg::InitTensorOp>(
2766 loc, resultDynamicDims, resultType.getShape(),
2767 resultType.getElementType());
2768
2769 initValue = rewriter.create<tensor::ExtractOp>(loc, initValue);
2770 Value filledInitTensor =
2771 rewriter.create<linalg::FillOp>(loc, initValue, initTensor)
2772 .getResult(0);
2773 auto createOp = [&](auto* typePtr) -> linalg::LinalgOp {
2774 return cast<linalg::LinalgOp>(
2775 rewriter
2776 .create<std::remove_pointer_t<decltype(typePtr)>>(
2777 loc, ArrayRef<Type>{resultType},
2778 ValueRange{input, fakeWindowDims.getResult()},
2779 filledInitTensor, strides, dilations,
2780 pruneAttributeList(op))
2781 .getOperation());
2782 };
2783 linalg::LinalgOp poolingOp;
2784 PoolingType poolingType = getPoolingType(op, result.getResultNumber());
2785 switch (poolingType) {
2786 case PoolingType::k2DMin: {
2787 poolingOp = createOp(static_cast<linalg::PoolingNhwcMinOp*>(nullptr));
2788 break;
2789 }
2790 case PoolingType::k3DMin: {
2791 poolingOp =
2792 createOp(static_cast<linalg::PoolingNdhwcMinOp*>(nullptr));
2793 break;
2794 }
2795 case PoolingType::k2DMax: {
2796 poolingOp = createOp(static_cast<linalg::PoolingNhwcMaxOp*>(nullptr));
2797 break;
2798 }
2799 case PoolingType::k3DMax: {
2800 poolingOp =
2801 createOp(static_cast<linalg::PoolingNdhwcMaxOp*>(nullptr));
2802 break;
2803 }
2804 case PoolingType::k2DAdd: {
2805 poolingOp = createOp(static_cast<linalg::PoolingNhwcSumOp*>(nullptr));
2806 break;
2807 }
2808 case PoolingType::k3DAdd: {
2809 poolingOp =
2810 createOp(static_cast<linalg::PoolingNdhwcSumOp*>(nullptr));
2811 break;
2812 }
2813 case PoolingType::kInvalid:
2814 return rewriter.notifyMatchFailure(op, "unknown reduction operation");
2815 }
2816 poolingOps.push_back(poolingOp->getResult(0));
2817 }
2818 rewriter.replaceOp(op, poolingOps);
2819 return success();
2820 }
2821 };
2822
2823 /// Converts xla-hlo.torch_index_select op to a linalg.generic op.
2824 struct TorchIndexSelectOpConversion
2825 : public OpConversionPattern<mhlo::TorchIndexSelectOp> {
2826 using OpConversionPattern<mhlo::TorchIndexSelectOp>::OpConversionPattern;
2827
matchAndRewritemlir::mhlo::__anon071fe3ef0111::TorchIndexSelectOpConversion2828 LogicalResult matchAndRewrite(
2829 mhlo::TorchIndexSelectOp op, OpAdaptor adaptor,
2830 ConversionPatternRewriter& rewriter) const final {
2831 int axis = static_cast<int>(op.dim());
2832 int batch = static_cast<int>(op.batch_dims());
2833 auto indexShapedType = adaptor.index().getType().cast<ShapedType>();
2834 int numIndices = static_cast<int>(indexShapedType.getRank());
2835 auto operandShapedType = adaptor.operand().getType().cast<ShapedType>();
2836 if (axis < 0) axis += static_cast<int>(operandShapedType.getRank());
2837 if (batch < 0) batch += numIndices;
2838
2839 Location loc = op.getLoc();
2840 auto resultType = this->typeConverter->convertType(op.getResult().getType())
2841 .cast<ShapedType>();
2842 int rank = static_cast<int>(resultType.getRank());
2843
2844 // The output shape is
2845 // `params[:axis] + indices[batch_dims:] + params[axis + 1:]`
2846 SmallVector<Value, 4> dynSizes;
2847 for (int i = 0; i < rank; ++i) {
2848 if (!resultType.isDynamicDim(i)) continue;
2849 if (i < axis) {
2850 dynSizes.push_back(
2851 rewriter.create<tensor::DimOp>(loc, adaptor.operand(), i));
2852 } else if (i < (axis + numIndices - batch)) {
2853 int idx = i - axis + batch;
2854 dynSizes.push_back(
2855 rewriter.create<tensor::DimOp>(loc, adaptor.index(), idx));
2856 } else {
2857 int idx = i - (axis + numIndices - batch) + axis + 1;
2858 dynSizes.push_back(
2859 rewriter.create<tensor::DimOp>(loc, adaptor.operand(), idx));
2860 }
2861 }
2862
2863 // Generate dummy tensor to preserve slice shape information.
2864 SmallVector<int64_t> sliceShape;
2865 SmallVector<Value, 4> dynSliceSizes;
2866 SmallVector<AffineExpr, 4> sliceExprs;
2867 auto resultShape = resultType.getShape();
2868 for (int i = 0; i < axis; ++i) {
2869 sliceExprs.push_back(rewriter.getAffineDimExpr(i));
2870 sliceShape.push_back(resultShape[i]);
2871 if (!resultType.isDynamicDim(i)) continue;
2872 dynSliceSizes.push_back(
2873 rewriter.create<tensor::DimOp>(loc, adaptor.operand(), i));
2874 }
2875 for (int i = axis + numIndices - batch; i < rank; ++i) {
2876 sliceExprs.push_back(rewriter.getAffineDimExpr(i));
2877 sliceShape.push_back(resultShape[i]);
2878 if (!resultType.isDynamicDim(i)) continue;
2879 int idx = i - (axis + numIndices - batch) + axis + 1;
2880 dynSliceSizes.push_back(
2881 rewriter.create<tensor::DimOp>(loc, adaptor.operand(), idx));
2882 }
2883
2884 // Setup AffineMap for operand tensor.
2885 SmallVector<AffineExpr, 4> exprs;
2886 for (int i = 0; i < batch; ++i) {
2887 exprs.push_back(rewriter.getAffineDimExpr(i));
2888 }
2889 for (int i = 0, e = numIndices - batch; i < e; ++i) {
2890 exprs.push_back(rewriter.getAffineDimExpr(axis + i));
2891 }
2892
2893 SmallVector<AffineMap, 2> indexingMaps;
2894 indexingMaps.emplace_back(
2895 AffineMap::get(rank, /*symbolCount=*/0, exprs, rewriter.getContext()));
2896 indexingMaps.emplace_back(AffineMap::get(
2897 rank, /*symbolCount=*/0, sliceExprs, rewriter.getContext()));
2898 indexingMaps.emplace_back(rewriter.getMultiDimIdentityMap(rank));
2899
2900 Value sliceOp = rewriter.create<linalg::InitTensorOp>(
2901 loc, dynSliceSizes, sliceShape, resultType.getElementType());
2902
2903 Value initOp = rewriter.create<linalg::InitTensorOp>(
2904 loc, dynSizes, resultType.getShape(), resultType.getElementType());
2905 auto linalgOp = rewriter.create<linalg::GenericOp>(
2906 loc, /*resultTensors=*/ArrayRef<Type>{resultType},
2907 /*inputs=*/ValueRange{adaptor.index(), sliceOp},
2908 /*outputs=*/initOp, indexingMaps, getNParallelLoopsAttrs(rank),
2909 /*bodyBuild=*/nullptr, pruneAttributeList(op));
2910
2911 SmallVector<Type, 4> bodyArgTypes;
2912 SmallVector<Value, 2> linalgOpArgs = {adaptor.index(), sliceOp};
2913 // Add a block to the region.
2914 auto* region = &linalgOp.region();
2915 auto* block = rewriter.createBlock(region, region->end());
2916 for (auto blockArgs : linalgOpArgs) {
2917 bodyArgTypes.push_back(
2918 blockArgs.getType().cast<ShapedType>().getElementType());
2919 }
2920 block->addArguments(bodyArgTypes,
2921 SmallVector<Location>(bodyArgTypes.size(), loc));
2922 block->addArguments(resultType.getElementType(), loc);
2923 OpBuilder::InsertionGuard guard(rewriter);
2924 rewriter.setInsertionPointToEnd(block);
2925
2926 Value castedValue = rewriter.create<arith::IndexCastOp>(
2927 loc, rewriter.getIndexType(), block->getArgument(0));
2928
2929 SmallVector<Value, 4> indices;
2930 for (int i = 0; i < axis; ++i) {
2931 indices.push_back(rewriter.create<linalg::IndexOp>(loc, i));
2932 }
2933 indices.push_back(castedValue);
2934 for (int i = axis + numIndices - batch; i < rank; ++i) {
2935 indices.push_back(rewriter.create<linalg::IndexOp>(loc, i));
2936 }
2937 Value res =
2938 rewriter.create<tensor::ExtractOp>(loc, adaptor.operand(), indices);
2939 rewriter.create<linalg::YieldOp>(loc, res);
2940
2941 rewriter.replaceOp(op, linalgOp.getResults());
2942 return success();
2943 }
2944 };
2945
2946 /// This lowering encompasses the full range of the Gather operation and
2947 /// therefore is very general and just loops over the output and calculate the
2948 /// corresponding input index. It follows the explanation at
2949 /// https://www.tensorflow.org/xla/operation_semantics#gather. The compiler
2950 /// should be able to optimize that a bit, but in order to get efficient
2951 /// lowerings, special-cases of gather should be extracted in separate
2952 /// lowerings, and ideally encapsulated as separate ops or canonicalization
2953 /// patterns.
2954 struct GatherConversion : public OpConversionPattern<mhlo::GatherOp> {
2955 using OpConversionPattern<mhlo::GatherOp>::OpConversionPattern;
2956
matchAndRewritemlir::mhlo::__anon071fe3ef0111::GatherConversion2957 LogicalResult matchAndRewrite(
2958 mhlo::GatherOp gatherOp, OpAdaptor adaptor,
2959 ConversionPatternRewriter& rewriter) const final {
2960 Location loc = gatherOp.getLoc();
2961
2962 Value startIndices = adaptor.start_indices();
2963 Value operand = adaptor.operand();
2964
2965 auto resultType = typeConverter->convertType(gatherOp.getType())
2966 .dyn_cast<RankedTensorType>();
2967 RankedTensorType startIndicesType =
2968 startIndices.getType().dyn_cast<RankedTensorType>();
2969 // We could actually deal with an unranked result by inferring the result
2970 // rank, but the current reifyReturnTypes doesn't support unranked either.
2971 if (!resultType || !startIndicesType)
2972 return rewriter.notifyMatchFailure(gatherOp,
2973 "unranked start indices or result");
2974
2975 int resultRank = resultType.getRank();
2976 // slice_sizes has to have the same size as operand.rank, and doing it this
2977 // way permits an unranked operand.
2978 int operandRank = gatherOp.slice_sizes().getNumElements();
2979
2980 int64_t indexVectorDim = gatherOp.dimension_numbers().getIndexVectorDim();
2981
2982 ArrayRef<int64_t> offsetDims = gatherOp.dimension_numbers().getOffsetDims();
2983 ArrayRef<int64_t> collapsedSliceDims =
2984 gatherOp.dimension_numbers().getCollapsedSliceDims();
2985 ArrayRef<int64_t> startIndexMap =
2986 gatherOp.dimension_numbers().getStartIndexMap();
2987
2988 auto extractAsIndex = [&](Value input, ArrayRef<Value> index) -> Value {
2989 return rewriter.create<arith::IndexCastOp>(
2990 loc, rewriter.getIndexType(),
2991 rewriter.create<tensor::ExtractOp>(loc, input, index));
2992 };
2993
2994 // We'll need these later and creating them on demand we end up with
2995 // duplicates, which also makes lit tests really hard to write.
2996 SmallVector<Value> constants;
2997 for (int i = 0; i < std::max({resultRank, operandRank, 2}); ++i) {
2998 constants.push_back(
2999 rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(i)));
3000 }
3001
3002 // Create ops to calculate the dynamic dimensions of the return shape, which
3003 // are needed for the init tensor.
3004 SmallVector<Value> dynDimSizes;
3005 if (!resultType.hasStaticShape()) {
3006 SmallVector<Value> returnShapes;
3007 if (failed(gatherOp.reifyReturnTypeShapes(rewriter, adaptor.getOperands(),
3008 returnShapes)))
3009 return rewriter.notifyMatchFailure(gatherOp,
3010 "could not reify return shape");
3011 assert(returnShapes.size() == 1);
3012 Value returnShape = returnShapes[0];
3013
3014 for (int i = 0; i < resultRank; ++i)
3015 if (resultType.isDynamicDim(i))
3016 dynDimSizes.push_back(extractAsIndex(returnShape, constants[i]));
3017 }
3018
3019 Value initOp = rewriter.create<linalg::InitTensorOp>(
3020 loc, dynDimSizes, resultType.getShape(), resultType.getElementType());
3021
3022 ValueRange ins;
3023 SmallVector<AffineMap, 1> indexingMaps(
3024 {rewriter.getMultiDimIdentityMap(resultRank)});
3025 auto linalgOp = rewriter.create<linalg::GenericOp>(
3026 loc, /*resultTensorTypes=*/resultType,
3027 /*inputs=*/ins,
3028 /*outputs=*/initOp, indexingMaps, getNParallelLoopsAttrs(resultRank),
3029 /*bodyBuild=*/nullptr, pruneAttributeList(gatherOp));
3030
3031 // Now populate the linalg generic region
3032 auto* region = &linalgOp.region();
3033 auto* block = rewriter.createBlock(region, region->end());
3034 block->addArguments(resultType.getElementType(), loc);
3035 OpBuilder::InsertionGuard guard(rewriter);
3036 rewriter.setInsertionPointToEnd(block);
3037
3038 // Dimensions in the result that aren't offset dimensions are called batch.
3039 SmallVector<int64_t> batchDims;
3040 for (int dim = 0; dim < resultRank; ++dim)
3041 if (!llvm::is_contained(offsetDims, dim)) batchDims.push_back(dim);
3042
3043 // Same as with the constants. Creating these all up front is easier than
3044 // potentially getting duplicates later.
3045 SmallVector<Value> linalgIndices;
3046 for (int i = 0; i < resultRank; ++i)
3047 linalgIndices.push_back(rewriter.create<linalg::IndexOp>(loc, i));
3048
3049 // Now the complicated part. For a given output dimension we build up an
3050 // index into the input. It's composed of two parts: the index coming from
3051 // start_indices, and the offset from that index along the offset
3052 // dimensions. Everything includes dimension shuffling and remapping as well
3053 // because of the way gather is defined to allow for any-layout input by
3054 // adding more attributes.
3055
3056 // The base gather index (`G` in the documentation) points to a place in
3057 // start_indices along the batch dimensions.
3058 SmallVector<Value> gatherIndex;
3059 for (auto dim : batchDims) gatherIndex.push_back(linalgIndices[dim]);
3060
3061 SmallVector<Value> indexFromStartIndices;
3062 for (unsigned i = 0; i < startIndexMap.size(); ++i) {
3063 // The index along the index_vector dimension of start_indices varies.
3064 // Basically indexFromStartIndices indexes into a "row" along
3065 // index_vector_dim, where the row is selected by the current output
3066 // index.
3067 // But if index_vector_dim is equal to start_indices.rank, then
3068 // start_indices gets a trailing 1 dimension added. So the row we're
3069 // extracting always has length 1 and the index into it is always 0, so we
3070 // just use the gather index directly
3071 SmallVector<Value> gCombine(gatherIndex);
3072 if (indexVectorDim != startIndicesType.getRank()) {
3073 assert(indexVectorDim <= static_cast<int64_t>(gCombine.size()));
3074 gCombine.insert(gCombine.begin() + indexVectorDim, constants[i]);
3075 }
3076
3077 indexFromStartIndices.push_back(extractAsIndex(startIndices, gCombine));
3078 }
3079
3080 // But then start indices are shuffled by the start index map. To make a
3081 // full index into the operand, all missing indices are zeroes.
3082 SmallVector<Value> remappedIndexFromIndices(operandRank, constants[0]);
3083 for (auto& it : llvm::enumerate(startIndexMap))
3084 remappedIndexFromIndices[it.value()] = indexFromStartIndices[it.index()];
3085
3086 // Now we construct the index based on the offset. First we need to remap
3087 // the offset dimensions by dropping the collapsed indices.
3088 SmallVector<unsigned> remappedOffsetDims;
3089 for (int i = 0; i < operandRank; ++i)
3090 if (!llvm::is_contained(collapsedSliceDims, i))
3091 remappedOffsetDims.push_back(i);
3092
3093 assert(remappedOffsetDims.size() == offsetDims.size());
3094
3095 // Clamp out of bounds indices.
3096 for (int i = 0, operandIndexDim = 0; i < operandRank; ++i) {
3097 // Compute the size of the output shape dimension corresponding to this
3098 // index dimension. If it's collapsed set it to 1.
3099 Value outputDimSize = constants[1];
3100 if (!llvm::is_contained(collapsedSliceDims, i)) {
3101 outputDimSize = rewriter.createOrFold<tensor::DimOp>(
3102 loc, initOp, offsetDims[operandIndexDim++]);
3103 }
3104
3105 // If this is a skipped dimension, we're done and don't have to clamp.
3106 if (remappedIndexFromIndices[i] == constants[0]) continue;
3107
3108 Value operandDimSize =
3109 rewriter.createOrFold<tensor::DimOp>(loc, operand, i);
3110 Value largestValidIndex = rewriter.createOrFold<arith::SubIOp>(
3111 loc, operandDimSize, outputDimSize);
3112
3113 // Clamp indices to [0, i, operand_dim-output_dim].
3114 Value clamp = rewriter.create<arith::MinSIOp>(
3115 loc,
3116 rewriter.create<arith::MaxSIOp>(loc, constants[0],
3117 remappedIndexFromIndices[i]),
3118 largestValidIndex);
3119 remappedIndexFromIndices[i] = clamp;
3120 }
3121
3122 // For the (remapped) offset dimensions, the index is the current index in
3123 // the output. As before this is expanded to a full index into the operand
3124 // by using zeroe for the missing indices.
3125 SmallVector<Value> indexFromOffset(operandRank, constants[0]);
3126 for (unsigned k = 0; k < offsetDims.size(); ++k)
3127 indexFromOffset[remappedOffsetDims[k]] = linalgIndices[offsetDims[k]];
3128
3129 // Now we add together our two indices to get the final index into the
3130 // operand.
3131 SmallVector<Value> combinedIndex;
3132 for (int i = 0; i < operandRank; ++i)
3133 combinedIndex.push_back(rewriter.createOrFold<arith::AddIOp>(
3134 loc, rewriter.getIndexType(), remappedIndexFromIndices[i],
3135 indexFromOffset[i]));
3136
3137 Value element =
3138 rewriter.create<tensor::ExtractOp>(loc, operand, combinedIndex);
3139 rewriter.create<linalg::YieldOp>(loc, element);
3140
3141 rewriter.replaceOp(gatherOp, linalgOp.getResults());
3142
3143 return success();
3144 }
3145 };
3146
3147 class DotGeneralOpConversion : public OpConversionPattern<mhlo::DotGeneralOp> {
3148 public:
3149 using OpConversionPattern<mhlo::DotGeneralOp>::OpConversionPattern;
matchAndRewrite(mhlo::DotGeneralOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const3150 LogicalResult matchAndRewrite(
3151 mhlo::DotGeneralOp op, OpAdaptor adaptor,
3152 ConversionPatternRewriter& rewriter) const final {
3153 if (!verifyHloOpBufferOrTensorSemantics(op)) {
3154 return failure();
3155 }
3156
3157 // Get various dimension iterator information
3158 mhlo::DotDimensionNumbersAttr dimNumbers = op.dot_dimension_numbers();
3159 auto lhsBatchingDims = dimNumbers.getLhsBatchingDimensions();
3160 auto rhsBatchingDims = dimNumbers.getRhsBatchingDimensions();
3161 auto lhsContractingDims = dimNumbers.getLhsContractingDimensions();
3162 auto rhsContractingDims = dimNumbers.getRhsContractingDimensions();
3163
3164 // Get shape information and initialize output
3165 assert(lhsContractingDims.size() == rhsContractingDims.size() &&
3166 "number of contracting dims must be equal");
3167 auto numContracting = lhsContractingDims.size();
3168 // Convert unsigned to signed. This works because signed and unsigned
3169 // integer matmul is the same operation in two's complement.
3170 auto outputType =
3171 typeConverter->convertType(op.getType()).cast<ShapedType>();
3172 auto targetRank = outputType.getRank();
3173 auto totalLoopCount = numContracting + targetRank;
3174
3175 auto lhsRank = adaptor.lhs().getType().cast<ShapedType>().getRank();
3176 auto lhsExtraDims =
3177 lhsRank - lhsBatchingDims.size() - lhsContractingDims.size();
3178 auto rhsRank = adaptor.rhs().getType().cast<ShapedType>().getRank();
3179
3180 Location loc = op.getLoc();
3181 auto initTensor =
3182 getInitTensorFor(rewriter, loc, outputType, op, adaptor.getOperands());
3183 Value zeroTensor = fillTensorWithZeros(rewriter, loc, initTensor);
3184 SmallVector<AffineMap, 3> indexingMaps;
3185
3186 auto getMap = [&](int64_t rank, ArrayRef<int64_t> batchingDims,
3187 ArrayRef<int64_t> contractingDims, size_t extraDims) {
3188 llvm::SmallVector<AffineExpr> indices(rank);
3189 for (const auto& i : llvm::enumerate(batchingDims)) {
3190 indices[i.value()] = rewriter.getAffineDimExpr(i.index());
3191 }
3192 for (const auto& i : llvm::enumerate(contractingDims)) {
3193 indices[i.value()] = rewriter.getAffineDimExpr(i.index() + targetRank);
3194 }
3195 for (int i = 0; i < rank; ++i) {
3196 if (!indices[i]) {
3197 indices[i] = rewriter.getAffineDimExpr(extraDims++);
3198 }
3199 }
3200 indexingMaps.push_back(AffineMap::get(/*dimCount=*/totalLoopCount,
3201 /*symbolCount=*/0, indices,
3202 op->getContext()));
3203 };
3204 getMap(lhsRank, lhsBatchingDims, lhsContractingDims,
3205 lhsBatchingDims.size());
3206 getMap(rhsRank, rhsBatchingDims, rhsContractingDims,
3207 rhsBatchingDims.size() + lhsExtraDims);
3208
3209 {
3210 SmallVector<AffineExpr, 4> dimExprs;
3211 dimExprs.reserve(targetRank);
3212 for (unsigned i = 0; i < targetRank; ++i)
3213 dimExprs.push_back(rewriter.getAffineDimExpr(i));
3214 indexingMaps.push_back(AffineMap::get(/*dimCount=*/totalLoopCount,
3215 /*symbolCount=*/0, dimExprs,
3216 op.getContext()));
3217 }
3218
3219 Operation* linalgOp = rewriter.create<linalg::GenericOp>(
3220 loc, /*resultTensorTypes=*/TypeRange{outputType},
3221 /*inputs=*/ValueRange{adaptor.lhs(), adaptor.rhs()},
3222 /*outputBuffers=*/ValueRange{zeroTensor}, indexingMaps,
3223 getParallelAndReductionIterators(
3224 /*nLoops=*/totalLoopCount,
3225 /*nReduction=*/numContracting),
3226 [](OpBuilder& b, Location loc, ValueRange) {
3227 ImplicitLocOpBuilder builder(loc, b);
3228 linalg::MatmulOp::regionBuilder(builder, *b.getInsertionBlock(), {});
3229 },
3230 pruneAttributeList(op));
3231
3232 rewriter.replaceOp(op, linalgOp->getResults());
3233 return success();
3234 }
3235 };
3236
3237 struct HloLegalizeToLinalgPass
3238 : public mhlo::HloLegalizeToLinalgPassBase<HloLegalizeToLinalgPass> {
getDependentDialectsmlir::mhlo::__anon071fe3ef0111::HloLegalizeToLinalgPass3239 void getDependentDialects(DialectRegistry& registry) const override {
3240 registry.insert<bufferization::BufferizationDialect, linalg::LinalgDialect,
3241 scf::SCFDialect, complex::ComplexDialect, math::MathDialect,
3242 memref::MemRefDialect, shape::ShapeDialect>();
3243 }
3244
runOnOperationmlir::mhlo::__anon071fe3ef0111::HloLegalizeToLinalgPass3245 void runOnOperation() override {
3246 MLIRContext& ctx = getContext();
3247 RewritePatternSet patterns(&ctx);
3248 ConversionTarget target(ctx);
3249 target.addLegalDialect<
3250 bufferization::BufferizationDialect, arith::ArithmeticDialect,
3251 complex::ComplexDialect, linalg::LinalgDialect, math::MathDialect,
3252 tensor::TensorDialect, sparse_tensor::SparseTensorDialect,
3253 scf::SCFDialect, shape::ShapeDialect>();
3254
3255 target.addLegalOp<UnrealizedConversionCastOp>();
3256
3257 auto typeConverter = createHloToLinalgTypeConverter();
3258 auto func = getOperation();
3259 mhlo::populateHloToLinalgConversionPattern(&ctx, *typeConverter, &patterns);
3260 if (failed(applyPartialConversion(func, target, std::move(patterns)))) {
3261 signalPassFailure();
3262 }
3263 }
3264 };
3265
3266 } // namespace
3267
populateHloToLinalgConversionPattern(MLIRContext * context,TypeConverter & typeConverter,RewritePatternSet * patterns)3268 void populateHloToLinalgConversionPattern(MLIRContext* context,
3269 TypeConverter& typeConverter,
3270 RewritePatternSet* patterns) {
3271 // clang-format off
3272 patterns->add<
3273 BroadcastConverter<mhlo::BroadcastOp>, ConcatenateConverter,
3274 ConstConverterTensor, HloDynamicBroadcastInDimConverter,
3275 HloBroadcastInDimConverter, IotaConverter<mhlo::IotaOp>,
3276 EinsumToLinalgConverter,
3277 IotaConverter<mhlo::DynamicIotaOp>,
3278 MapOpConverter,
3279 PointwiseToLinalgConverter<mhlo::AbsOp>,
3280 PointwiseToLinalgConverter<mhlo::AddOp>,
3281 PointwiseToLinalgConverter<mhlo::AndOp>,
3282 PointwiseToLinalgConverter<mhlo::Atan2Op>,
3283 PointwiseToLinalgConverter<mhlo::BitcastConvertOp>,
3284 PointwiseToLinalgConverter<mhlo::CbrtOp>,
3285 PointwiseToLinalgConverter<mhlo::CeilOp>,
3286 PointwiseToLinalgConverter<mhlo::ClampOp>,
3287 PointwiseToLinalgConverter<mhlo::ClzOp>,
3288 PointwiseToLinalgConverter<mhlo::CompareOp>,
3289 PointwiseToLinalgConverter<mhlo::ComplexOp>,
3290 PointwiseToLinalgConverter<mhlo::ConvertOp>,
3291 PointwiseToLinalgConverter<mhlo::CopyOp>,
3292 PointwiseToLinalgConverter<mhlo::CosineOp>,
3293 PointwiseToLinalgConverter<mhlo::DivOp>,
3294 PointwiseToLinalgConverter<mhlo::ExpOp>,
3295 PointwiseToLinalgConverter<mhlo::Expm1Op>,
3296 PointwiseToLinalgConverter<mhlo::FloorOp>,
3297 PointwiseToLinalgConverter<mhlo::ImagOp>,
3298 PointwiseToLinalgConverter<mhlo::IsFiniteOp>,
3299 PointwiseToLinalgConverter<mhlo::LogOp>,
3300 PointwiseToLinalgConverter<mhlo::LogisticOp>,
3301 PointwiseToLinalgConverter<mhlo::Log1pOp>,
3302 PointwiseToLinalgConverter<mhlo::MaxOp>,
3303 PointwiseToLinalgConverter<mhlo::MinOp>,
3304 PointwiseToLinalgConverter<mhlo::MulOp>,
3305 PointwiseToLinalgConverter<mhlo::NegOp>,
3306 PointwiseToLinalgConverter<mhlo::NotOp>,
3307 PointwiseToLinalgConverter<mhlo::OrOp>,
3308 PointwiseToLinalgConverter<mhlo::PopulationCountOp>,
3309 PointwiseToLinalgConverter<mhlo::PowOp>,
3310 PointwiseToLinalgConverter<mhlo::RealOp>,
3311 PointwiseToLinalgConverter<mhlo::RemOp>,
3312 PointwiseToLinalgConverter<mhlo::RoundOp>,
3313 PointwiseToLinalgConverter<mhlo::RsqrtOp>,
3314 PointwiseToLinalgConverter<mhlo::SelectOp>,
3315 PointwiseToLinalgConverter<mhlo::ShiftLeftOp>,
3316 PointwiseToLinalgConverter<mhlo::ShiftRightArithmeticOp>,
3317 PointwiseToLinalgConverter<mhlo::ShiftRightLogicalOp>,
3318 PointwiseToLinalgConverter<mhlo::SignOp>,
3319 PointwiseToLinalgConverter<mhlo::SineOp>,
3320 PointwiseToLinalgConverter<mhlo::SqrtOp>,
3321 PointwiseToLinalgConverter<mhlo::SubtractOp>,
3322 PointwiseToLinalgConverter<mhlo::TanhOp>,
3323 PointwiseToLinalgConverter<mhlo::XorOp>,
3324 PointwiseToLinalgConverter<mhlo::ReducePrecisionOp>,
3325 RealDynamicSliceConverter,
3326 ReshapeOpConverter,
3327 ReverseConverter,
3328 SliceConverter,
3329 DynamicSliceConverter,
3330 DynamicUpdateSliceConverter,
3331 TransposeConverter<mhlo::TransposeOp>,
3332 GatherConversion,
3333 PadOpConversion,
3334 PadOpNegativePaddingConversion,
3335 ReduceConversion,
3336 ReduceWindowOpOnTensorsGenericConversion,
3337 ReduceWindowOpConversion,
3338 RngUniformConversion,
3339 TorchIndexSelectOpConversion,
3340 ReduceRegionReturnOpConversion>(typeConverter, context);
3341 // Ensure specialized patterns are higher priority than their generic
3342 // versions.
3343 patterns->add<
3344 NormalConvolutionOpConversion,
3345 DepthwiseConvolutionOpConversion,
3346 DotOpConversion<DotOperationType::kMatrixMatrix, linalg::MatmulOp>,
3347 DotOpConversion<DotOperationType::kMatrixVector, linalg::MatvecOp>,
3348 DotOpConversion<DotOperationType::kVectorMatrix, linalg::VecmatOp>,
3349 DotOpConversion<DotOperationType::kVectorDot, linalg::DotOp>,
3350 DotGeneralBatchMatMulOpConversion>(typeConverter, context,
3351 PatternBenefit(2));
3352 patterns->add<
3353 ConvolutionOpGeneralConversion,
3354 DotGeneralOpConversion>(typeConverter, context, PatternBenefit(1));
3355 // clang-format on
3356 }
3357
createLegalizeHloToLinalgPass()3358 std::unique_ptr<OperationPass<func::FuncOp>> createLegalizeHloToLinalgPass() {
3359 return std::make_unique<HloLegalizeToLinalgPass>();
3360 }
3361
createHloToLinalgTypeConverter()3362 std::unique_ptr<TypeConverter> createHloToLinalgTypeConverter() {
3363 return std::make_unique<LinalgTypeConverter>();
3364 }
3365
3366 } // namespace mhlo
3367 } // namespace mlir
3368