1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This file implements logic for lowering HLO/LHLO dialect to Linalg dialect.
17 
18 #include <algorithm>
19 #include <cstdint>
20 #include <numeric>
21 #include <string>
22 #include <utility>
23 
24 #include "llvm/ADT/ArrayRef.h"
25 #include "llvm/ADT/BitVector.h"
26 #include "llvm/ADT/STLExtras.h"
27 #include "llvm/ADT/SetVector.h"
28 #include "llvm/ADT/SmallVector.h"
29 #include "llvm/ADT/iterator_range.h"
30 #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
31 #include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h"
32 #include "mlir-hlo/Dialect/mhlo/transforms/legalize_to_linalg_utils.h"
33 #include "mlir-hlo/Dialect/mhlo/transforms/map_mhlo_to_scalar_op.h"
34 #include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
35 #include "mlir-hlo/Dialect/mhlo/transforms/type_conversion.h"
36 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
37 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
38 #include "mlir/Dialect/Complex/IR/Complex.h"
39 #include "mlir/Dialect/Func/IR/FuncOps.h"
40 #include "mlir/Dialect/Linalg/IR/Linalg.h"
41 #include "mlir/Dialect/Math/IR/Math.h"
42 #include "mlir/Dialect/MemRef/IR/MemRef.h"
43 #include "mlir/Dialect/SCF/IR/SCF.h"
44 #include "mlir/Dialect/Shape/IR/Shape.h"
45 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
46 #include "mlir/Dialect/Tensor/IR/Tensor.h"
47 #include "mlir/Dialect/Tensor/Utils/Utils.h"
48 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
49 #include "mlir/IR/AffineExpr.h"
50 #include "mlir/IR/Attributes.h"
51 #include "mlir/IR/BlockAndValueMapping.h"
52 #include "mlir/IR/Builders.h"
53 #include "mlir/IR/BuiltinAttributes.h"
54 #include "mlir/IR/BuiltinOps.h"
55 #include "mlir/IR/BuiltinTypes.h"
56 #include "mlir/IR/Location.h"
57 #include "mlir/IR/MLIRContext.h"
58 #include "mlir/IR/Operation.h"
59 #include "mlir/IR/OperationSupport.h"
60 #include "mlir/IR/PatternMatch.h"
61 #include "mlir/IR/TypeUtilities.h"
62 #include "mlir/Pass/Pass.h"
63 #include "mlir/Support/LLVM.h"
64 #include "mlir/Support/LogicalResult.h"
65 #include "mlir/Transforms/DialectConversion.h"
66 
67 namespace mlir {
68 namespace mhlo {
69 namespace {
70 
getResultValue(Operation * op)71 Value getResultValue(Operation* op) { return op->getResult(0); }
72 
getHloOpResultType(Operation * op)73 ShapedType getHloOpResultType(Operation* op) {
74   return getResultValue(op).getType().cast<ShapedType>();
75 }
76 
verifyHloOpBufferOrTensorSemantics(Operation * op)77 bool verifyHloOpBufferOrTensorSemantics(Operation* op) {
78   auto verifyType = [&](Value val) -> bool {
79     return val.getType().isa<RankedTensorType>();
80   };
81   if (!llvm::all_of(op->getOperands(), verifyType)) return false;
82   return llvm::all_of(op->getResults(), verifyType);
83 }
84 
fillTensorWithZeros(OpBuilder & builder,Location loc,Value tensor)85 Value fillTensorWithZeros(OpBuilder& builder, Location loc, Value tensor) {
86   auto type = tensor.getType().cast<ShapedType>();
87   Value zero;
88   // Complex numbers are a special case.
89   if (auto complexType = type.getElementType().dyn_cast<ComplexType>()) {
90     auto zeroElement = builder.getZeroAttr(complexType.getElementType());
91     auto zeroAttr = builder.getArrayAttr({zeroElement, zeroElement});
92     zero = builder.create<complex::ConstantOp>(loc, complexType, zeroAttr);
93   } else {
94     auto zeroAttr = builder.getZeroAttr(type.getElementType());
95     zero = builder.create<arith::ConstantOp>(loc, zeroAttr);
96   }
97   return builder.create<linalg::FillOp>(loc, zero, tensor).result();
98 }
99 
extract1DVector(DenseIntElementsAttr elements)100 SmallVector<int64_t, 4> extract1DVector(DenseIntElementsAttr elements) {
101   SmallVector<int64_t, 4> ret;
102   for (const APInt& element : elements) {
103     ret.push_back(element.getLimitedValue());
104   }
105   return ret;
106 }
107 
108 /// Returns a permutation AffineMap that puts all reduction dimensions to the
109 /// last. The order of parallel loops and reduction loops are all sorted. E.g.,
110 /// if `rank` is 4 and `reductionDims` is {1, 3}, then
111 /// "(d0, d1, d2, d3) -> (d0, d2, d1, d3)" is used. The inverse permutation of
112 /// the AffineMap is returned.
getTransposeMapForReduction(MLIRContext * context,int rank,ArrayRef<int64_t> reductionDims)113 AffineMap getTransposeMapForReduction(MLIRContext* context, int rank,
114                                       ArrayRef<int64_t> reductionDims) {
115   llvm::SmallSetVector<int, 4> s;
116   for (auto dim : reductionDims) s.insert(dim);
117 
118   SmallVector<unsigned, 4> permutation;
119   for (int i = 0; i < rank; ++i)
120     if (!s.count(i)) permutation.push_back(i);
121   for (auto dim : reductionDims) permutation.push_back(dim);
122 
123   auto map = AffineMap::getPermutationMap(permutation, context);
124   return inversePermutation(map);
125 }
126 
127 /// Returns true if the given `attr` is a splat of the given `value`.
isSplatValue(DenseIntElementsAttr attr,uint64_t value)128 bool isSplatValue(DenseIntElementsAttr attr, uint64_t value) {
129   return attr.isSplat() && attr.getSplatValue<uint64_t>() == value;
130 }
131 
132 /// Returns true if the given `dimensionNumbers` from a mhlo.convolution op
133 /// follows a canonical form:
134 ///
135 /// * Input dimensions have order: (batch_count, spatial_dims,
136 ///   input_channel_count).
137 /// * Filter dimensions have order: (spatial_dims, input_channel_count,
138 ///   output_channel_count).
139 /// * Output dimensions have order: (batch_count, spatial_dims,
140 ///   output_channel_count).
hasCanonicalDimensionNumbers(mhlo::ConvDimensionNumbersAttr dimensionNumbers)141 static bool hasCanonicalDimensionNumbers(
142     mhlo::ConvDimensionNumbersAttr dimensionNumbers) {
143   const int inputSpatialRank =
144       llvm::size(dimensionNumbers.getInputSpatialDimensions());
145   // The dimensions for input should follow the order of
146   // batch_count, spatial_dims..., input_feature_count.
147   if (dimensionNumbers.getInputBatchDimension() != 0 ||
148       dimensionNumbers.getInputFeatureDimension() != (inputSpatialRank + 1)) {
149     return false;
150   }
151 
152   const int kernelSpatialRank =
153       llvm::size(dimensionNumbers.getKernelSpatialDimensions());
154   // The dimensions for filter should follow the order of
155   // spatial_dims..., input_feature_count, num_output_feature_count.
156   if (dimensionNumbers.getKernelInputFeatureDimension() != kernelSpatialRank ||
157       dimensionNumbers.getKernelOutputFeatureDimension() !=
158           (kernelSpatialRank + 1)) {
159     return false;
160   }
161 
162   const int outputSpatialRank =
163       llvm::size(dimensionNumbers.getOutputSpatialDimensions());
164   // The dimensions for output should follow the order of
165   // batch_count, spatial_dims.., output_feature_count.
166   if (dimensionNumbers.getOutputBatchDimension() != 0 ||
167       dimensionNumbers.getOutputFeatureDimension() != (outputSpatialRank + 1)) {
168     return false;
169   }
170 
171   if (inputSpatialRank != outputSpatialRank ||
172       inputSpatialRank != kernelSpatialRank) {
173     return false;
174   }
175 
176   const auto* inputSpatialDim =
177       dimensionNumbers.getInputSpatialDimensions().begin();
178   const auto* kernelSpatialDim =
179       dimensionNumbers.getKernelSpatialDimensions().begin();
180   const auto* outputSpatialDim =
181       dimensionNumbers.getOutputSpatialDimensions().begin();
182   // Check spatial dims are ordered correctly.
183   for (int i = 0; i < inputSpatialRank; ++i) {
184     const int dim = i + 1;
185     if ((*inputSpatialDim++) != dim || (*outputSpatialDim++) != dim ||
186         (*kernelSpatialDim++) != i) {
187       return false;
188     }
189   }
190 
191   return true;
192 }
193 
194 //===----------------------------------------------------------------------===//
195 // mhlo.RngOp conversion patterns.
196 //===----------------------------------------------------------------------===//
197 
198 // Pass to lower from rng to stateless pseudo RNG with LCG
199 // algorithm
200 struct RngUniformConversion : public OpConversionPattern<mhlo::RngOp> {
201   using OpConversionPattern<mhlo::RngOp>::OpConversionPattern;
202 
matchAndRewritemlir::mhlo::__anon071fe3ef0111::RngUniformConversion203   LogicalResult matchAndRewrite(
204       mhlo::RngOp op, OpAdaptor adaptor,
205       ConversionPatternRewriter& rewriter) const final {
206     // We only handle uniform distributions
207     if (op.rng_distribution() != ::mlir::mhlo::RngDistribution::UNIFORM) {
208       return failure();
209     }
210     // TODO(raikonenfnu): Handle other element types as well.
211     auto minTy = adaptor.getOperands()[0].getType().dyn_cast<ShapedType>();
212     auto maxTy = adaptor.getOperands()[0].getType().dyn_cast<ShapedType>();
213     if (!minTy.getElementType().dyn_cast<FloatType>() ||
214         !maxTy.getElementType().dyn_cast<FloatType>()) {
215       return rewriter.notifyMatchFailure(
216           op, "expected min/max for rng op to be FloatType");
217     }
218     auto targetTy = this->typeConverter->convertType(op.getResult().getType())
219                         .cast<ShapedType>();
220     if (!targetTy) {
221       return rewriter.notifyMatchFailure(
222           op, "expected target shape of rng op to be ShapedType");
223     }
224     auto loc = op.getLoc();
225     Value initTensor =
226         getInitTensorFor(rewriter, loc, targetTy, op, adaptor.getOperands());
227     // Creates index map using target matrix's rank.
228     auto targetRank = targetTy.getRank();
229     SmallVector<AffineMap, 3> indexingMaps(
230         2, AffineMap::get(targetRank, /*symbolCount=*/0,
231                           SmallVector<AffineExpr>({}), rewriter.getContext()));
232     indexingMaps.push_back(rewriter.getMultiDimIdentityMap(targetRank));
233     const int kInitialSeed = 0;
234     // Generic region with LCG Algorithm that make use of element index from:
235     // https://reviews.llvm.org/D101364
236     auto linalgOp = rewriter.create<linalg::GenericOp>(
237         loc, /*resultTensors=*/targetTy,
238         /*inputs=*/
239         ValueRange{adaptor.getOperands()[0], adaptor.getOperands()[1]},
240         /*outputs=*/initTensor, indexingMaps,
241         getParallelAndReductionIterators(/*nLoops=*/targetRank,
242                                          /*nReduction=*/0),
243         [&](OpBuilder& b, Location loc, ValueRange args) {
244           llvm::SmallVector<Value> updateVec = {b.create<arith::ConstantOp>(
245               loc, b.getI32IntegerAttr(kInitialSeed))};
246           Value multiplier =
247               b.create<arith::ConstantOp>(loc, b.getI32IntegerAttr(1103515245));
248           Value incrementStep =
249               b.create<arith::ConstantOp>(loc, b.getI32IntegerAttr(12345));
250           // For output matrix with rank N:
251           // temp1 = (cast(I32, index(D.0)) + seed) * mult + incr
252           // ...
253           // tempN = (cast(I32, index(D.(N))) + tempN_1) * mult + incr
254           for (int i = 0; i < targetRank; i++) {
255             Value update = updateVec.back();
256             Value ind = b.create<linalg::IndexOp>(loc, i);
257             Value castInd =
258                 b.create<arith::IndexCastOp>(loc, b.getI32Type(), ind);
259             Value addRes = b.create<arith::AddIOp>(loc, castInd, update);
260             Value multRes = b.create<arith::MulIOp>(loc, addRes, multiplier);
261             Value incRes = b.create<arith::AddIOp>(loc, multRes, incrementStep);
262             updateVec.push_back(incRes);
263           }
264           // Scaling = (max - min) * const(F64, 2.3283064E-10)
265           // which is derived from rand(min,max) = rand()/(RAND_MAX/(max-min)).
266           Value epsilon = b.create<arith::ConstantOp>(
267               loc, b.getFloatAttr(args[0].getType(), 2.3283064E-10));
268           Value range = b.create<arith::SubFOp>(loc, args[1], args[0]);
269           Value scale = b.create<arith::MulFOp>(loc, range, epsilon);
270           // Res = cast(T, cast(F64, tempN) * scaling + min)
271           Value updateCast = b.create<arith::UIToFPOp>(
272               loc, targetTy.getElementType(), updateVec.back());
273           Value scaleUpdate = b.create<arith::MulFOp>(loc, updateCast, scale);
274           Value res = b.create<arith::AddFOp>(loc, scaleUpdate, args[0]);
275           b.create<linalg::YieldOp>(loc, res);
276         },
277         pruneAttributeList(op));
278     rewriter.replaceOp(op, linalgOp.getResults());
279     return success();
280   }
281 };
282 
283 //===----------------------------------------------------------------------===//
284 // mhlo.Einsum conversion patterns.
285 //===----------------------------------------------------------------------===//
286 
287 // Looks through a set of dimension that has been marked as reduction axes,
288 // if it is found within the set, then we set it as "reduction", otherwise
289 // we can label it as "parallel".
getEinsumLoopsAttrs(const llvm::SmallSetVector<StringRef,4> & inputInd,const llvm::SmallSetVector<StringRef,4> & reductionDims)290 SmallVector<StringRef, 3> getEinsumLoopsAttrs(
291     const llvm::SmallSetVector<StringRef, 4>& inputInd,
292     const llvm::SmallSetVector<StringRef, 4>& reductionDims) {
293   SmallVector<StringRef, 3> res;
294   for (StringRef dim : inputInd) {
295     if (!reductionDims.contains(dim)) {
296       res.push_back(getParallelIteratorTypeName());
297     } else {
298       res.push_back(getReductionIteratorTypeName());
299     }
300   }
301   return res;
302 }
303 
extractDynamicEinsumSizes(OpBuilder & b,Location loc,Value lhs,Value rhs,const SmallVector<std::string> & lhsLoopVec,const SmallVector<std::string> & rhsLoopVec,const SmallVector<std::string> & outputLoopVec)304 SmallVector<Value, 2> extractDynamicEinsumSizes(
305     OpBuilder& b, Location loc, Value lhs, Value rhs,
306     const SmallVector<std::string>& lhsLoopVec,
307     const SmallVector<std::string>& rhsLoopVec,
308     const SmallVector<std::string>& outputLoopVec) {
309   SmallVector<Value, 2> dynSizes;
310   for (const std::string& dimInd : outputLoopVec) {
311     Value dimSize;
312     const auto* dimIndIt =
313         std::find(lhsLoopVec.begin(), lhsLoopVec.end(), dimInd);
314     if (dimIndIt != lhsLoopVec.end()) {
315       // Query from lhs vars.
316       auto dimIndPos = dimIndIt - lhsLoopVec.begin();
317       auto lhsShape = lhs.getType().dyn_cast<RankedTensorType>().getShape();
318       if (lhsShape[dimIndPos] != ShapedType::kDynamicSize) continue;
319       dimSize = b.create<tensor::DimOp>(loc, lhs, dimIndPos);
320     } else {
321       // query from rhs vars.
322       dimIndIt = std::find(rhsLoopVec.begin(), rhsLoopVec.end(), dimInd);
323       auto dimIndPos = dimIndIt - rhsLoopVec.begin();
324       auto rhsShape = rhs.getType().dyn_cast<RankedTensorType>().getShape();
325       if (rhsShape[dimIndPos] != ShapedType::kDynamicSize) continue;
326       dimSize = b.create<tensor::DimOp>(loc, rhs, dimIndPos);
327     }
328     dynSizes.push_back(dimSize);
329   }
330   return dynSizes;
331 }
332 
333 // Adds indices/axes that are missing from output set.
findSummationAxes(const llvm::SmallSetVector<StringRef,4> & inputSet,const llvm::SmallSetVector<StringRef,4> & outputSet)334 llvm::SmallSetVector<StringRef, 4> findSummationAxes(
335     const llvm::SmallSetVector<StringRef, 4>& inputSet,
336     const llvm::SmallSetVector<StringRef, 4>& outputSet) {
337   llvm::SmallSetVector<StringRef, 4> summationAxes;
338   for (StringRef ind : inputSet) {
339     if (!outputSet.contains(ind)) summationAxes.insert(ind);
340   }
341   return summationAxes;
342 }
343 
344 // Given a 1:1 map from std::string -> affine dimension expression
345 // we can get the affine expression of dimensions that an
346 // operand will access based on the input_str of einsum_config.
347 // For example:
348 // let string_dim_umap = {'a' : d0, 'b' : d1, 'c' : d2}
349 // for einsum_config "abc,cb->acb"
350 // first_input_operand will get umap[{"a","b","c"}] -> (d0, d1, d2).
351 // second_input_operand will get umap[{"c","b"}] -> (d2, d1).
352 // output_operand will get umap[{"a","c","b"}] -> (d0, d2, d1).
getExprFromConfig(const SmallVector<std::string> & loopDims,const DenseMap<StringRef,AffineExpr> & strAffineDimUmap)353 SmallVector<AffineExpr> getExprFromConfig(
354     const SmallVector<std::string>& loopDims,
355     const DenseMap<StringRef, AffineExpr>& strAffineDimUmap) {
356   SmallVector<AffineExpr> exprs;
357   for (const auto& dim : loopDims) {
358     exprs.push_back(strAffineDimUmap.lookup(dim));
359   }
360   return exprs;
361 }
362 
363 // Convert mhlo.einsum op into linalg.generic.
364 // Algorithm in general 3 steps:
365 
366 // Step1) Dissect entire einsum_config to different operands
367 // e.g f("abc,cd->abd") = {lhs:["abc"], rhs:["cd"], out:["abd"]}.
368 
369 // Step2) Split up the string into vector of the elements
370 // e.g {lhs:["abc"], rhs:["cd"], out:["abd"]} = {lhs:["a","b","c"],
371 // rhs:["c","d"], out:["a","b","d"]}.
372 
373 // Step3) Convert the vector into data access
374 // patern represented by affineMaps with affineDimensions e.g
375 // {lhs:["a","b","c"], rhs:["c","d"], out:["a","b","d"]} = {lhs:[d0,d1,d2],
376 // rhs:[d2,d3], out:[d0,d1,d3]}.
377 class EinsumToLinalgConverter : public OpConversionPattern<mhlo::EinsumOp> {
378  public:
379   using OpConversionPattern<mhlo::EinsumOp>::OpConversionPattern;
380 
matchAndRewrite(mhlo::EinsumOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const381   LogicalResult matchAndRewrite(
382       mhlo::EinsumOp op, OpAdaptor adaptor,
383       ConversionPatternRewriter& rewriter) const final {
384     auto getRank = [](Value v) {
385       return v.getType().cast<ShapedType>().getRank();
386     };
387     auto einsumConfig = op.einsum_config();
388 
389     // With the assumption of binary input operand and single output
390     // get the inputs and output operands' indices.
391     // einsum_config = "lhs_loop,rhs_loop->out_loop"
392     std::size_t posArrow = einsumConfig.find(kArrow);
393     std::size_t posComma = einsumConfig.find(kComma);
394 
395     StringRef lhsLoop = einsumConfig.substr(0, posComma);
396     StringRef rhsLoop = einsumConfig.substr(
397         posComma + kComma.size(), posArrow - (posComma + kComma.size()));
398     StringRef outLoop = einsumConfig.substr(posArrow + kArrow.size());
399 
400     // Check for Invalid Configs.
401     // 1.Check that there is only maximum 2 inputs
402     // 2.Check that there is only maximum 1 output
403     // 3.Check that there is 1 kArrow
404     if (rhsLoop.find(kComma) != std::string::npos ||
405         outLoop.find(kComma) != std::string::npos ||
406         outLoop.find(kArrow) != std::string::npos) {
407       return rewriter.notifyMatchFailure(op, "Invalid einsum config!");
408     }
409 
410     // Find result type, if on tensors.
411     auto resultTy = this->typeConverter->convertType(getHloOpResultType(op))
412                         .dyn_cast<RankedTensorType>();
413 
414     // Check result type compatibility.
415     if (!resultTy || !(resultTy.getElementType().isSignlessIntOrFloat())) {
416       return rewriter.notifyMatchFailure(op, "Invalid result type");
417     }
418 
419     // Convert the representation to vector<string>.
420     SmallVector<std::string> lhsEin =
421         getEinsumConfigAsVector(lhsLoop, getRank(adaptor.lhs()));
422     SmallVector<std::string> rhsEin =
423         getEinsumConfigAsVector(rhsLoop, getRank(adaptor.rhs()));
424     SmallVector<std::string> outEin =
425         getEinsumConfigAsVector(outLoop, resultTy.getRank());
426 
427     if (!checkBatchHasEqualRank(lhsEin.size(), lhsLoop, rhsEin.size(), rhsLoop,
428                                 outEin.size(), outLoop)) {
429       return rewriter.notifyMatchFailure(
430           op, "Invalid elipsis('...') within einsum config!");
431     }
432 
433     // Find all unique indices in the input and output.
434     llvm::SmallSetVector<StringRef, 4> inputInd;
435     llvm::SmallSetVector<StringRef, 4> outputInd;
436 
437     inputInd.insert(lhsEin.begin(), lhsEin.end());
438     inputInd.insert(rhsEin.begin(), rhsEin.end());
439     outputInd.insert(outEin.begin(), outEin.end());
440 
441     llvm::SmallSetVector<StringRef, 4> reductionAxe =
442         findSummationAxes(inputInd, outputInd);
443 
444     // Find input/output values and types.
445     auto loc = op.getLoc();
446 
447     // Prepare init tensor for linalg.generic op.
448     auto dynSizes = extractDynamicEinsumSizes(
449         rewriter, loc, adaptor.lhs(), adaptor.rhs(), lhsEin, rhsEin, outEin);
450     Value output = getInitTensor(rewriter, loc, resultTy, dynSizes);
451     if (!reductionAxe.empty()) {
452       output = fillTensorWithZeros(rewriter, loc, output);
453     }
454 
455     // Create indexing maps.
456     // Create a 1:1 map from f:strDimension -> affineDimension.
457     int64_t nloops = inputInd.size();
458     DenseMap<StringRef, AffineExpr> strAffineDimUmap;
459     for (auto& it : llvm::enumerate(inputInd)) {
460       strAffineDimUmap[it.value()] = rewriter.getAffineDimExpr(it.index());
461     }
462 
463     // From einsum_config of each operand in vector<string>, generate
464     // the equivalent vector<AffineExpr>.
465     SmallVector<AffineMap, 4> maps;
466     for (const SmallVector<std::string>& loopOperand :
467          {lhsEin, rhsEin, outEin}) {
468       auto exprs = getExprFromConfig(loopOperand, strAffineDimUmap);
469       maps.push_back(AffineMap::get(nloops, 0, exprs, rewriter.getContext()));
470     }
471 
472     auto linalgOp = rewriter.create<linalg::GenericOp>(
473         loc, resultTy ? resultTy : TypeRange{}, adaptor.getOperands(), output,
474         maps, getEinsumLoopsAttrs(inputInd, reductionAxe),
475         [reductionAxe](OpBuilder& b, Location nestedLoc, ValueRange args) {
476           Value resultVal =
477               b.create<mlir::arith::MulFOp>(nestedLoc, args[0], args[1]);
478           if (!reductionAxe.empty()) {
479             resultVal =
480                 b.create<mlir::arith::AddFOp>(nestedLoc, args[2], resultVal);
481           }
482           b.create<linalg::YieldOp>(nestedLoc, resultVal);
483         },
484         pruneAttributeList(op));
485     rewriter.replaceOp(op, linalgOp.getResults());
486     return success();
487   }
488 
489  private:
490   static constexpr StringRef kArrow = "->";
491   static constexpr StringRef kComma = ",";
492   static constexpr StringRef kEllipsis = "...";
493 
494   static bool checkBatchHasEqualRank(size_t lhsRank, StringRef lhsLoop,
495                                      size_t rhsRank, StringRef rhsLoop,
496                                      size_t outRank, StringRef outLoop);
497   static SmallVector<std::string> getEinsumConfigAsVector(StringRef loop,
498                                                           size_t operandRank);
499 };
500 
501 // Definition of util const member variables.
502 constexpr StringRef EinsumToLinalgConverter::kArrow;
503 constexpr StringRef EinsumToLinalgConverter::kComma;
504 constexpr StringRef EinsumToLinalgConverter::kEllipsis;
505 
506 // Convert the representation from string/vector<char> to vector<string>.
507 // i.e ("abc") -> {"a", "b", "c"}. For cases with ellipsis with batch rank 3:
508 // get loop_dim = f("ab...cde") = {"a","b","0","1","2","c","d","e"}
getEinsumConfigAsVector(StringRef loop,size_t operandRank)509 SmallVector<std::string> EinsumToLinalgConverter::getEinsumConfigAsVector(
510     StringRef loop, size_t operandRank) {
511   SmallVector<std::string> loopDim;
512   size_t preElip = loop.find(kEllipsis);
513   bool hasElip = preElip != std::string::npos;
514   if (!hasElip) preElip = loop.size();
515   // Add the dimension until the end or up to ellipsis if it exist.
516   for (int64_t preElipInd = 0; preElipInd < static_cast<int64_t>(preElip);
517        preElipInd++) {
518     loopDim.push_back(loop.substr(preElipInd, 1).str());
519   }
520   if (!hasElip) return loopDim;
521   // Case where Ellipsis presence:
522   size_t nonBatchRank = loop.size() - kEllipsis.size();
523   size_t batchRank = operandRank - nonBatchRank;
524   // Add the batch dimension ("0",...,"N") where N is rank of batch into the
525   // loop.
526   for (int64_t batchInd = 0; batchInd < static_cast<int64_t>(batchRank);
527        batchInd++) {
528     loopDim.push_back(std::to_string(batchInd));
529   }
530   // Add the dimension after ellipsis into the loop.
531   int postElip = preElip + kEllipsis.size();
532   for (int64_t postElipInd = postElip;
533        postElipInd < static_cast<int64_t>(loop.size()); postElipInd++) {
534     loopDim.push_back(loop.substr(postElipInd, 1).str());
535   }
536   return loopDim;
537 }
538 
539 // Returns true if all operand's batch has same rank.
checkBatchHasEqualRank(size_t lhsRank,StringRef lhsLoop,size_t rhsRank,StringRef rhsLoop,size_t outRank,StringRef outLoop)540 bool EinsumToLinalgConverter::checkBatchHasEqualRank(
541     size_t lhsRank, StringRef lhsLoop, size_t rhsRank, StringRef rhsLoop,
542     size_t outRank, StringRef outLoop) {
543   SmallVector<int, 3> batchRankVec;
544   if (lhsRank != lhsLoop.size()) {
545     size_t lhsBatchRank = lhsRank - (lhsLoop.size() - kEllipsis.size());
546     batchRankVec.push_back(lhsBatchRank);
547   }
548   if (rhsRank != rhsLoop.size()) {
549     size_t rhsBatchRank = rhsRank - (rhsLoop.size() - kEllipsis.size());
550     batchRankVec.push_back(rhsBatchRank);
551   }
552   if (outRank != outLoop.size()) {
553     size_t outBatchRank = outRank - (outLoop.size() - kEllipsis.size());
554     batchRankVec.push_back(outBatchRank);
555   }
556   bool batchHasEqualRank = true;
557 
558   // Condition is valid if only 1 operand or less have batches.
559   if (batchRankVec.size() < 2) return batchHasEqualRank;
560   if (!std::equal(batchRankVec.begin() + 1, batchRankVec.end(),
561                   batchRankVec.begin()) &&
562       batchRankVec.size() > 1)
563     batchHasEqualRank = false;
564   return batchHasEqualRank;
565 }
566 
567 template <typename MhloOp>
568 class ScalarPointwiseToStandardConverter : public OpConversionPattern<MhloOp> {
569  public:
570   using OpConversionPattern<MhloOp>::OpConversionPattern;
571 
matchAndRewrite(MhloOp mhloOp,ConversionPatternRewriter & rewriter) const572   LogicalResult matchAndRewrite(
573       MhloOp mhloOp, ConversionPatternRewriter& rewriter) const final {
574     auto loc = mhloOp.getLoc();
575     auto argType =
576         mhloOp.getOperand(0).getType().template dyn_cast<ShapedType>();
577     if (!argType || !argType.getElementType().isSignlessIntOrFloat() ||
578         (argType.getRank() != 0)) {
579       return failure();
580     }
581 
582     // Create two loads from the input.
583     auto lhs = rewriter.create<memref::LoadOp>(loc, mhloOp.lhs());
584     auto rhs = rewriter.create<memref::LoadOp>(loc, mhloOp.rhs());
585     Value opResult = mhlo::MhloOpToStdScalarOp::mapOp(
586         mhloOp, argType.getElementType(), llvm::ArrayRef<Value>{lhs, rhs},
587         &rewriter);
588     rewriter.create<memref::StoreOp>(loc, opResult, mhloOp.out());
589     rewriter.eraseOp(mhloOp);
590     return success();
591   }
592 };
593 
594 /// Base class for lowering HLO operations that have one operand and one result,
595 /// and are semantically equivalent to a copy of the input to the output (like
596 /// transpose, some reshape, etc.). The derived classes need to provide a method
597 /// `getIndexingMaps` that returns AffineMaps for the index maps of the input
598 /// and the output.
599 template <typename Derived, typename OpTy>
600 class DataMovementOpConverter : public OpConversionPattern<OpTy> {
601  public:
602   using OpConversionPattern<OpTy>::OpConversionPattern;
603 
matchAndRewrite(OpTy op,typename OpTy::Adaptor adaptor,ConversionPatternRewriter & rewriter) const604   LogicalResult matchAndRewrite(
605       OpTy op, typename OpTy::Adaptor adaptor,
606       ConversionPatternRewriter& rewriter) const final {
607     if (!verifyHloOpBufferOrTensorSemantics(op)) return failure();
608     auto resultType = getHloOpResultType(op);
609     resultType = this->typeConverter->convertType(resultType)
610                      .template cast<ShapedType>();
611 
612     SmallVector<AffineMap, 2> indexingMaps =
613         Derived::getIndexingMaps(op, &rewriter);
614     if (indexingMaps.empty()) return failure();
615 
616     auto nloops = resultType.getRank();
617     auto loc = op.getLoc();
618     auto linalgOp = rewriter.create<linalg::GenericOp>(
619         loc,
620         /*resultTensorTypes=*/resultType,
621         /*inputs=*/adaptor.getOperands().front(),
622         /*outputBuffers=*/
623 
624         ValueRange{getInitTensorFor(rewriter, loc, resultType, op,
625                                     adaptor.getOperands())},
626         indexingMaps, getNParallelLoopsAttrs(nloops),
627         [&](OpBuilder& nestedBuilder, Location /*nested_loc*/,
628             ValueRange args) {
629           nestedBuilder.create<linalg::YieldOp>(loc, *args.begin());
630         },
631         pruneAttributeList(op));
632     rewriter.replaceOp(op, linalgOp.getOperation()->getResults());
633     return success();
634   }
635 };
636 
637 /// Pattern to convert BroadcastOp to Linalg ops.
638 template <typename OpTy>
639 class BroadcastConverter
640     : public DataMovementOpConverter<BroadcastConverter<OpTy>, OpTy> {
641  public:
642   using DataMovementOpConverter<BroadcastConverter,
643                                 OpTy>::DataMovementOpConverter;
644 
getIndexingMaps(OpTy broadcastOp,Builder * b)645   static SmallVector<AffineMap, 2> getIndexingMaps(OpTy broadcastOp,
646                                                    Builder* b) {
647     ShapedType inputType =
648         broadcastOp.operand().getType().template cast<ShapedType>();
649     unsigned inputRank = inputType.getRank();
650     unsigned nloops = getHloOpResultType(broadcastOp).getRank();
651 
652     // BroadcastOp prepends the dimensions in the `broadcast_sizes` attribute to
653     // the input's dimensions.
654     unsigned numPrependedDims = llvm::size(broadcastOp.broadcast_sizes());
655     SmallVector<AffineExpr, 4> inputDimExprs;
656     inputDimExprs.reserve(inputRank);
657     for (unsigned i = 0; i < inputRank; ++i) {
658       inputDimExprs.push_back(b->getAffineDimExpr(numPrependedDims + i));
659     }
660 
661     AffineMap inputMap;
662     MLIRContext* context = b->getContext();
663     if (inputDimExprs.empty()) {
664       // The input is a scalar, i.e. this is a scalar broadcast op.
665       inputMap = AffineMap::get(nloops, /*symbolCount=*/0, context);
666     } else {
667       inputMap =
668           AffineMap::get(nloops, /*symbolCount=*/0, inputDimExprs, context);
669     }
670     return {inputMap, b->getMultiDimIdentityMap(nloops)};
671   }
672 };
673 
674 class HloBroadcastInDimConverter
675     : public DataMovementOpConverter<HloBroadcastInDimConverter,
676                                      mhlo::BroadcastInDimOp> {
677  public:
678   using DataMovementOpConverter<
679       HloBroadcastInDimConverter,
680       mhlo::BroadcastInDimOp>::DataMovementOpConverter;
681 
getIndexingMaps(mhlo::BroadcastInDimOp broadcastOp,Builder * b)682   static SmallVector<AffineMap, 2> getIndexingMaps(
683       mhlo::BroadcastInDimOp broadcastOp, Builder* b) {
684     auto resultType = getHloOpResultType(broadcastOp);
685     auto operandType =
686         broadcastOp.operand().getType().template cast<ShapedType>();
687     unsigned nloops = resultType.getRank();
688 
689     // The input is a scalar, i.e. this is a scalar broadcast op.
690     if (operandType.getRank() == 0) {
691       return {AffineMap::get(nloops, /*symbolCount=*/0, b->getContext()),
692               b->getMultiDimIdentityMap(nloops)};
693     }
694 
695     auto operandShape = operandType.getShape();
696     SmallVector<AffineExpr, 4> dimExprs;
697     dimExprs.reserve(nloops);
698 
699     if (broadcastOp.broadcast_dimensions()) {
700       for (const auto& broadcastDim :
701            enumerate(broadcastOp.broadcast_dimensions().getValues<APInt>())) {
702         int size = broadcastDim.value().getSExtValue();
703         bool expansionNeeded = operandShape[broadcastDim.index()] == 1 &&
704                                resultType.getShape()[size] != 1;
705         dimExprs.push_back(expansionNeeded ? b->getAffineConstantExpr(0)
706                                            : b->getAffineDimExpr(size));
707       }
708     }
709     return {
710         AffineMap::get(nloops, /*symbolCount=*/0, dimExprs, b->getContext()),
711         b->getMultiDimIdentityMap(nloops)};
712   }
713 };
714 
715 // If the input has a static shape we know exactly when the broadcast must
716 // expand (the dimension is 1, which also trivially expands to 1) or will never
717 // expand (the dimension is not 1). We can also source the information from the
718 // optionally provided attrbibutes on statically known broadcasting behavior.
719 // This means we can lower the broadcast just as we would lower a fully static
720 // broadcast and go directly to `linalg.generic`.
721 
722 // This also covers the important case of broadcasting a scalar. Ideally the
723 // pattern (`mhlo.constant` -> `mhlo.dynamic_broadcast_in_dim`) should be
724 // converted to a tensor dialect op similar to TF's `ConstantLikeOp`.
725 class HloDynamicBroadcastInDimConverter
726     : public OpConversionPattern<mhlo::DynamicBroadcastInDimOp> {
727  public:
728   using OpConversionPattern<mhlo::DynamicBroadcastInDimOp>::OpConversionPattern;
729 
matchAndRewrite(mhlo::DynamicBroadcastInDimOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const730   LogicalResult matchAndRewrite(
731       mhlo::DynamicBroadcastInDimOp op, OpAdaptor adaptor,
732       ConversionPatternRewriter& rewriter) const final {
733     Value operand = adaptor.operand();
734     auto operandType = operand.getType().dyn_cast<RankedTensorType>();
735     if (!operandType) return failure();
736     auto resultType =
737         typeConverter->convertType(op.getType()).dyn_cast<RankedTensorType>();
738     if (!resultType) return failure();
739 
740     // Determine dimension expressions based on whether the dimension is
741     // expanding (0) or non-expanding (identity), and fail if we cannot decide
742     // this.
743     SmallVector<AffineExpr> dimExprs(operandType.getRank(), nullptr);
744 
745     // Use static type info.
746     auto bcastDims = llvm::to_vector(
747         llvm::map_range(op.broadcast_dimensions(), [](const APInt& d) {
748           return static_cast<int64_t>(d.getLimitedValue());
749         }));
750     for (const auto& it : llvm::enumerate(operandType.getShape())) {
751       if (ShapedType::isDynamic(it.value())) continue;
752       bool isExpanding = it.value() == 1;
753       dimExprs[it.index()] =
754           isExpanding ? rewriter.getAffineConstantExpr(0)
755                       : rewriter.getAffineDimExpr(bcastDims[it.index()]);
756     }
757 
758     // Use annotated expansion behavior, if available.
759     if (op.known_expanding_dimensions()) {
760       for (const auto& it :
761            op.known_expanding_dimensions()->getValues<APInt>()) {
762         auto i = it.getLimitedValue();
763         dimExprs[i] = rewriter.getAffineConstantExpr(0);
764       }
765     }
766     if (op.known_nonexpanding_dimensions()) {
767       for (const auto& it :
768            op.known_nonexpanding_dimensions()->getValues<APInt>()) {
769         auto i = it.getLimitedValue();
770         dimExprs[i] = rewriter.getAffineDimExpr(bcastDims[i]);
771       }
772     }
773 
774     // Fail if unknown expansion behavior remains.
775     if (!llvm::all_of(dimExprs, [](AffineExpr expr) { return expr; }))
776       return failure();
777 
778     // Materialize `linalg.generic` op.
779     Location loc = op.getLoc();
780     int64_t nloops = resultType.getRank();
781     Value init =
782         getInitTensorFor(rewriter, loc, resultType, op, adaptor.getOperands());
783     rewriter.replaceOpWithNewOp<linalg::GenericOp>(
784         op, TypeRange{init.getType()}, ValueRange{operand},
785         /*outputBuffers=*/ValueRange{init},
786         llvm::makeArrayRef(
787             {AffineMap::get(/*dimCount=*/nloops, /*symbolCount=*/0, dimExprs,
788                             rewriter.getContext()),
789              rewriter.getMultiDimIdentityMap(nloops)}),
790         getNParallelLoopsAttrs(nloops),
791         [&](OpBuilder& nestedBuilder, Location /*nested_loc*/,
792             ValueRange args) {
793           nestedBuilder.create<linalg::YieldOp>(loc, *args.begin());
794         },
795         pruneAttributeList(op));
796     return success();
797   }
798 };
799 
800 template <typename OpTy>
801 class TransposeConverter
802     : public DataMovementOpConverter<TransposeConverter<OpTy>, OpTy> {
803  public:
804   using DataMovementOpConverter<TransposeConverter<OpTy>,
805                                 OpTy>::DataMovementOpConverter;
getIndexingMaps(OpTy op,Builder * b)806   static SmallVector<AffineMap, 2> getIndexingMaps(OpTy op, Builder* b) {
807     auto resultType = getHloOpResultType(op).template cast<ShapedType>();
808     auto nloops = resultType.getRank();
809     SmallVector<AffineExpr, 2> inputExprs;
810     inputExprs.resize(resultType.getRank());
811     for (const auto& permutation : llvm::enumerate(op.permutation())) {
812       inputExprs[permutation.value().getZExtValue()] =
813           b->getAffineDimExpr(permutation.index());
814     }
815     return {
816         AffineMap::get(nloops, /*symbolCount=*/0, inputExprs, b->getContext()),
817         b->getMultiDimIdentityMap(nloops)};
818   }
819 };
820 
821 // Lowers mhlo.RealDynamicSliceOp to tensor.extract_slice and other
822 // arith/tensor dialect ops.
823 class RealDynamicSliceConverter
824     : public OpConversionPattern<mhlo::RealDynamicSliceOp> {
825  public:
826   using OpConversionPattern<mhlo::RealDynamicSliceOp>::OpConversionPattern;
827 
828   // Computes size of a slice as
829   //   size = ceil((limit - start)/stride)
computeSize(Location loc,Value start,Value limit,Value stride,ConversionPatternRewriter & b)830   static Value computeSize(Location loc, Value start, Value limit, Value stride,
831                            ConversionPatternRewriter& b) {
832     Value delta = b.create<arith::SubIOp>(loc, limit, start);
833     Value ret = b.create<arith::CeilDivUIOp>(loc, delta, stride);
834     if (ret.getType().isIndex()) return ret;
835     return b.create<arith::IndexCastOp>(loc, b.getIndexType(), ret);
836   }
837 
matchAndRewrite(mhlo::RealDynamicSliceOp realDynamicSliceOp,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const838   LogicalResult matchAndRewrite(
839       mhlo::RealDynamicSliceOp realDynamicSliceOp, OpAdaptor adaptor,
840       ConversionPatternRewriter& rewriter) const final {
841     Location loc = realDynamicSliceOp.getLoc();
842     auto argType = adaptor.operand().getType().dyn_cast<ShapedType>();
843     if (!argType || !argType.hasRank()) {
844       return rewriter.notifyMatchFailure(realDynamicSliceOp,
845                                          "require known-rank args");
846     }
847 
848     Type dimElementType = getElementTypeOrSelf(adaptor.start_indices());
849     if (getElementTypeOrSelf(adaptor.limit_indices()) != dimElementType ||
850         getElementTypeOrSelf(adaptor.strides()) != dimElementType) {
851       return rewriter.notifyMatchFailure(
852           realDynamicSliceOp,
853           "requires same element type for all dimension specification");
854     }
855     Type arithType =
856         dimElementType.isIndex() ? rewriter.getI64Type() : dimElementType;
857     Type indexType = rewriter.getIndexType();
858 
859     auto resultType =
860         this->typeConverter->convertType(realDynamicSliceOp.getType())
861             .cast<RankedTensorType>();
862     Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
863     SmallVector<OpFoldResult, 4> offsets, sizes, strides;
864     SmallVector<Type, 3> clampType(3, arithType);
865     for (auto i : llvm::seq<unsigned>(0, argType.getRank())) {
866       Value dim = rewriter.create<arith::ConstantIndexOp>(loc, i);
867       Value start =
868           rewriter.create<tensor::ExtractOp>(loc, adaptor.start_indices(), dim);
869       Value limit =
870           rewriter.create<tensor::ExtractOp>(loc, adaptor.limit_indices(), dim);
871       Value stride =
872           rewriter.create<tensor::ExtractOp>(loc, adaptor.strides(), dim);
873 
874       // Compute i-th dimension size of the result : size[i].
875       // If the i-th dimension of the result type is known, we go ahead with it
876       // else we compute it using limit, start and stride values.
877       int64_t resultDimSize = resultType.getDimSize(i);
878       Value size =
879           ShapedType::isDynamic(resultDimSize)
880               ? computeSize(loc, start, limit, stride, rewriter)
881               : rewriter.create<arith::ConstantIndexOp>(loc, resultDimSize);
882 
883       // We can now convert start to index.
884       if (!start.getType().isIndex())
885         start = rewriter.create<arith::IndexCastOp>(
886             loc, rewriter.getIndexType(), start);
887 
888       // Fetch i-th dimension size of the operand and calculate upper bound as
889       //   ub = operand_dim[i] - size[i]
890       Value operandDimSize =
891           rewriter.createOrFold<tensor::DimOp>(loc, adaptor.operand(), dim);
892       Value upperBound =
893           rewriter.createOrFold<arith::SubIOp>(loc, operandDimSize, size);
894 
895       // We clamp the start_index to keep it bounded as
896       //   0 <= start_index[i] <= ub
897       // Clamp does not support index type, so cast to integer type.
898       start = rewriter.create<arith::MaxSIOp>(loc, start, zero);
899       start = rewriter.create<arith::MinSIOp>(loc, start, upperBound);
900 
901       offsets.push_back(start);
902       if (ShapedType::isDynamic(resultDimSize))
903         sizes.push_back(size);
904       else
905         sizes.push_back(IntegerAttr::get(indexType, resultDimSize));
906 
907       if (!stride.getType().isIndex())
908         stride =
909             rewriter.createOrFold<arith::IndexCastOp>(loc, indexType, stride);
910       strides.push_back(stride);
911     }
912 
913     rewriter.replaceOpWithNewOp<tensor::ExtractSliceOp>(
914         realDynamicSliceOp, resultType, adaptor.operand(), offsets, sizes,
915         strides);
916     return success();
917   }
918 };
919 
920 // Converts reshape ops that can be proven to be either a collapse of dimensions
921 // or expansion of dimensions of the operand.
922 class ReshapeOpConverter : public OpConversionPattern<mhlo::ReshapeOp> {
923  public:
924   using OpConversionPattern::OpConversionPattern;
925 
matchAndRewrite(mhlo::ReshapeOp reshapeOp,mhlo::ReshapeOp::Adaptor adaptor,ConversionPatternRewriter & rewriter) const926   LogicalResult matchAndRewrite(
927       mhlo::ReshapeOp reshapeOp, mhlo::ReshapeOp::Adaptor adaptor,
928       ConversionPatternRewriter& rewriter) const final {
929     if (!verifyHloOpBufferOrTensorSemantics(reshapeOp)) return failure();
930     auto operand = adaptor.operand();
931     auto operandType = operand.getType().cast<ShapedType>();
932     auto elemType = operandType.getElementType();
933     auto resultType = reshapeOp.getType().cast<ShapedType>();
934 
935     if (!resultType.hasStaticShape()) return failure();
936 
937     resultType = typeConverter->convertType(resultType).cast<ShapedType>();
938 
939     // Special case where the result is a scalar.
940     if (resultType.getRank() == 0 && !operandType.hasStaticShape()) {
941       // This means all dimensions of the operand need to be 1. We add a cast to
942       // cast the dynamic dimensions to 1.
943       auto staticType = RankedTensorType::get(
944           llvm::SmallVector<int64_t>(operandType.getRank(), 1), elemType);
945       operand = rewriter.create<tensor::CastOp>(reshapeOp.getLoc(), staticType,
946                                                 operand);
947       rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>(
948           reshapeOp, resultType, operand, ArrayRef<ReassociationIndices>{});
949       return success();
950     }
951 
952     // Compute the reassociation maps for the linalg operation. This will
953     // succeed if the reshape can be done with a single expand_shape or
954     // collapse_shape.
955     if (Optional<SmallVector<ReassociationIndices>> reassociationMap =
956             getReassociationIndicesForReshape(operandType, resultType)) {
957       if (resultType.getRank() < operandType.getRank()) {
958         // We have found a working reassociation map. If the operand is dynamic,
959         // we first need to cast all unknown dimensions in the input that get
960         // collapsed to a static-sized dimension in the output, to 1.
961         SmallVector<int64_t> shape(operandType.getShape().begin(),
962                                    operandType.getShape().end());
963         for (const auto& map : llvm::enumerate(*reassociationMap)) {
964           // If the result dim is dynamic, we do not mind dynamic entries in the
965           // source.
966           if (resultType.isDynamicDim(map.index())) continue;
967           for (auto targetDim : map.value()) {
968             if (shape[targetDim] == ShapedType::kDynamicSize)
969               shape[targetDim] = 1;
970           }
971         }
972         // Insert a cast if types are not the same (ignoring sparse encoding).
973         auto enc = sparse_tensor::getSparseTensorEncoding(operandType);
974         auto newOperandType = RankedTensorType::get(shape, elemType, enc);
975         if (newOperandType != operandType) {
976           operand = rewriter.create<tensor::CastOp>(reshapeOp.getLoc(),
977                                                     newOperandType, operand);
978         }
979         // Generate collapse operation.
980         rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>(
981             reshapeOp, resultType, operand, *reassociationMap);
982       } else {
983         // Generate expand operation.
984         rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
985             reshapeOp, resultType, operand, *reassociationMap);
986       }
987       return success();
988     }
989 
990     Value collapsedOp = operand;
991     Location loc = reshapeOp.getLoc();
992     auto getIdentityExprs = [&rewriter](int64_t n) {
993       SmallVector<AffineExpr, 4> exprs;
994       for (int i = 0; i < n; ++i) exprs.push_back(rewriter.getAffineDimExpr(i));
995       return exprs;
996     };
997     // Otherwise, we need to first reduce all source dimensions into one and
998     // then expand to the destination dimensions. If there is only a single
999     // source dimension, the reduce step can be skipped. TensorCollapseShape
1000     // expects a different rank of operand and result.
1001     if (operandType.getRank() != 1) {
1002       SmallVector<ReassociationExprs, 4> collapsingMap = {
1003           // Use operand_type here because we need to collapse all operands
1004           // dimensions.
1005           getIdentityExprs(operandType.getRank())};
1006 
1007       collapsedOp =
1008           rewriter.create<tensor::CollapseShapeOp>(loc, operand, collapsingMap);
1009     }
1010     // Cast to a known static type if the input has dynamic dimensions.
1011     int64_t totalElems = resultType.getNumElements();
1012     auto collapsedType = RankedTensorType::get({totalElems}, elemType);
1013     collapsedOp =
1014         rewriter.create<tensor::CastOp>(loc, collapsedType, collapsedOp);
1015     if (resultType.getRank() == 1) {
1016       rewriter.replaceOp(reshapeOp, collapsedOp);
1017     } else {
1018       SmallVector<ReassociationExprs, 4> expandingMap = {
1019           // Use resultType here because we need to expand to all result
1020           // dimensions.
1021           getIdentityExprs(resultType.getRank())};
1022       rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
1023           reshapeOp, resultType, collapsedOp, expandingMap);
1024     }
1025     return success();
1026   }
1027 };
1028 
1029 template <typename OpTy>
1030 class IotaConverter : public OpConversionPattern<OpTy> {
1031  public:
1032   using OpConversionPattern<OpTy>::OpConversionPattern;
1033 
matchAndRewrite(OpTy iotaOp,typename OpTy::Adaptor adaptor,ConversionPatternRewriter & rewriter) const1034   LogicalResult matchAndRewrite(
1035       OpTy iotaOp, typename OpTy::Adaptor adaptor,
1036       ConversionPatternRewriter& rewriter) const final {
1037     ShapedType resultShapedType = getHloOpResultType(iotaOp);
1038     if (!resultShapedType) return failure();
1039     resultShapedType = this->typeConverter->convertType(resultShapedType)
1040                            .template dyn_cast<ShapedType>();
1041 
1042     Type resultElementType = resultShapedType.getElementType();
1043 
1044     // Construct the indexing maps needed for linalg.generic ops.
1045     unsigned nloops = resultShapedType.getRank();
1046 
1047     Location loc = iotaOp.getLoc();
1048     auto linalgOp = rewriter.create<linalg::GenericOp>(
1049         loc,
1050         /*resultTensorTypes=*/
1051         ArrayRef<Type>{resultShapedType},
1052         /*inputs=*/ValueRange{},
1053         /*outputBuffers=*/
1054 
1055         ValueRange{getInitTensorFor(rewriter, loc, resultShapedType, iotaOp,
1056                                     adaptor.getOperands())},
1057         llvm::makeArrayRef(rewriter.getMultiDimIdentityMap(nloops)),
1058         getNParallelLoopsAttrs(nloops),
1059         [&](OpBuilder& nestedBuilder, Location nestedLoc, ValueRange /*args*/) {
1060           Value indexOp = nestedBuilder.create<linalg::IndexOp>(
1061               nestedLoc, iotaOp.iota_dimension());
1062           Type unwrappedResultElementType = resultElementType;
1063           if (auto complexType =
1064                   unwrappedResultElementType.dyn_cast<ComplexType>())
1065             unwrappedResultElementType = complexType.getElementType();
1066           Value castOp = nestedBuilder.create<arith::IndexCastOp>(
1067               nestedLoc,
1068               nestedBuilder.getIntegerType(
1069                   unwrappedResultElementType.getIntOrFloatBitWidth()),
1070               indexOp);
1071           castOp = mhlo::MhloOpToStdScalarOp::mapOpOfType<mhlo::ConvertOp>(
1072               nestedLoc, resultElementType, castOp.getType(), castOp,
1073               &nestedBuilder);
1074           nestedBuilder.create<linalg::YieldOp>(nestedLoc, castOp);
1075         },
1076         pruneAttributeList(iotaOp));
1077     rewriter.replaceOp(iotaOp, linalgOp.result_tensors());
1078     return success();
1079   }
1080 };
1081 
1082 /// Converts mhlo.concatenate operation to a linalg.generic op.
1083 struct ConcatenateConverter : public OpConversionPattern<mhlo::ConcatenateOp> {
1084   using OpConversionPattern<mhlo::ConcatenateOp>::OpConversionPattern;
1085 
matchAndRewritemlir::mhlo::__anon071fe3ef0111::ConcatenateConverter1086   LogicalResult matchAndRewrite(
1087       mhlo::ConcatenateOp op, OpAdaptor adaptor,
1088       ConversionPatternRewriter& rewriter) const override {
1089     // Shortcut the one-operand case, simplifies code below.
1090     if (adaptor.getOperands().size() == 1) {
1091       rewriter.replaceOp(op, adaptor.getOperands()[0]);
1092       return success();
1093     }
1094 
1095     auto resultType = this->typeConverter->convertType(op.getResult().getType())
1096                           .dyn_cast<RankedTensorType>();
1097     if (!resultType) return failure();
1098 
1099     uint64_t dim = op.dimension();
1100     Location loc = op.getLoc();
1101     Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
1102 
1103     // Allocate the output tensor with init_tensor.
1104     Value result =
1105         getInitTensorFor(rewriter, loc, resultType, op, adaptor.getOperands());
1106 
1107     // Generate a generic op to gather the elements of the concatenate. This is
1108     // awkward standalone but allows fusion with other generic ops.
1109     int64_t nloops = resultType.getRank();
1110     rewriter.replaceOpWithNewOp<linalg::GenericOp>(
1111         op,
1112         /*resultTensorTypes=*/resultType,
1113         /*inputs=*/ValueRange{}, /*outputBuffers=*/result,
1114         llvm::makeArrayRef(rewriter.getMultiDimIdentityMap(nloops)),
1115         getNParallelLoopsAttrs(nloops),
1116         [&](OpBuilder& nestedBuilder, Location loc, ValueRange) {
1117           OpBuilder b = nestedBuilder;
1118           Value concatDimSize = zero;
1119           Value result;
1120 
1121           SmallVector<Value, 4> extractIndices;
1122           extractIndices.reserve(nloops);
1123           for (int64_t i = 0; i < nloops; i++) {
1124             extractIndices.push_back(b.create<linalg::IndexOp>(loc, i));
1125           }
1126 
1127           Value indexOp = b.create<linalg::IndexOp>(loc, dim);
1128           for (auto& it : llvm::enumerate(adaptor.getOperands())) {
1129             Value arg = it.value();
1130             Value newConcatDimSize;
1131             scf::IfOp ifOp;
1132             if (it.index() != (adaptor.getOperands().size() - 1)) {
1133               // Calculate how far along we have iterated along the concatenate
1134               // dimension. That way we can tell which input to select.
1135               newConcatDimSize = b.create<arith::AddIOp>(
1136                   loc, concatDimSize, b.create<tensor::DimOp>(loc, arg, dim));
1137               Value cmp = b.create<arith::CmpIOp>(loc, rewriter.getI1Type(),
1138                                                   arith::CmpIPredicate::ult,
1139                                                   indexOp, newConcatDimSize);
1140               ifOp = b.create<scf::IfOp>(loc, resultType.getElementType(), cmp,
1141                                          true);
1142               if (result) {
1143                 b.create<scf::YieldOp>(loc, ifOp->getResults()[0]);
1144               } else {
1145                 result = ifOp->getResults()[0];
1146               }
1147 
1148               b = ifOp.getThenBodyBuilder(b.getListener());
1149             }
1150 
1151             // Now adjust the index for the concatenated dimension to fit into
1152             // the selected tensor and do an extract at that position.
1153             extractIndices[dim] =
1154                 b.create<arith::SubIOp>(loc, indexOp, concatDimSize);
1155             Value extract =
1156                 b.create<tensor::ExtractOp>(loc, arg, extractIndices);
1157             b.create<scf::YieldOp>(loc, extract);
1158 
1159             if (ifOp) {
1160               b = ifOp.getElseBodyBuilder(b.getListener());
1161               concatDimSize = newConcatDimSize;
1162             }
1163           }
1164           nestedBuilder.create<linalg::YieldOp>(loc, result);
1165         },
1166         pruneAttributeList(op));
1167     return success();
1168   }
1169 };
1170 
1171 class ConstConverterTensor : public OpConversionPattern<mhlo::ConstantOp> {
1172  public:
1173   using OpConversionPattern::OpConversionPattern;
1174 
matchAndRewrite(mhlo::ConstantOp constOp,OpAdaptor,ConversionPatternRewriter & rewriter) const1175   LogicalResult matchAndRewrite(
1176       mhlo::ConstantOp constOp, OpAdaptor /*adaptor*/,
1177       ConversionPatternRewriter& rewriter) const final {
1178     auto valueAttr = constOp.value().cast<DenseElementsAttr>();
1179     auto type =
1180         typeConverter->convertType(constOp.getType()).cast<ShapedType>();
1181     if (type != constOp.getType()) {
1182       // Signedness conversion.
1183       valueAttr = valueAttr.mapValues(type.getElementType(),
1184                                       [](const APInt& i) { return i; });
1185     }
1186     rewriter.replaceOpWithNewOp<arith::ConstantOp>(constOp, type, valueAttr);
1187     return success();
1188   }
1189 };
1190 
1191 // TODO(b/156787842): Support the lowering for dynamic shapes.
1192 class ReverseConverter
1193     : public DataMovementOpConverter<ReverseConverter, mhlo::ReverseOp> {
1194  public:
1195   using DataMovementOpConverter<ReverseConverter,
1196                                 mhlo::ReverseOp>::DataMovementOpConverter;
getIndexingMaps(mhlo::ReverseOp op,Builder * b)1197   static SmallVector<AffineMap, 2> getIndexingMaps(mhlo::ReverseOp op,
1198                                                    Builder* b) {
1199     auto resultType = getHloOpResultType(op).cast<ShapedType>();
1200     auto nloops = resultType.getRank();
1201     SmallVector<AffineExpr, 2> inputExprs;
1202     inputExprs.reserve(nloops);
1203     for (int i = 0; i < nloops; ++i)
1204       inputExprs.push_back(b->getAffineDimExpr(i));
1205     for (auto dim : op.dimensions()) {
1206       int i = dim.getZExtValue();
1207       if (resultType.isDynamicDim(i)) return {};
1208       int n = resultType.getShape()[i];
1209       inputExprs[i] = b->getAffineConstantExpr(n - 1) - inputExprs[i];
1210     }
1211     return {
1212         AffineMap::get(nloops, /*symbolCount=*/0, inputExprs, b->getContext()),
1213         b->getMultiDimIdentityMap(nloops)};
1214   }
1215 };
1216 
1217 class SliceConverter : public OpConversionPattern<mhlo::SliceOp> {
1218  public:
1219   using OpConversionPattern::OpConversionPattern;
1220 
matchAndRewrite(mhlo::SliceOp sliceOp,typename mhlo::SliceOp::Adaptor adaptor,ConversionPatternRewriter & rewriter) const1221   LogicalResult matchAndRewrite(
1222       mhlo::SliceOp sliceOp, typename mhlo::SliceOp::Adaptor adaptor,
1223       ConversionPatternRewriter& rewriter) const final {
1224     auto argType = adaptor.getOperands()[0].getType().dyn_cast<ShapedType>();
1225     if (!argType || !argType.hasRank()) {
1226       return rewriter.notifyMatchFailure(sliceOp, "expects known-rank args");
1227     }
1228 
1229     SmallVector<OpFoldResult, 3> offsets, sizes, strides;
1230     for (int i = 0, e = argType.getRank(); i < e; ++i) {
1231       auto start = sliceOp.start_indices().getValues<int64_t>()[i];
1232       auto limit = sliceOp.limit_indices().getValues<int64_t>()[i];
1233       auto stride = sliceOp.strides().getValues<int64_t>()[i];
1234       offsets.push_back(rewriter.getI64IntegerAttr(start));
1235       // Say that there are k elements in total, we have condition:
1236       //   start + (k - 1) * strides <= limit - 1
1237       // ->
1238       //   k <= (limit - 1 - start) / strides + 1
1239       sizes.push_back(
1240           rewriter.getI64IntegerAttr((limit - 1 - start) / stride + 1));
1241       strides.push_back(rewriter.getI64IntegerAttr(stride));
1242     }
1243     rewriter.replaceOpWithNewOp<tensor::ExtractSliceOp>(
1244         sliceOp, adaptor.getOperands()[0], offsets, sizes, strides);
1245     return success();
1246   }
1247 };
1248 
1249 class DynamicSliceConverter : public OpConversionPattern<mhlo::DynamicSliceOp> {
1250  public:
1251   using OpConversionPattern<mhlo::DynamicSliceOp>::OpConversionPattern;
1252 
matchAndRewrite(mhlo::DynamicSliceOp dynamicSliceOp,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const1253   LogicalResult matchAndRewrite(
1254       mhlo::DynamicSliceOp dynamicSliceOp, OpAdaptor adaptor,
1255       ConversionPatternRewriter& rewriter) const final {
1256     auto loc = dynamicSliceOp.getLoc();
1257     auto argType = adaptor.operand().getType().dyn_cast<ShapedType>();
1258     if (!argType || !argType.hasRank()) {
1259       return rewriter.notifyMatchFailure(dynamicSliceOp,
1260                                          "require known-rank args");
1261     }
1262 
1263     SmallVector<OpFoldResult, 3> startIndices, sizes;
1264     for (auto& en : llvm::enumerate(
1265              llvm::zip(adaptor.start_indices(),
1266                        dynamicSliceOp.slice_sizes().getValues<int64_t>()))) {
1267       int64_t size = std::get<1>(en.value());
1268       sizes.push_back(rewriter.getI64IntegerAttr(size));
1269 
1270       // By mhlo.DynamicSlice definition:
1271       //   `start_indices[i] = clamp(start_indices[i],
1272       //       0, operand.dimension_size[i] - size_indices[i])`
1273       Value startIndex =
1274           rewriter.create<tensor::ExtractOp>(loc, std::get<0>(en.value()));
1275       startIndex = rewriter.createOrFold<arith::IndexCastOp>(
1276           loc, rewriter.getIndexType(), startIndex);
1277 
1278       Value mn = rewriter.create<arith::ConstantIndexOp>(loc, 0);
1279 
1280       Value mx = rewriter.createOrFold<tensor::DimOp>(loc, adaptor.operand(),
1281                                                       en.index());
1282       mx = rewriter.createOrFold<arith::SubIOp>(
1283           loc, mx, rewriter.create<arith::ConstantIndexOp>(loc, size));
1284 
1285       startIndex = rewriter.create<arith::MaxSIOp>(loc, startIndex, mn);
1286       startIndex = rewriter.create<arith::MinSIOp>(loc, startIndex, mx);
1287 
1288       startIndices.push_back(startIndex);
1289     }
1290 
1291     int64_t rank = argType.getRank();
1292     SmallVector<OpFoldResult, 3> strides(rank, rewriter.getI64IntegerAttr(1));
1293 
1294     auto resultType = this->typeConverter->convertType(dynamicSliceOp.getType())
1295                           .cast<RankedTensorType>();
1296 
1297     rewriter.replaceOpWithNewOp<tensor::ExtractSliceOp>(
1298         dynamicSliceOp, resultType, adaptor.operand(), startIndices, sizes,
1299         strides);
1300     return success();
1301   }
1302 };
1303 
1304 class DynamicUpdateSliceConverter
1305     : public OpConversionPattern<mhlo::DynamicUpdateSliceOp> {
1306  public:
1307   using OpConversionPattern<mhlo::DynamicUpdateSliceOp>::OpConversionPattern;
1308 
matchAndRewrite(mhlo::DynamicUpdateSliceOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const1309   LogicalResult matchAndRewrite(
1310       mhlo::DynamicUpdateSliceOp op, OpAdaptor adaptor,
1311       ConversionPatternRewriter& rewriter) const final {
1312     auto loc = op.getLoc();
1313     auto operandType = adaptor.operand().getType().dyn_cast<RankedTensorType>();
1314     if (!operandType || !operandType.hasStaticShape()) {
1315       return rewriter.notifyMatchFailure(
1316           op, "require static ranked type for operand");
1317     }
1318 
1319     auto updateType = adaptor.update().getType().dyn_cast<RankedTensorType>();
1320     if (!updateType || !updateType.hasStaticShape()) {
1321       return rewriter.notifyMatchFailure(
1322           op, "require static ranked type for operand");
1323     }
1324 
1325     // We do not have to clamp sizes because the semantic of `update`
1326     // guarantees that it is always in the bounds. See
1327     // https://www.tensorflow.org/xla/operation_semantics#dynamicupdateslice
1328     SmallVector<OpFoldResult, 3> sizes;
1329     for (auto size : updateType.getShape()) {
1330       sizes.push_back(rewriter.getIndexAttr(size));
1331     }
1332 
1333     SmallVector<OpFoldResult, 3> startIndices;
1334     Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
1335     for (auto& en : llvm::enumerate(adaptor.start_indices())) {
1336       // By mhlo.DynamicUpdateSlice definition:
1337       //   `start_indices[i] = clamp(start_indices[i],
1338       //       0, operand.dimension_size[i] - update.dimension_size[i])`
1339       Value startIndex = rewriter.create<tensor::ExtractOp>(loc, en.value());
1340       if (!startIndex.getType().isIndex())
1341         startIndex = rewriter.create<arith::IndexCastOp>(
1342             loc, rewriter.getIndexType(), startIndex);
1343       Value ub = rewriter.create<arith::ConstantIndexOp>(
1344           loc, operandType.getDimSize(en.index()) -
1345                    updateType.getDimSize(en.index()));
1346 
1347       startIndex = rewriter.create<arith::MaxSIOp>(loc, startIndex, zero);
1348       startIndex = rewriter.create<arith::MinSIOp>(loc, startIndex, ub);
1349       startIndices.push_back(startIndex);
1350     }
1351 
1352     int64_t rank = operandType.getRank();
1353     SmallVector<OpFoldResult, 3> strides(rank, rewriter.getI64IntegerAttr(1));
1354     rewriter.replaceOpWithNewOp<tensor::InsertSliceOp>(
1355         op, adaptor.update(), adaptor.operand(), startIndices, sizes, strides);
1356     return success();
1357   }
1358 };
1359 
1360 enum class DotOperationType {
1361   kVectorDot = 0,
1362   kMatrixVector,
1363   kVectorMatrix,
1364   kMatrixMatrix,
1365   kUnsupported
1366 };
1367 
getDotOperationType(mhlo::DotOp dotOp)1368 DotOperationType getDotOperationType(mhlo::DotOp dotOp) {
1369   ArrayRef<int64_t> lhsShape =
1370       dotOp.lhs().getType().cast<ShapedType>().getShape();
1371   ArrayRef<int64_t> rhsShape =
1372       dotOp.rhs().getType().cast<ShapedType>().getShape();
1373   auto shapeMatches = [](int64_t a, int64_t b) {
1374     return a == ShapedType::kDynamicSize || b == ShapedType::kDynamicSize ||
1375            a == b;
1376   };
1377   if (lhsShape.size() == 1 && rhsShape.size() == 1 &&
1378       shapeMatches(lhsShape[0], rhsShape[0])) {
1379     return DotOperationType::kVectorDot;
1380   }
1381   if (lhsShape.size() == 2 && rhsShape.size() == 1 &&
1382       shapeMatches(lhsShape[1], rhsShape[0])) {
1383     return DotOperationType::kMatrixVector;
1384   }
1385   if (lhsShape.size() == 1 && rhsShape.size() == 2 &&
1386       shapeMatches(lhsShape[0], rhsShape[0])) {
1387     return DotOperationType::kVectorMatrix;
1388   }
1389   if (lhsShape.size() == 2 && rhsShape.size() == 2 &&
1390       shapeMatches(lhsShape[1], rhsShape[0])) {
1391     return DotOperationType::kMatrixMatrix;
1392   }
1393   return DotOperationType::kUnsupported;
1394 }
1395 
getDotOpInitTensorDynSizes(OpBuilder & b,Location loc,Value lhs,Value rhs,DotOperationType type)1396 SmallVector<Value, 2> getDotOpInitTensorDynSizes(OpBuilder& b, Location loc,
1397                                                  Value lhs, Value rhs,
1398                                                  DotOperationType type) {
1399   SmallVector<Value, 2> dynShape;
1400   switch (type) {
1401     case DotOperationType::kMatrixMatrix: {
1402       if (lhs.getType().cast<ShapedType>().isDynamicDim(0))
1403         dynShape.push_back(b.create<tensor::DimOp>(loc, lhs, 0));
1404       if (rhs.getType().cast<ShapedType>().isDynamicDim(1))
1405         dynShape.push_back(b.create<tensor::DimOp>(loc, rhs, 1));
1406       break;
1407     }
1408     case DotOperationType::kMatrixVector: {
1409       if (lhs.getType().cast<ShapedType>().isDynamicDim(0))
1410         dynShape.push_back(b.create<tensor::DimOp>(loc, lhs, 0));
1411       break;
1412     }
1413     case DotOperationType::kVectorMatrix: {
1414       if (rhs.getType().cast<ShapedType>().isDynamicDim(1))
1415         dynShape.push_back(b.create<tensor::DimOp>(loc, rhs, 1));
1416       break;
1417     }
1418     case DotOperationType::kVectorDot:
1419     case DotOperationType::kUnsupported:
1420       break;
1421   }
1422   return dynShape;
1423 }
1424 
1425 template <DotOperationType op_type, typename LinalgOp>
1426 class DotOpConversion : public OpConversionPattern<mhlo::DotOp> {
1427  public:
1428   using OpConversionPattern<mhlo::DotOp>::OpConversionPattern;
matchAndRewrite(mhlo::DotOp op,mhlo::DotOp::Adaptor adaptor,ConversionPatternRewriter & rewriter) const1429   LogicalResult matchAndRewrite(
1430       mhlo::DotOp op, mhlo::DotOp::Adaptor adaptor,
1431       ConversionPatternRewriter& rewriter) const final {
1432     if (!verifyHloOpBufferOrTensorSemantics(op)) {
1433       return failure();
1434     }
1435     if (getDotOperationType(op) != op_type) return failure();
1436 
1437     Location loc = op.getLoc();
1438     // Convert unsigned to signed. This works because signed and unsigned
1439     // integer matmul is the same operation in two's complement.
1440     auto outputType =
1441         typeConverter->convertType(op.getType()).cast<ShapedType>();
1442     SmallVector<Value, 2> dynShape = getDotOpInitTensorDynSizes(
1443         rewriter, loc, adaptor.lhs(), adaptor.rhs(), op_type);
1444     auto initTensor = getInitTensor(rewriter, loc, outputType, dynShape);
1445     Value zeroTensor = fillTensorWithZeros(rewriter, loc, initTensor);
1446     rewriter.replaceOpWithNewOp<LinalgOp>(
1447         op, TypeRange{outputType}, ValueRange{adaptor.lhs(), adaptor.rhs()},
1448         ValueRange{zeroTensor}, pruneAttributeList(op));
1449     return success();
1450   }
1451 };
1452 
1453 class DotGeneralBatchMatMulOpConversion
1454     : public OpConversionPattern<mhlo::DotGeneralOp> {
1455  public:
1456   using OpConversionPattern<mhlo::DotGeneralOp>::OpConversionPattern;
matchAndRewrite(mhlo::DotGeneralOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const1457   LogicalResult matchAndRewrite(
1458       mhlo::DotGeneralOp op, OpAdaptor adaptor,
1459       ConversionPatternRewriter& rewriter) const final {
1460     if (!verifyHloOpBufferOrTensorSemantics(op)) {
1461       return failure();
1462     }
1463     if (op.getType().cast<RankedTensorType>().getRank() != 3) {
1464       return rewriter.notifyMatchFailure(op, "expected a batch matmul");
1465     }
1466 
1467     mhlo::DotDimensionNumbersAttr dimNumbers = op.dot_dimension_numbers();
1468     auto lhsBatchingDims = dimNumbers.getLhsBatchingDimensions();
1469     auto rhsBatchingDims = dimNumbers.getRhsBatchingDimensions();
1470     auto lhsContractingDims = dimNumbers.getLhsContractingDimensions();
1471     auto rhsContractingDims = dimNumbers.getRhsContractingDimensions();
1472     if (lhsBatchingDims.size() != 1 || lhsBatchingDims[0] != 0) {
1473       return rewriter.notifyMatchFailure(
1474           op, "expected lhs batching dimensions exactly {0}");
1475     }
1476     if (rhsBatchingDims.size() != 1 || rhsBatchingDims[0] != 0) {
1477       return rewriter.notifyMatchFailure(
1478           op, "expected rhs batching dimensions exactly {0}");
1479     }
1480     if (lhsContractingDims.size() != 1 || lhsContractingDims[0] != 2) {
1481       return rewriter.notifyMatchFailure(
1482           op, "expected lhs contracting dimensions exactly {2}");
1483     }
1484     if (rhsContractingDims.size() != 1 || rhsContractingDims[0] != 1) {
1485       return rewriter.notifyMatchFailure(
1486           op, "expected rhs contracting dimensions exactly {1}");
1487     }
1488 
1489     Location loc = op.getLoc();
1490     // Convert unsigned to signed. This works because signed and unsigned
1491     // integer matmul is the same operation in two's complement.
1492     auto outputType =
1493         typeConverter->convertType(op.getType()).cast<ShapedType>();
1494     auto initTensor =
1495         getInitTensorFor(rewriter, loc, outputType, op, adaptor.getOperands());
1496     Value zeroTensor = fillTensorWithZeros(rewriter, loc, initTensor);
1497     Operation* linalgOp = rewriter.create<linalg::BatchMatmulOp>(
1498         loc, /*resultTensorTypes=*/TypeRange{outputType},
1499         /*inputs=*/ValueRange{adaptor.lhs(), adaptor.rhs()},
1500         /*outputBuffers=*/ValueRange{zeroTensor}, pruneAttributeList(op));
1501 
1502     rewriter.replaceOp(op, linalgOp->getResults());
1503     return success();
1504   }
1505 };
1506 
1507 class MapOpConverter : public OpConversionPattern<mhlo::MapOp> {
1508  public:
1509   using OpConversionPattern::OpConversionPattern;
matchAndRewrite(mhlo::MapOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const1510   LogicalResult matchAndRewrite(
1511       mhlo::MapOp op, OpAdaptor adaptor,
1512       ConversionPatternRewriter& rewriter) const final {
1513     if (!verifyHloOpBufferOrTensorSemantics(op)) return failure();
1514 
1515     auto resultType =
1516         typeConverter->convertType(op.getType()).cast<ShapedType>();
1517     assert(op.dimensions().size() == resultType.getRank() &&
1518            "Expected a pointwise map");
1519 
1520     Location loc = op.getLoc();
1521     Value output =
1522         getInitTensorFor(rewriter, loc, resultType, op, adaptor.getOperands());
1523     SmallVector<AffineMap> indexingMaps(
1524         op.getNumOperands() + 1,
1525         rewriter.getMultiDimIdentityMap(resultType.getRank()));
1526 
1527     auto linalgOp = rewriter.create<linalg::GenericOp>(
1528         loc, resultType, adaptor.getOperands(), output, indexingMaps,
1529         getNParallelLoopsAttrs(resultType.getRank()),
1530         /*bodyBuild=*/nullptr, pruneAttributeList(op));
1531 
1532     // Convert the signature of the body. We scalarize the operands and add a
1533     // scalar operand representing the output tensor.
1534     Region& region = linalgOp.region();
1535     rewriter.inlineRegionBefore(op.computation(), region, region.end());
1536     TypeConverter::SignatureConversion signatureConverter(op.getNumOperands() +
1537                                                           1);
1538 
1539     for (const auto& it : llvm::enumerate(op.getOperation()->getOperands())) {
1540       signatureConverter.addInputs(
1541           it.index(),
1542           typeConverter->convertType(
1543               it.value().getType().cast<ShapedType>().getElementType()));
1544     }
1545     signatureConverter.addInputs(resultType.getElementType());
1546 
1547     rewriter.applySignatureConversion(&region, signatureConverter,
1548                                       getTypeConverter());
1549     rewriter.replaceOp(op, linalgOp.getResults());
1550     return success();
1551   }
1552 };
1553 
isInBodyOfLinalgOps(Operation * op)1554 bool isInBodyOfLinalgOps(Operation* op) {
1555   auto* parentOp = op->getParentRegion()->getParentOp();
1556   return parentOp->getDialect() ==
1557          parentOp->getContext()->getLoadedDialect<linalg::LinalgDialect>();
1558 }
1559 
getReduceOpInitTensorDynSizes(OpBuilder & b,Location loc,Value arg,ShapedType resultType,ArrayRef<int64_t> reductionDims)1560 SmallVector<Value, 8> getReduceOpInitTensorDynSizes(
1561     OpBuilder& b, Location loc, Value arg, ShapedType resultType,
1562     ArrayRef<int64_t> reductionDims) {
1563   llvm::SmallSetVector<int, 4> s;
1564   for (auto dim : reductionDims) s.insert(dim);
1565 
1566   SmallVector<unsigned, 4> parallelDims;
1567   SmallVector<Value, 8> dynShape;
1568   int rank = arg.getType().cast<RankedTensorType>().getRank();
1569   for (int i = 0, j = 0; i < rank; ++i) {
1570     if (s.count(i)) continue;
1571     if (!resultType.isDynamicDim(j++)) continue;
1572     dynShape.push_back(b.create<tensor::DimOp>(loc, arg, i));
1573   }
1574 
1575   return dynShape;
1576 }
1577 
1578 class ReduceRegionReturnOpConversion
1579     : public OpConversionPattern<mhlo::ReturnOp> {
1580  public:
1581   using OpConversionPattern<mhlo::ReturnOp>::OpConversionPattern;
matchAndRewrite(mhlo::ReturnOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const1582   LogicalResult matchAndRewrite(
1583       mhlo::ReturnOp op, OpAdaptor adaptor,
1584       ConversionPatternRewriter& rewriter) const final {
1585     if (!isInBodyOfLinalgOps(op)) {
1586       return failure();
1587     }
1588     SmallVector<Value, 4> operands(adaptor.getOperands());
1589     for (size_t i = 0; i < operands.size(); ++i) {
1590       if (operands[i].getType().isa<ShapedType>()) {
1591         auto loc = operands[i].getLoc();
1592         operands[i] = rewriter.create<tensor::ExtractOp>(loc, operands[i]);
1593       }
1594     }
1595     rewriter.replaceOpWithNewOp<linalg::YieldOp>(op, operands);
1596     return success();
1597   }
1598 };
1599 
1600 class ReduceConversion : public OpConversionPattern<mhlo::ReduceOp> {
1601  public:
1602   using OpConversionPattern<mhlo::ReduceOp>::OpConversionPattern;
matchAndRewrite(mhlo::ReduceOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const1603   LogicalResult matchAndRewrite(
1604       mhlo::ReduceOp op, OpAdaptor adaptor,
1605       ConversionPatternRewriter& rewriter) const final {
1606     Location loc = op.getLoc();
1607 
1608     int numOperands = static_cast<int>(adaptor.operands().size());
1609 
1610     if (llvm::any_of(adaptor.operands(), [](Value v) {
1611           return !v.getType().cast<ShapedType>().getRank();
1612         })) {
1613       return rewriter.notifyMatchFailure(op, "expects known-rank args");
1614     }
1615     auto srcRank = adaptor.operands()[0].getType().cast<ShapedType>().getRank();
1616 
1617     SmallVector<int64_t, 4> reductionDims = extract1DVector(op.dimensions());
1618 
1619     SmallVector<Type> resultTypes;
1620     if (failed(typeConverter->convertTypes(op.getResultTypes(), resultTypes)))
1621       return failure();
1622 
1623     SmallVector<Value> operands, outputs;
1624     SmallVector<AffineMap, 3> indexingMaps;
1625     for (auto values :
1626          llvm::zip(adaptor.operands(), adaptor.init_values(), resultTypes)) {
1627       // Check if init_value is constant. If so, inline the value into the
1628       // region.
1629       Value operand = std::get<0>(values);
1630       Value initValue = std::get<1>(values);
1631       Type resultType = std::get<2>(values);
1632       initValue = rewriter.createOrFold<tensor::ExtractOp>(loc, initValue);
1633 
1634       operands.push_back(operand);
1635       SmallVector<Value, 8> dynShape = getReduceOpInitTensorDynSizes(
1636           rewriter, loc, operand, resultType, reductionDims);
1637       auto initTensor = getInitTensor(rewriter, loc, resultType, dynShape);
1638       Value filledTensor =
1639           rewriter.create<linalg::FillOp>(loc, initValue, initTensor).result();
1640       outputs.push_back(filledTensor);
1641     }
1642 
1643     // Prepare indexing maps for linalg generic op. The elements are for src
1644     // and dst. Transpose `src` to make the reduction loops be the innermost,
1645     // because it's easier to fully utilize processors.
1646     indexingMaps.append(
1647         numOperands, getTransposeMapForReduction(rewriter.getContext(),
1648                                                  (int)srcRank, reductionDims));
1649 
1650     // The indexing map of `dst` should drop the reduction loops. Since the
1651     // reduction loops now are all in the innermost, drops
1652     // `reduction_dims.size()` dimensions. We don't need an inverse
1653     // permutation here because they are the same.
1654     SmallVector<AffineExpr, 4> exprs;
1655     for (int i = 0, e = srcRank - reductionDims.size(); i < e; ++i)
1656       exprs.push_back(rewriter.getAffineDimExpr(i));
1657     indexingMaps.append(numOperands,
1658                         AffineMap::get(srcRank, /*symbolCount=*/0, exprs,
1659                                        rewriter.getContext()));
1660 
1661     auto linalgOp = rewriter.create<linalg::GenericOp>(
1662         loc, /*resultTensorTypes=*/resultTypes, operands,
1663         /*outputBuffers=*/ValueRange{outputs}, indexingMaps,
1664         getParallelAndReductionIterators(srcRank, reductionDims.size()),
1665         /*bodyBuild=*/nullptr, pruneAttributeList(op));
1666 
1667     // Convert the signature of the body. The reduce op region apply function
1668     // has a signature (lhs, rhs) -> output, all of the same tensor type t.
1669     // This is converted to a function with the same signature but with
1670     // element types. E.g., "(tensor<f32>, tensor<f32>) -> tensor<f32>" will
1671     // be converted to "(f32, f32, f32)".
1672     Region& region = linalgOp.region();
1673     rewriter.inlineRegionBefore(op.body(), region, region.end());
1674     TypeConverter::SignatureConversion signatureConverter(numOperands * 2);
1675 
1676     // Reduce requires that the seed be used as a LHS operand inside the
1677     // region, and the seed is encoded in linalg in the intial out value, so
1678     // modify the signature of the block and the value mappings, so the output
1679     // args will correlate with the LHS and the inputs correlate with the RHS.
1680     for (const auto& [idx, val] : llvm::enumerate(op.init_values())) {
1681       signatureConverter.addInputs(
1682           idx + numOperands,
1683           typeConverter->convertType(
1684               val.getType().cast<ShapedType>().getElementType()));
1685     }
1686     for (const auto& [idx, val] : llvm::enumerate(op.operands())) {
1687       signatureConverter.addInputs(
1688           idx, typeConverter->convertType(
1689                    val.getType().cast<ShapedType>().getElementType()));
1690     }
1691 
1692     rewriter.applySignatureConversion(&region, signatureConverter,
1693                                       getTypeConverter());
1694     rewriter.replaceOp(op, linalgOp.getResults());
1695     return success();
1696   }
1697 };
1698 
1699 // Decomposes a pad with negative edge padding into a pad without negative edge
1700 // padding and a tensor.extract_slice.
1701 struct PadOpNegativePaddingConversion
1702     : public OpConversionPattern<mhlo::PadOp> {
1703   using OpConversionPattern::OpConversionPattern;
1704 
matchAndRewritemlir::mhlo::__anon071fe3ef0111::PadOpNegativePaddingConversion1705   LogicalResult matchAndRewrite(
1706       mhlo::PadOp op, OpAdaptor adaptor,
1707       ConversionPatternRewriter& rewriter) const override {
1708     SmallVector<int64_t, 4> padLow;
1709     SmallVector<int64_t, 4> padHigh;
1710     SmallVector<OpFoldResult, 4> sliceStarts;
1711 
1712     bool hasNegativePadding = false;
1713     for (int64_t low : op.edge_padding_low().getValues<int64_t>()) {
1714       if (low >= 0) {
1715         padLow.push_back(low);
1716         sliceStarts.push_back(rewriter.getIndexAttr(0));
1717       } else {
1718         padLow.push_back(0);
1719         sliceStarts.push_back(rewriter.getIndexAttr(-low));
1720         hasNegativePadding = true;
1721       }
1722     }
1723 
1724     for (int64_t high : op.edge_padding_high().getValues<int64_t>()) {
1725       if (high >= 0) {
1726         padHigh.push_back(high);
1727       } else {
1728         padHigh.push_back(-high);
1729         hasNegativePadding = true;
1730       }
1731     }
1732 
1733     // If there's no negative edge padding we're done.
1734     if (!hasNegativePadding) return failure();
1735 
1736     // Create a new pad op with the positive values.
1737     Value pad = rewriter.create<mhlo::PadOp>(
1738         op.getLoc(), adaptor.operand(), adaptor.padding_value(),
1739         rewriter.getI64TensorAttr(padLow), rewriter.getI64TensorAttr(padHigh),
1740         op.interior_padding());
1741 
1742     // Then slice according to the negative edge padding. Static shapes only for
1743     // now.
1744     if (!op.getType().hasStaticShape()) return failure();
1745     SmallVector<OpFoldResult, 4> sizes(llvm::map_range(
1746         op.getType().getShape(),
1747         [&](int64_t dim) { return rewriter.getIndexAttr(dim); }));
1748     SmallVector<OpFoldResult, 4> strides(sliceStarts.size(),
1749                                          rewriter.getIndexAttr(1));
1750     rewriter.replaceOpWithNewOp<tensor::ExtractSliceOp>(op, pad, sliceStarts,
1751                                                         sizes, strides);
1752     return success();
1753   }
1754 };
1755 
1756 /// Converts mhlo.pad operation to tensor.pad or tensor.insert_slice.
1757 struct PadOpConversion : public OpConversionPattern<mhlo::PadOp> {
1758   using OpConversionPattern<mhlo::PadOp>::OpConversionPattern;
1759 
matchAndRewritemlir::mhlo::__anon071fe3ef0111::PadOpConversion1760   LogicalResult matchAndRewrite(
1761       mhlo::PadOp op, OpAdaptor adaptor,
1762       ConversionPatternRewriter& rewriter) const override {
1763     auto loc = op.getLoc();
1764     auto resultType = typeConverter->convertType(op.getResult().getType());
1765 
1766     // Negative edge padding is decomposed separately.
1767     auto isNegative = [](const APInt& intVal) { return intVal.isNegative(); };
1768     if (llvm::any_of(op.edge_padding_low().getValues<APInt>(), isNegative) ||
1769         llvm::any_of(op.edge_padding_high().getValues<APInt>(), isNegative))
1770       return failure();
1771 
1772     Value paddingVal =
1773         rewriter.createOrFold<tensor::ExtractOp>(loc, adaptor.padding_value());
1774 
1775     SmallVector<OpFoldResult, 4> low(
1776         op.edge_padding_low().getValues<IntegerAttr>());
1777 
1778     // If there is no interior padding lower to tensor.pad directly.
1779     if (llvm::all_of(op.interior_padding().getValues<APInt>(),
1780                      [](const APInt& intVal) { return intVal.isZero(); })) {
1781       SmallVector<OpFoldResult, 4> high(
1782           op.edge_padding_high().getValues<IntegerAttr>());
1783       auto padTensorOp = tensor::createPadScalarOp(
1784           resultType, adaptor.operand(), paddingVal, low, high,
1785           /*nofold=*/false, loc, rewriter);
1786       rewriter.replaceOp(op, padTensorOp.getResult());
1787       return success();
1788     }
1789 
1790     // We have interior padding, which can be lowered to tensor.insert_slice.
1791     // Start by filling a result-sized tensor with the pad value.
1792     auto initTensor =
1793         getInitTensorFor(rewriter, loc, resultType, op, adaptor.getOperands());
1794     auto fill =
1795         rewriter.create<linalg::FillOp>(loc, paddingVal, initTensor).result();
1796 
1797     // Get sizes of the original operand.
1798     auto operandType = adaptor.operand().getType().cast<ShapedType>();
1799     auto sizes = llvm::to_vector<4>(llvm::map_range(
1800         llvm::seq<int64_t>(0, operandType.getRank()),
1801         [&](int64_t dim) -> OpFoldResult {
1802           if (!operandType.isDynamicDim(dim))
1803             return rewriter.getIndexAttr(operandType.getDimSize(dim));
1804           return rewriter.create<tensor::DimOp>(loc, adaptor.operand(), dim)
1805               .getResult();
1806         }));
1807     // Map interior padding to strides.
1808     auto strides = llvm::to_vector<4>(
1809         llvm::map_range(op.interior_padding().getValues<IntegerAttr>(),
1810                         [&](IntegerAttr stride) -> OpFoldResult {
1811                           return rewriter.getIntegerAttr(stride.getType(),
1812                                                          stride.getValue() + 1);
1813                         }));
1814 
1815     rewriter.replaceOpWithNewOp<tensor::InsertSliceOp>(
1816         op, adaptor.operand(), fill, low, sizes, strides);
1817     return success();
1818   }
1819 };
1820 
1821 // Apply dilation and padding to the input of a convolution.
applyConvolutionPadding(Location loc,Value input,DenseIntElementsAttr padding,DenseIntElementsAttr lhsDilation,llvm::ArrayRef<int64_t> dimMappings,OpBuilder & rewriter)1822 Value applyConvolutionPadding(Location loc, Value input,
1823                               DenseIntElementsAttr padding,
1824                               DenseIntElementsAttr lhsDilation,
1825                               llvm::ArrayRef<int64_t> dimMappings,
1826                               OpBuilder& rewriter) {
1827   if ((!padding || isSplatValue(padding, 0)) &&
1828       (!lhsDilation || isSplatValue(lhsDilation, 1)))
1829     return input;
1830 
1831   auto inputType = input.getType().cast<ShapedType>();
1832   auto rank = inputType.getRank();
1833 
1834   // Translate window padding into low/high padding.
1835   SmallVector<int64_t, 8> padLow(rank, 0);
1836   SmallVector<int64_t, 8> padHigh(rank, 0);
1837   if (padding) {
1838     // The padding attribute contains two values per dimension, but excludes the
1839     // batch and feature dimensions.
1840     assert(rank * 2 == padding.size() + 4 &&
1841            "There should be 2 padding values per dimension, i.e low and high.");
1842     for (auto i : llvm::seq<int64_t>(0, padding.size() / 2)) {
1843       auto dim = dimMappings[i];
1844       padLow[dim] = padding.getValues<int64_t>()[i * 2];
1845       padHigh[dim] = padding.getValues<int64_t>()[i * 2 + 1];
1846     }
1847   }
1848 
1849   // Translate input dilation into interior padding.
1850   SmallVector<int64_t, 8> padInterior(rank, 0);
1851   if (lhsDilation) {
1852     assert(rank == lhsDilation.size() + 2);
1853     for (auto i : llvm::seq<int64_t>(0, lhsDilation.size())) {
1854       auto dim = dimMappings[i];
1855       padInterior[dim] = lhsDilation.getValues<int64_t>()[i] - 1;
1856     }
1857   }
1858 
1859   auto indexType = rewriter.getIntegerType(64);
1860   auto attrType = RankedTensorType::get({rank}, indexType);
1861   Value zero = rewriter.create<arith::ConstantOp>(
1862       loc, rewriter.getZeroAttr(
1863                RankedTensorType::get({}, inputType.getElementType())));
1864   return rewriter.create<mhlo::PadOp>(
1865       loc, input, zero, DenseIntElementsAttr::get(attrType, padLow),
1866       DenseIntElementsAttr::get(attrType, padHigh),
1867       DenseIntElementsAttr::get(attrType, padInterior));
1868 }
1869 
1870 /// Converts mhlo.conv operation to linalg named op. This only covers normal
1871 /// convolution cases. The op must have canonical dimension numbers. Depthwise
1872 /// convolution and pointwise convolution are not handled in the conversion.
1873 struct NormalConvolutionOpConversion
1874     : public OpConversionPattern<mhlo::ConvolutionOp> {
1875   using OpConversionPattern<mhlo::ConvolutionOp>::OpConversionPattern;
1876 
matchAndRewritemlir::mhlo::__anon071fe3ef0111::NormalConvolutionOpConversion1877   LogicalResult matchAndRewrite(
1878       mhlo::ConvolutionOp op, OpAdaptor adaptor,
1879       ConversionPatternRewriter& rewriter) const override {
1880     if (!hasCanonicalDimensionNumbers(op.dimension_numbers())) return failure();
1881     if (op.feature_group_count() != 1u) return failure();
1882     if (op.batch_group_count() != 1u) return failure();
1883 
1884     Location loc = op.getLoc();
1885     Value input = adaptor.lhs();
1886     Value filter = adaptor.rhs();
1887     auto resultType =
1888         typeConverter->convertType(op.getResult().getType()).cast<ShapedType>();
1889     int64_t rank = resultType.getRank();
1890 
1891     // The output shape is N spatial_dims F.
1892     SmallVector<Value, 8> dynSizes;
1893     if (resultType.isDynamicDim(0)) {
1894       dynSizes.push_back(rewriter.create<tensor::DimOp>(loc, input, 0));
1895     }
1896     for (int64_t i = 1, e = rank - 1; i < e; ++i) {
1897       if (resultType.isDynamicDim(i)) {
1898         return rewriter.notifyMatchFailure(
1899             op, "expected output spatial dims to be static shapes");
1900       }
1901     }
1902     if (resultType.isDynamicDim(rank - 1)) {
1903       dynSizes.push_back(rewriter.create<tensor::DimOp>(loc, filter, rank - 1));
1904     }
1905     Value initTensor = rewriter.create<linalg::InitTensorOp>(
1906         loc, dynSizes, resultType.getShape(), resultType.getElementType());
1907     Value zeroTensor = fillTensorWithZeros(rewriter, loc, initTensor);
1908     linalg::LinalgOp res;
1909     Attribute strides = op.window_stridesAttr();
1910     Attribute dilations = op.rhs_dilationAttr();
1911 
1912     // Apply padding and input dilation.
1913     llvm::SmallVector<int64_t> spatialDimMapping(rank - 2);
1914     std::iota(spatialDimMapping.begin(), spatialDimMapping.end(), 1);
1915     input = applyConvolutionPadding(loc, input, op.paddingAttr(),
1916                                     op.lhs_dilationAttr(), spatialDimMapping,
1917                                     rewriter);
1918 
1919     switch (rank) {
1920       case 2: {
1921         res = rewriter.create<linalg::MatmulOp>(
1922             loc, resultType, ValueRange{input, filter}, ValueRange{zeroTensor},
1923             pruneAttributeList(op));
1924         break;
1925       }
1926       case 3: {
1927         res = rewriter.create<linalg::Conv1DNwcWcfOp>(
1928             loc, resultType, ValueRange{input, filter}, ValueRange{zeroTensor},
1929             strides, dilations, pruneAttributeList(op));
1930         break;
1931       }
1932       case 4: {
1933         res = rewriter.create<linalg::Conv2DNhwcHwcfOp>(
1934             loc, resultType, ValueRange{input, filter}, ValueRange{zeroTensor},
1935             strides, dilations, pruneAttributeList(op));
1936         break;
1937       }
1938       case 5: {
1939         res = rewriter.create<linalg::Conv3DNdhwcDhwcfOp>(
1940             loc, resultType, ValueRange{input, filter}, ValueRange{zeroTensor},
1941             strides, dilations, pruneAttributeList(op));
1942         break;
1943       }
1944       default:
1945         return rewriter.notifyMatchFailure(op, "expected 1/2/3D conv op");
1946     }
1947     rewriter.replaceOp(op, res.getOperation()->getResults());
1948     return success();
1949   }
1950 };
1951 
1952 /// Handles all possible inputs for the mhlo::ConvolutionOp
1953 struct ConvolutionOpGeneralConversion
1954     : public OpConversionPattern<mhlo::ConvolutionOp> {
1955   using OpConversionPattern<mhlo::ConvolutionOp>::OpConversionPattern;
1956 
1957   /// This lowering proceeds with the following steps:
1958   /// 1. Handle padding and dilation of the input
1959   /// 2. Handle padding and dilation of the window
1960   /// 3. Handle reversal of the window
1961   /// 4. If feature_group_count != 1:
1962   ///    - Reshape the input feature dimension, kernel output feature dimension,
1963   ///      and output feature dimension.
1964   ///    - Create the AffineExpr for the new dimension
1965   ///    - Conceptually, this splits the input feature and both output feature
1966   ///      dimensions and computes sets of convolutions with these partial views
1967   ///      of the values as if they were multiple convolutions combined in a
1968   ///      batch.
1969   /// 5: If batch_group_count != 1:
1970   ///    - Reshape the input batch dimension, kernel output feature dimension,
1971   ///      and output feature dimension.
1972   ///    - Create the AffineExpr for the new dimension
1973   ///    - Conceptually, this splits the input batch and both output feature
1974   ///      dimensions and computes sets of convolutions with these partial views
1975   ///      of the values as if they were multiple convolutions combined in a
1976   ///      batch.
1977   /// 6. For all dimensions not newly created by a reshape, create the
1978   ///    appropriate parallel and reduction dimensions to create a convolution.
1979   /// 7. Create the linalg.generic that computes the multiply-add
1980   /// 8. Reshape the output to the original shape if it was reshaped by the
1981   ///    feature or group count attributes.
matchAndRewritemlir::mhlo::__anon071fe3ef0111::ConvolutionOpGeneralConversion1982   LogicalResult matchAndRewrite(
1983       mhlo::ConvolutionOp op, OpAdaptor adaptor,
1984       ConversionPatternRewriter& rewriter) const override {
1985     auto loc = op.getLoc();
1986     auto* ctx = op.getContext();
1987 
1988     auto resultType =
1989         typeConverter->convertType(op.getResult().getType()).cast<ShapedType>();
1990     auto reshapedResultShape = resultType.getShape().vec();
1991     if (!resultType.hasStaticShape()) return failure();
1992 
1993     auto dimensionNumbers = op.dimension_numbers();
1994     auto inputBatchDimension = dimensionNumbers.getInputBatchDimension();
1995     auto inputFeatureDimension = dimensionNumbers.getInputFeatureDimension();
1996     auto inputSpatialDimensions = dimensionNumbers.getInputSpatialDimensions();
1997 
1998     auto kernelInputFeatureDimension =
1999         dimensionNumbers.getKernelInputFeatureDimension();
2000     auto kernelOutputFeatureDimension =
2001         dimensionNumbers.getKernelOutputFeatureDimension();
2002     auto kernelSpatialDimensions =
2003         dimensionNumbers.getKernelSpatialDimensions();
2004 
2005     auto outputFeatureDimension = dimensionNumbers.getOutputFeatureDimension();
2006     auto outputSpatialDimensions =
2007         dimensionNumbers.getOutputSpatialDimensions();
2008 
2009     auto featureGroupCount = op.feature_group_count();
2010     auto batchGroupCount = op.batch_group_count();
2011 
2012     if (op.feature_group_count() != 1 && op.batch_group_count() != 1) {
2013       return rewriter.notifyMatchFailure(
2014           op, "only one of feature and batch group counts can be non-one");
2015     }
2016 
2017     // Decompose the convolution into an initial padding
2018     Value modifiedLhs = applyConvolutionPadding(
2019         op.getLoc(), adaptor.lhs(), adaptor.paddingAttr(),
2020         adaptor.lhs_dilationAttr(),
2021         op.dimension_numbers().getInputSpatialDimensions(), rewriter);
2022     Value modifiedRhs = applyConvolutionPadding(
2023         op.getLoc(), adaptor.rhs(), nullptr, adaptor.rhs_dilationAttr(),
2024         op.dimension_numbers().getKernelSpatialDimensions(), rewriter);
2025 
2026     // Decompose the reversal dims into its own step
2027     auto reversals = op.window_reversal();
2028     if (reversals.value()) {
2029       llvm::SmallVector<int64_t> reversedDims;
2030       for (auto& idxAndBool :
2031            llvm::enumerate(reversals.value().getValues<bool>()))
2032         if (idxAndBool.value())
2033           reversedDims.push_back(
2034               op.dimension_numbers()
2035                   .getKernelSpatialDimensions()[idxAndBool.index()]);
2036 
2037       modifiedRhs = rewriter.create<mhlo::ReverseOp>(
2038           loc, modifiedRhs,
2039           mlir::DenseIntElementsAttr::get(
2040               RankedTensorType::get(reversedDims.size(),
2041                                     rewriter.getIntegerType(64)),
2042               reversedDims));
2043     }
2044 
2045     // Non-one values for feature or batch group counts will result in reshaped
2046     // inputs and outputs. These mappings are used to keep track of the the new
2047     // index after reshaping has possibly inserted new dimensions.
2048     auto paddedLhsType = modifiedLhs.getType().cast<ShapedType>();
2049     auto paddedRhsType = modifiedRhs.getType().cast<ShapedType>();
2050     SmallVector<int64_t> lhsIndexMapping(paddedLhsType.getRank());
2051     std::iota(lhsIndexMapping.begin(), lhsIndexMapping.end(), 0);
2052     SmallVector<int64_t> rhsIndexMapping(paddedRhsType.getRank());
2053     std::iota(rhsIndexMapping.begin(), rhsIndexMapping.end(), 0);
2054     SmallVector<int64_t> resultIndexMapping(resultType.getRank());
2055     std::iota(resultIndexMapping.begin(), resultIndexMapping.end(), 0);
2056     auto updateDimMappingFromOffset =
2057         [](llvm::SmallVectorImpl<int64_t>& mapping, int64_t offset) {
2058           for (auto i = offset; i < mapping.size(); ++i) {
2059             mapping[i] += 1;
2060           }
2061         };
2062 
2063     // The rest of this code prepares the inputs and a single linalg::GenericOp
2064     // to execute the convolution. The final linalg::GenericOp will be iterated
2065     // through based on the following eventual maps.
2066     SmallVector<AffineExpr, 2> srcExprs(paddedLhsType.getRank());
2067     SmallVector<AffineExpr, 2> windowExprs(paddedRhsType.getRank());
2068     SmallVector<AffineExpr, 2> dstExprs(reshapedResultShape.size());
2069     int64_t nextDim = 0;
2070     int64_t rank = resultType.getRank();
2071 
2072     auto reshapeShapeVector = [](llvm::ArrayRef<int64_t> oldShape,
2073                                  llvm::SmallVectorImpl<int64_t>& newShape,
2074                                  int64_t reshapedDim, int64_t factor) {
2075       newShape.reserve(oldShape.size() + 1);
2076       for (int i = 0; i < oldShape.size(); ++i) {
2077         if (i == reshapedDim) {
2078           newShape.push_back(factor);
2079           newShape.push_back(oldShape[reshapedDim] / factor);
2080         } else {
2081           newShape.push_back(oldShape[i]);
2082         }
2083       }
2084     };
2085 
2086     // If batch or feature count groupings exist, represent this through
2087     // reshaping the input to have an additional dimension that these groupings
2088     // exist along, and reduce in that dimension
2089     SmallVector<StringRef, 3> iterationLoops;
2090     if (featureGroupCount != 1) {
2091       auto parallelDim = mlir::getAffineDimExpr(nextDim++, ctx);
2092       iterationLoops.push_back(getParallelIteratorTypeName());
2093       // Reshape LHS
2094       {
2095         srcExprs.insert(srcExprs.begin() + inputFeatureDimension, parallelDim);
2096         auto prevDimsRef = paddedLhsType.getShape();
2097         llvm::SmallVector<int64_t> newShape;
2098         reshapeShapeVector(prevDimsRef, newShape, inputFeatureDimension,
2099                            featureGroupCount);
2100         updateDimMappingFromOffset(lhsIndexMapping, inputFeatureDimension);
2101         modifiedLhs = rewriter.create<mhlo::ReshapeOp>(
2102             op.getLoc(),
2103             RankedTensorType::get(newShape, paddedLhsType.getElementType()),
2104             modifiedLhs);
2105       }
2106 
2107       // Reshape RHS
2108       {
2109         windowExprs.insert(windowExprs.begin() + kernelOutputFeatureDimension,
2110                            parallelDim);
2111         auto prevDimsRef = paddedRhsType.getShape();
2112         llvm::SmallVector<int64_t> newShape;
2113         reshapeShapeVector(prevDimsRef, newShape, kernelOutputFeatureDimension,
2114                            featureGroupCount);
2115         updateDimMappingFromOffset(rhsIndexMapping,
2116                                    kernelOutputFeatureDimension);
2117         modifiedRhs = rewriter.create<mhlo::ReshapeOp>(
2118             op.getLoc(),
2119             RankedTensorType::get(newShape, paddedRhsType.getElementType()),
2120             modifiedRhs);
2121       }
2122       // Prepare reshaped output shape
2123       {
2124         dstExprs.insert(dstExprs.begin() + outputFeatureDimension, parallelDim);
2125         updateDimMappingFromOffset(resultIndexMapping, outputFeatureDimension);
2126         reshapedResultShape.insert(
2127             reshapedResultShape.begin() + outputFeatureDimension,
2128             featureGroupCount);
2129         reshapedResultShape[outputFeatureDimension + 1] /= featureGroupCount;
2130       }
2131     }
2132 
2133     if (batchGroupCount != 1) {
2134       iterationLoops.push_back(getParallelIteratorTypeName());
2135       auto parallelDim = mlir::getAffineDimExpr(nextDim++, ctx);
2136       // Reshape LHS
2137       {
2138         srcExprs.insert(srcExprs.begin() + inputBatchDimension, parallelDim);
2139         auto prevDimsRef = paddedLhsType.getShape();
2140         llvm::SmallVector<int64_t> newShape;
2141         reshapeShapeVector(prevDimsRef, newShape, inputBatchDimension,
2142                            batchGroupCount);
2143         updateDimMappingFromOffset(lhsIndexMapping, inputBatchDimension);
2144         modifiedLhs = rewriter.create<mhlo::ReshapeOp>(
2145             op.getLoc(),
2146             RankedTensorType::get(newShape, paddedLhsType.getElementType()),
2147             modifiedLhs);
2148       }
2149 
2150       // Reshape RHS
2151       {
2152         windowExprs.insert(windowExprs.begin() + kernelOutputFeatureDimension,
2153                            parallelDim);
2154         auto prevDimsRef = paddedRhsType.getShape();
2155         llvm::SmallVector<int64_t> newShape;
2156         reshapeShapeVector(prevDimsRef, newShape, kernelOutputFeatureDimension,
2157                            batchGroupCount);
2158         updateDimMappingFromOffset(rhsIndexMapping,
2159                                    kernelOutputFeatureDimension);
2160         modifiedRhs = rewriter.create<mhlo::ReshapeOp>(
2161             op.getLoc(),
2162             RankedTensorType::get(newShape, paddedRhsType.getElementType()),
2163             modifiedRhs);
2164       }
2165       // Prepare reshaped output shape
2166       {
2167         auto outputFeatureDim = resultIndexMapping[outputFeatureDimension];
2168         dstExprs.insert(dstExprs.begin() + outputFeatureDim, parallelDim);
2169         updateDimMappingFromOffset(resultIndexMapping, outputFeatureDimension);
2170         reshapedResultShape.insert(
2171             reshapedResultShape.begin() + outputFeatureDim, batchGroupCount);
2172         reshapedResultShape[outputFeatureDim + 1] /= batchGroupCount;
2173       }
2174     }
2175 
2176     // Handle input feature dimension
2177     {
2178       iterationLoops.push_back(getReductionIteratorTypeName());
2179       auto inputFeatureDim = mlir::getAffineDimExpr(nextDim++, ctx);
2180       srcExprs[lhsIndexMapping[inputFeatureDimension]] = inputFeatureDim;
2181       windowExprs[rhsIndexMapping[kernelInputFeatureDimension]] =
2182           inputFeatureDim;
2183     }
2184 
2185     // Handle output feature dimension
2186     {
2187       iterationLoops.push_back(getParallelIteratorTypeName());
2188       auto outputFeatureDim = mlir::getAffineDimExpr(nextDim++, ctx);
2189       dstExprs[resultIndexMapping[outputFeatureDimension]] = outputFeatureDim;
2190       windowExprs[rhsIndexMapping[kernelOutputFeatureDimension]] =
2191           outputFeatureDim;
2192     }
2193 
2194     // Handle spatial Dimensions
2195     int64_t numSpatialDims = rank - 2;
2196     for (int64_t i = 0; i < numSpatialDims; i++) {
2197       iterationLoops.push_back(getParallelIteratorTypeName());
2198       iterationLoops.push_back(getReductionIteratorTypeName());
2199       auto dim0 = mlir::getAffineDimExpr(nextDim++, ctx);
2200       auto dim1 = mlir::getAffineDimExpr(nextDim++, ctx);
2201 
2202       auto stride = dim0;
2203       if (op.window_strides().value())
2204         stride = stride * op.window_strides().value().getValues<int64_t>()[i];
2205       AffineExpr srcExpr = stride + dim1;
2206 
2207       srcExprs[lhsIndexMapping[inputSpatialDimensions[i]]] = srcExpr;
2208       dstExprs[resultIndexMapping[outputSpatialDimensions[i]]] = dim0;
2209       windowExprs[rhsIndexMapping[kernelSpatialDimensions[i]]] = dim1;
2210     }
2211 
2212     // Handle batch dimension
2213     {
2214       iterationLoops.push_back(getParallelIteratorTypeName());
2215       auto batchDim = mlir::getAffineDimExpr(nextDim++, ctx);
2216 
2217       srcExprs[lhsIndexMapping[inputBatchDimension]] = batchDim;
2218       dstExprs[resultIndexMapping[inputBatchDimension]] = batchDim;
2219     }
2220 
2221     // Finally, create the computation
2222     auto inferredMaps =
2223         AffineMap::inferFromExprList({srcExprs, windowExprs, dstExprs});
2224 
2225     Value initTensor = rewriter.create<linalg::InitTensorOp>(
2226         loc, reshapedResultShape, resultType.getElementType());
2227     Value zeroTensor = fillTensorWithZeros(rewriter, loc, initTensor);
2228 
2229     Value convolved =
2230         rewriter
2231             .create<linalg::GenericOp>(
2232                 loc,
2233                 /*resultTensors=*/
2234                 llvm::makeArrayRef<Type>(zeroTensor.getType()),
2235                 /*inputs=*/
2236                 llvm::makeArrayRef<Value>({modifiedLhs, modifiedRhs}),
2237                 /*outputs=*/llvm::makeArrayRef<Value>(zeroTensor), inferredMaps,
2238                 iterationLoops,
2239                 /*bodyBuild=*/
2240                 [&](OpBuilder& nestedBuilder, Location nestedLoc, ValueRange) {
2241                   ImplicitLocOpBuilder builder(nestedLoc, nestedBuilder);
2242                   linalg::Conv2DOp::regionBuilder(
2243                       builder, *builder.getInsertionBlock(), {});
2244                 },
2245                 pruneAttributeList(op))
2246             .getResult(0);
2247     rewriter.replaceOpWithNewOp<mhlo::ReshapeOp>(op, resultType, convolved);
2248 
2249     return success();
2250   }
2251 };
2252 
2253 /// Converts mhlo.convolution operation to
2254 /// linalg.depthwise_conv_2d_input_nhwc_filter_hwcf op or
2255 /// depthwise_conv_2d_input_nhwc_filter_hwc op.
2256 struct DepthwiseConvolutionOpConversion
2257     : public OpConversionPattern<mhlo::ConvolutionOp> {
2258   using OpConversionPattern<mhlo::ConvolutionOp>::OpConversionPattern;
2259 
matchAndRewritemlir::mhlo::__anon071fe3ef0111::DepthwiseConvolutionOpConversion2260   LogicalResult matchAndRewrite(
2261       mhlo::ConvolutionOp op, OpAdaptor adaptor,
2262       ConversionPatternRewriter& rewriter) const override {
2263     if (op.batch_group_count() != 1) return failure();
2264     // Fall into the normal convolution cases.
2265     if (op.feature_group_count() == 1) return failure();
2266 
2267     const mhlo::ConvDimensionNumbersAttr& dimensionNumbers =
2268         op.dimension_numbers();
2269     const auto spatialRank =
2270         llvm::size(dimensionNumbers.getInputSpatialDimensions());
2271     if (spatialRank == 0 || spatialRank > 3) {
2272       return rewriter.notifyMatchFailure(op, "only support up to 3D for now");
2273     }
2274 
2275     // Make sure that this is depthwise convolution.
2276     int64_t inputFeatureDim = dimensionNumbers.getInputFeatureDimension();
2277     int64_t inputFeatureCount =
2278         op.lhs().getType().cast<ShapedType>().getDimSize(inputFeatureDim);
2279     if (op.feature_group_count() != inputFeatureCount) {
2280       return rewriter.notifyMatchFailure(op, "not depth-wise convolution");
2281     }
2282 
2283     // Make sure that this convolution has a canonical form.
2284     if (!hasCanonicalDimensionNumbers(dimensionNumbers)) {
2285       return rewriter.notifyMatchFailure(op, "does not have canonical form");
2286     }
2287 
2288     Attribute windowStrides;
2289     if (op.window_strides()) {
2290       windowStrides = op.window_strides().value();
2291     } else {
2292       windowStrides = SplatElementsAttr::get(
2293           VectorType::get({spatialRank}, rewriter.getI64Type()),
2294           rewriter.getI64IntegerAttr(1));
2295     }
2296 
2297     Attribute rhsDilation;
2298     if (op.rhs_dilation()) {
2299       rhsDilation = op.rhs_dilation().value();
2300     } else {
2301       rhsDilation = SplatElementsAttr::get(
2302           VectorType::get({spatialRank}, rewriter.getI64Type()),
2303           rewriter.getI64IntegerAttr(1));
2304     }
2305 
2306     Location loc = op.getLoc();
2307     Value input = adaptor.lhs();
2308     Value filter = adaptor.rhs();
2309     auto resultType = typeConverter->convertType(op.getResult().getType())
2310                           .cast<RankedTensorType>();
2311     if (!resultType.hasStaticShape()) {
2312       return rewriter.notifyMatchFailure(op,
2313                                          "expected output has static shapes");
2314     }
2315 
2316     // Apply padding and input dilation.
2317     llvm::SmallVector<int64_t> spatialDimMapping(spatialRank);
2318     std::iota(spatialDimMapping.begin(), spatialDimMapping.end(), 1);
2319     input = applyConvolutionPadding(loc, input, op.paddingAttr(),
2320                                     op.lhs_dilationAttr(), spatialDimMapping,
2321                                     rewriter);
2322 
2323     auto filterDims =
2324         llvm::to_vector<4>(op.rhs().getType().cast<ShapedType>().getShape());
2325 
2326     auto getReassociationIndicesToCollapseLastTwoDims = [](Value v) {
2327       SmallVector<ReassociationIndices> reassociations;
2328       int64_t rank = v.getType().cast<ShapedType>().getRank();
2329       for (int64_t i = 0; i < rank - 1; ++i) reassociations.emplace_back(1, i);
2330       reassociations.back().push_back(rank - 1);
2331       return reassociations;
2332     };
2333 
2334     int64_t kernelInputFeatureDimension =
2335         dimensionNumbers.getKernelInputFeatureDimension();
2336     int64_t kernelOutputFeatureDimension =
2337         dimensionNumbers.getKernelOutputFeatureDimension();
2338     if (filterDims[kernelInputFeatureDimension] *
2339             filterDims[kernelOutputFeatureDimension] !=
2340         op.feature_group_count()) {
2341       // For cases where channel multiplier != 1
2342 
2343       // Reshaping filter shape
2344       //   [filter_height, filter_width, 1, kernel-output-feature].
2345       // to
2346       //   [filter_height, filter_width, feature_group_count,
2347       //      kernel-output-feature/feature_group_count ]
2348       SmallVector<int64_t> reshapedFilterDims;
2349       reshapedFilterDims.assign(filterDims.begin(), filterDims.end());
2350       auto reshapedFilter = filter;
2351       if (filterDims[kernelInputFeatureDimension] == 1) {
2352         reshapedFilterDims[kernelInputFeatureDimension] =
2353             op.feature_group_count();
2354         reshapedFilterDims[kernelOutputFeatureDimension] /=
2355             op.feature_group_count();
2356         auto reshapedFilterType = RankedTensorType::get(
2357             reshapedFilterDims,
2358             op.rhs().getType().cast<RankedTensorType>().getElementType());
2359 
2360         reshapedFilter =
2361             rewriter.create<mhlo::ReshapeOp>(loc, reshapedFilterType, filter);
2362       }
2363 
2364       auto outputDims = resultType.getShape();
2365       auto channelMultiplier = reshapedFilterDims.back();
2366       SmallVector<int64_t> reshapedOutputDims;
2367       reshapedOutputDims.assign(outputDims.begin(), outputDims.end());
2368       reshapedOutputDims.push_back(channelMultiplier);
2369       reshapedOutputDims[reshapedOutputDims.size() - 2] /= channelMultiplier;
2370 
2371       Value initTensor = rewriter.create<linalg::InitTensorOp>(
2372           loc, reshapedOutputDims, resultType.getElementType());
2373       Value zeroTensor = fillTensorWithZeros(rewriter, loc, initTensor);
2374 
2375       auto reshapedOutputType = RankedTensorType::get(
2376           reshapedOutputDims, resultType.getElementType());
2377       Value conv;
2378       switch (spatialRank) {
2379         case 1:
2380           conv =
2381               rewriter
2382                   .create<linalg::DepthwiseConv1DNwcWcmOp>(
2383                       loc, reshapedOutputType,
2384                       ValueRange{input, reshapedFilter}, ValueRange{zeroTensor},
2385                       windowStrides, rhsDilation, pruneAttributeList(op))
2386                   .getResult(0);
2387           break;
2388         case 2:
2389           conv =
2390               rewriter
2391                   .create<linalg::DepthwiseConv2DNhwcHwcmOp>(
2392                       loc, reshapedOutputType,
2393                       ValueRange{input, reshapedFilter}, ValueRange{zeroTensor},
2394                       windowStrides, rhsDilation, pruneAttributeList(op))
2395                   .getResult(0);
2396           break;
2397         case 3:
2398           conv =
2399               rewriter
2400                   .create<linalg::DepthwiseConv3DNdhwcDhwcmOp>(
2401                       loc, reshapedOutputType,
2402                       ValueRange{input, reshapedFilter}, ValueRange{zeroTensor},
2403                       windowStrides, rhsDilation, pruneAttributeList(op))
2404                   .getResult(0);
2405           break;
2406       }
2407 
2408       // Create a Linalg reshape op that converts the output from 5 dimensions
2409       // into 4 dimensions (by collapsing the last two dimensions). This is
2410       // needed because linalg.depthwise_conv_2d_input_nhwc_filter_hwcf returns
2411       // 5 dimensions for the output.
2412       rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>(
2413           op, resultType, conv,
2414           getReassociationIndicesToCollapseLastTwoDims(conv));
2415     } else {
2416       // For cases where channel multiplier == 1
2417       Value initTensor = rewriter.create<linalg::InitTensorOp>(
2418           loc, resultType.getShape(), resultType.getElementType());
2419       Value zeroTensor = fillTensorWithZeros(rewriter, loc, initTensor);
2420 
2421       // Create a Linalg reshape op that converts the filter from 4 dimensions
2422       // into 3 dimensions (by droping the unit dimension). This is needed
2423       // because linalg.depthwise_conv_2d_input_nhwc_filter_hwc expects 3
2424       // dimensions for the filter.
2425 
2426       filterDims[filterDims.size() - 2] =
2427           static_cast<int64_t>(op.feature_group_count());
2428       filterDims.pop_back();
2429 
2430       RankedTensorType filterShape =
2431           RankedTensorType::get(filterDims, op.getType().getElementType());
2432 
2433       Value reshapedFilter = rewriter.create<tensor::CollapseShapeOp>(
2434           loc, filterShape, filter,
2435           getReassociationIndicesToCollapseLastTwoDims(filter));
2436 
2437       switch (spatialRank) {
2438         case 1:
2439           rewriter.replaceOpWithNewOp<linalg::DepthwiseConv1DNwcWcOp>(
2440               op, resultType, ValueRange{input, reshapedFilter},
2441               ValueRange{zeroTensor}, windowStrides, rhsDilation,
2442               pruneAttributeList(op));
2443           break;
2444         case 2:
2445           rewriter.replaceOpWithNewOp<linalg::DepthwiseConv2DNhwcHwcOp>(
2446               op, resultType, ValueRange{input, reshapedFilter},
2447               ValueRange{zeroTensor}, windowStrides, rhsDilation,
2448               pruneAttributeList(op));
2449           break;
2450         case 3:
2451           rewriter.replaceOpWithNewOp<linalg::DepthwiseConv3DNdhwcDhwcOp>(
2452               op, resultType, ValueRange{input, reshapedFilter},
2453               ValueRange{zeroTensor}, windowStrides, rhsDilation,
2454               pruneAttributeList(op));
2455           break;
2456       }
2457     }
2458 
2459     return success();
2460   }
2461 };
2462 
2463 struct ReduceWindowOpOnTensorsGenericConversion
2464     : public OpConversionPattern<mhlo::ReduceWindowOp> {
2465   using OpConversionPattern<mhlo::ReduceWindowOp>::OpConversionPattern;
matchAndRewritemlir::mhlo::__anon071fe3ef0111::ReduceWindowOpOnTensorsGenericConversion2466   LogicalResult matchAndRewrite(
2467       mhlo::ReduceWindowOp op, OpAdaptor adaptor,
2468       ConversionPatternRewriter& rewriter) const override {
2469     MLIRContext* ctx = op->getContext();
2470     Location loc = op.getLoc();
2471     llvm::SmallVector<Value> initValues = adaptor.init_values();
2472     llvm::SmallVector<Type> resultTypes = llvm::to_vector(op.getResultTypes());
2473     auto numOperands = initValues.size();
2474 
2475     llvm::SmallVector<int64_t> windowDimensions =
2476         extract1DVector(op.window_dimensions());
2477 
2478     llvm::SmallVector<int64_t> padding;
2479     if (op.padding()) {
2480       padding = extract1DVector(*op.padding());
2481     }
2482 
2483     llvm::SmallVector<int64_t> baseDilations;
2484     if (op.base_dilations()) {
2485       baseDilations = extract1DVector(*op.base_dilations());
2486     }
2487 
2488     llvm::SmallVector<int64_t> windowStrides(windowDimensions.size(), 1);
2489     if (op.window_strides()) {
2490       windowStrides = extract1DVector(*op.window_strides());
2491     }
2492 
2493     llvm::SmallVector<int64_t> windowDilations(windowDimensions.size(), 1);
2494     if (op.window_dilations()) {
2495       windowDilations = extract1DVector(*op.window_dilations());
2496     }
2497 
2498     auto rank = static_cast<int64_t>(windowDimensions.size());
2499     SmallVector<AffineExpr, 2> srcExprs;
2500     SmallVector<AffineExpr, 2> windowExprs;
2501     SmallVector<AffineExpr, 2> dstExprs;
2502     SmallVector<int64_t> filteredWindowDims;
2503 
2504     int windowDim = 0;
2505     for (int64_t i = 0; i < rank; i++) {
2506       AffineExpr srcExpr = mlir::getAffineDimExpr(i, ctx);
2507 
2508       if (windowStrides[i] != 1) srcExpr = srcExpr * windowStrides[i];
2509 
2510       if (windowDimensions[i] != 1) {
2511         filteredWindowDims.push_back(windowDimensions[i]);
2512         AffineExpr windowExpr = mlir::getAffineDimExpr(rank + windowDim, ctx);
2513         windowExprs.push_back(windowExpr);
2514 
2515         if (windowDilations[i] != 1)
2516           windowExpr = windowExpr * windowDilations[i];
2517 
2518         srcExpr = srcExpr + windowExpr;
2519         windowDim++;
2520       }
2521 
2522       srcExprs.push_back(srcExpr);
2523       dstExprs.push_back(mlir::getAffineDimExpr(i, ctx));
2524     }
2525 
2526     SmallVector<AffineMap, 4> inferredMaps(3, AffineMap::get(ctx));
2527     if (rank > 0)
2528       inferredMaps =
2529           AffineMap::inferFromExprList({srcExprs, windowExprs, dstExprs});
2530 
2531     SmallVector<AffineMap, 4> indexingMaps;
2532 
2533     indexingMaps.append(numOperands, inferredMaps[0]);
2534     indexingMaps.append(1, inferredMaps[1]);
2535     indexingMaps.append(numOperands, inferredMaps[2]);
2536 
2537     // Setup the initial values.
2538     llvm::SmallVector<Value> broadcastValues;
2539     for (uint64_t i = 0, s = initValues.size(); i < s; i++) {
2540       Value initValue = initValues[i];
2541       auto resultTy = resultTypes[i].cast<ShapedType>();
2542       if (!resultTy.hasStaticShape()) return failure();
2543 
2544       auto broadcastSizes = rewriter.getI64TensorAttr(resultTy.getShape());
2545       broadcastValues.push_back(rewriter.create<mhlo::BroadcastOp>(
2546           loc, resultTy, initValue, broadcastSizes));
2547     }
2548 
2549     llvm::SmallVector<Value> inputs = llvm::to_vector(adaptor.operands());
2550 
2551     // Pad as necessary.
2552     if (llvm::any_of(padding, [](int64_t v) { return v != 0; }) ||
2553         llvm::any_of(baseDilations, [](int64_t v) { return v != 1; })) {
2554       llvm::SmallVector<int64_t> staticLows(rank, 0);
2555       llvm::SmallVector<int64_t> staticHighs(rank, 0);
2556       for (int i = 0; i < padding.size(); i += 2) {
2557         staticLows[i / 2] = padding[i];
2558         staticHighs[i / 2] = padding[i + 1];
2559       }
2560       // Translate base dilation into interior padding.
2561       llvm::SmallVector<int64_t> staticInteriors(rank, 0);
2562       for (const auto& dilation : llvm::enumerate(baseDilations)) {
2563         staticInteriors[dilation.index()] = dilation.value() - 1;
2564       }
2565 
2566       auto padAttrType =
2567           RankedTensorType::get({rank}, rewriter.getIntegerType(64));
2568       auto padLows = DenseIntElementsAttr::get(padAttrType, staticLows);
2569       auto padHighs = DenseIntElementsAttr::get(padAttrType, staticHighs);
2570       auto padInteriors =
2571           DenseIntElementsAttr::get(padAttrType, staticInteriors);
2572 
2573       for (auto values : llvm::zip(inputs, initValues)) {
2574         auto& input = std::get<0>(values);
2575         auto& initValue = std::get<1>(values);
2576         input = rewriter.create<mhlo::PadOp>(loc, input, initValue, padLows,
2577                                              padHighs, padInteriors);
2578       }
2579     }
2580 
2581     // Add the extra input for the reduction dimension.
2582     inputs.push_back(rewriter.create<linalg::InitTensorOp>(
2583         loc, filteredWindowDims, rewriter.getF32Type()));
2584 
2585     auto linalgOp = rewriter.create<linalg::GenericOp>(
2586         loc, /*resultTensors=*/resultTypes,
2587         /*inputs=*/inputs,
2588         /*outputs=*/broadcastValues, indexingMaps,
2589         getParallelAndReductionIterators(rank + filteredWindowDims.size(),
2590                                          filteredWindowDims.size()),
2591         /*bodyBuild=*/nullptr, pruneAttributeList(op));
2592 
2593     // Convert the signature of the body. This includes converting scalar
2594     // tensors to their scalar values and inserting an additional block arg for
2595     // the window arg.
2596     Region& region = linalgOp.region();
2597     rewriter.cloneRegionBefore(op.body(), region, region.end());
2598 
2599     TypeConverter::SignatureConversion signatureConverter(
2600         inputs.size() + op->getNumResults() - 1);
2601 
2602     // ReduceWindow requires that the seed be used as a LHS operand inside the
2603     // region, and the seed is encoded in linalg in the intial out value, so
2604     // modify the signature of the block and the value mappings, so the output
2605     // args will correlate with the LHS and the inputs correlate with the RHS.
2606     for (const auto& [i, type] : llvm::enumerate(resultTypes)) {
2607       auto idx = inputs.size() + i - 1;
2608       signatureConverter.addInputs(idx,
2609                                    type.cast<ShapedType>().getElementType());
2610     }
2611 
2612     signatureConverter.addInputs(
2613         inputs.back().getType().cast<ShapedType>().getElementType());
2614 
2615     for (const auto& [i, input] :
2616          llvm::enumerate(ArrayRef<Value>(inputs).drop_back())) {
2617       signatureConverter.addInputs(
2618           i, input.getType().cast<ShapedType>().getElementType());
2619     }
2620 
2621     rewriter.applySignatureConversion(&region, signatureConverter,
2622                                       getTypeConverter());
2623     rewriter.replaceOp(op, linalgOp.getResults());
2624     return success();
2625   }
2626 };
2627 
2628 struct ReduceWindowOpConversion
2629     : public OpConversionPattern<mhlo::ReduceWindowOp> {
2630   using OpConversionPattern<mhlo::ReduceWindowOp>::OpConversionPattern;
2631 
2632   /// mhlo.reduce_window is mapped to a linalg.pooling operation. The type of
2633   /// the pooling is determined based on the body of the reduce window
2634   /// operation. This class enumerates the different variants.
2635   enum class PoolingType {
2636     kInvalid,
2637     k2DMin,
2638     k3DMin,
2639     k2DMax,
2640     k3DMax,
2641     k2DAdd,
2642     k3DAdd,
2643   };
2644 
getPoolingTypemlir::mhlo::__anon071fe3ef0111::ReduceWindowOpConversion2645   static PoolingType getPoolingType(mhlo::ReduceWindowOp reduceOp,
2646                                     int resultIndex) {
2647     auto rank =
2648         reduceOp.getResultTypes()[resultIndex].cast<ShapedType>().getRank();
2649     if (Operation* op = reduceOp.getReductionOp(resultIndex)) {
2650       if (isa<mhlo::MinOp>(*op) && rank == 4) return PoolingType::k2DMin;
2651       if (isa<mhlo::MinOp>(*op) && rank == 5) return PoolingType::k3DMin;
2652       if (isa<mhlo::MaxOp>(*op) && rank == 4) return PoolingType::k2DMax;
2653       if (isa<mhlo::MaxOp>(*op) && rank == 5) return PoolingType::k3DMax;
2654       if (isa<mhlo::AddOp>(*op) && rank == 4) return PoolingType::k2DAdd;
2655       if (isa<mhlo::AddOp>(*op) && rank == 5) return PoolingType::k3DAdd;
2656     }
2657     return PoolingType::kInvalid;
2658   }
2659 
matchAndRewritemlir::mhlo::__anon071fe3ef0111::ReduceWindowOpConversion2660   LogicalResult matchAndRewrite(
2661       mhlo::ReduceWindowOp op, OpAdaptor adaptor,
2662       ConversionPatternRewriter& rewriter) const override {
2663     auto loc = op.getLoc();
2664     int rank = op.getResultTypes()[0].cast<ShapedType>().getRank();
2665     if (rank != 4 && rank != 5) {
2666       return rewriter.notifyMatchFailure(
2667           op, "expected NHWC/NDHWC pooling-based op");
2668     }
2669 
2670     if (op.padding() && !isSplatValue(*op.padding(), 0)) {
2671       return rewriter.notifyMatchFailure(op, "require paddings are all zero");
2672     }
2673 
2674     if (op.base_dilations() && !isSplatValue(*op.base_dilations(), 1)) {
2675       return rewriter.notifyMatchFailure(op, "expected undilated base");
2676     }
2677 
2678     int lastDim = rank - 1;
2679     SmallVector<int64_t, 2> fakeWindowShapes;
2680     for (int i = 1; i < lastDim; ++i) {
2681       fakeWindowShapes.push_back(
2682           op.window_dimensions().getValues<int64_t>()[i]);
2683     }
2684 
2685     if (op.window_strides() &&
2686         (op.window_strides().value().getValues<int64_t>()[0] != 1 ||
2687          op.window_strides().value().getValues<int64_t>()[lastDim] != 1)) {
2688       return rewriter.notifyMatchFailure(
2689           op, "expected window_strides to be [1,x,y,(z),1]");
2690     }
2691     if (op.window_dimensions() &&
2692         (op.window_dimensions().getValues<int64_t>()[0] != 1 ||
2693          op.window_dimensions().getValues<int64_t>()[lastDim] != 1)) {
2694       return rewriter.notifyMatchFailure(
2695           op, "expected window_dimensions to be [1,x,y,(z),1]");
2696     }
2697 
2698     Attribute strides;
2699     SmallVector<int64_t> vec;
2700     if (op.window_stridesAttr()) {
2701       for (int i = 1; i < lastDim; ++i) {
2702         vec.push_back(op.window_strides().value().getValues<int64_t>()[i]);
2703       }
2704     } else {
2705       vec.assign(rank - 2, 1);
2706     }
2707     strides = rewriter.getI64VectorAttr(vec);
2708 
2709     Attribute dilations;
2710     vec.clear();
2711     if (op.window_dilations()) {
2712       for (int i = 1; i < lastDim; ++i) {
2713         vec.push_back(op.window_dilations().value().getValues<int64_t>()[i]);
2714       }
2715     } else {
2716       vec.assign(rank - 2, 1);
2717     }
2718     dilations = rewriter.getI64VectorAttr(vec);
2719 
2720     SmallVector<Value> poolingOps;
2721 
2722     ValueRange operands = adaptor.operands();
2723     ValueRange initValues = adaptor.init_values();
2724     for (auto it : llvm::zip(op.getResults(), operands, initValues)) {
2725       OpResult result = std::get<0>(it);
2726       Value input = std::get<1>(it);
2727       Value initValue = std::get<2>(it);
2728       auto resultType = result.getType().cast<ShapedType>();
2729       if (!input.getType().cast<ShapedType>().getElementType().isF32()) {
2730         return rewriter.notifyMatchFailure(op,
2731                                            "expected element type to be f32");
2732       }
2733 
2734       // Create a fake window dimension.
2735       auto fakeWindowDims = rewriter.create<linalg::InitTensorOp>(
2736           loc, fakeWindowShapes, resultType.getElementType());
2737 
2738       SmallVector<Value> resultDynamicDims;
2739       for (auto& en : llvm::enumerate(resultType.getShape())) {
2740         if (en.value() != ShapedType::kDynamicSize) continue;
2741         Value dimSize = rewriter.create<tensor::DimOp>(loc, input, en.index());
2742         if (en.index() == 0 || static_cast<int64_t>(en.index()) == rank - 1) {
2743           // batch dims and channel dims can be derived from input dims
2744           // directly.
2745           resultDynamicDims.push_back(dimSize);
2746         } else {
2747           auto i = en.index() - 1;
2748           auto stride =
2749               strides.cast<DenseIntElementsAttr>().getValues<int64_t>()[i];
2750           auto dilation =
2751               dilations.cast<DenseIntElementsAttr>().getValues<int64_t>()[i];
2752           // let j = i * stride
2753           // output[i] = reduce( input[j, j + window_size * dilation) )
2754           Value offset = rewriter.create<arith::ConstantIndexOp>(
2755               loc, fakeWindowShapes[i] * dilation);
2756           dimSize = rewriter.create<arith::SubIOp>(loc, dimSize, offset);
2757           dimSize = rewriter.create<arith::DivUIOp>(
2758               loc, dimSize,
2759               rewriter.create<arith::ConstantIndexOp>(loc, stride));
2760           dimSize = rewriter.create<arith::AddIOp>(
2761               loc, dimSize, rewriter.create<arith::ConstantIndexOp>(loc, 1));
2762           resultDynamicDims.push_back(dimSize);
2763         }
2764       }
2765       Value initTensor = rewriter.create<linalg::InitTensorOp>(
2766           loc, resultDynamicDims, resultType.getShape(),
2767           resultType.getElementType());
2768 
2769       initValue = rewriter.create<tensor::ExtractOp>(loc, initValue);
2770       Value filledInitTensor =
2771           rewriter.create<linalg::FillOp>(loc, initValue, initTensor)
2772               .getResult(0);
2773       auto createOp = [&](auto* typePtr) -> linalg::LinalgOp {
2774         return cast<linalg::LinalgOp>(
2775             rewriter
2776                 .create<std::remove_pointer_t<decltype(typePtr)>>(
2777                     loc, ArrayRef<Type>{resultType},
2778                     ValueRange{input, fakeWindowDims.getResult()},
2779                     filledInitTensor, strides, dilations,
2780                     pruneAttributeList(op))
2781                 .getOperation());
2782       };
2783       linalg::LinalgOp poolingOp;
2784       PoolingType poolingType = getPoolingType(op, result.getResultNumber());
2785       switch (poolingType) {
2786         case PoolingType::k2DMin: {
2787           poolingOp = createOp(static_cast<linalg::PoolingNhwcMinOp*>(nullptr));
2788           break;
2789         }
2790         case PoolingType::k3DMin: {
2791           poolingOp =
2792               createOp(static_cast<linalg::PoolingNdhwcMinOp*>(nullptr));
2793           break;
2794         }
2795         case PoolingType::k2DMax: {
2796           poolingOp = createOp(static_cast<linalg::PoolingNhwcMaxOp*>(nullptr));
2797           break;
2798         }
2799         case PoolingType::k3DMax: {
2800           poolingOp =
2801               createOp(static_cast<linalg::PoolingNdhwcMaxOp*>(nullptr));
2802           break;
2803         }
2804         case PoolingType::k2DAdd: {
2805           poolingOp = createOp(static_cast<linalg::PoolingNhwcSumOp*>(nullptr));
2806           break;
2807         }
2808         case PoolingType::k3DAdd: {
2809           poolingOp =
2810               createOp(static_cast<linalg::PoolingNdhwcSumOp*>(nullptr));
2811           break;
2812         }
2813         case PoolingType::kInvalid:
2814           return rewriter.notifyMatchFailure(op, "unknown reduction operation");
2815       }
2816       poolingOps.push_back(poolingOp->getResult(0));
2817     }
2818     rewriter.replaceOp(op, poolingOps);
2819     return success();
2820   }
2821 };
2822 
2823 /// Converts xla-hlo.torch_index_select op to a linalg.generic op.
2824 struct TorchIndexSelectOpConversion
2825     : public OpConversionPattern<mhlo::TorchIndexSelectOp> {
2826   using OpConversionPattern<mhlo::TorchIndexSelectOp>::OpConversionPattern;
2827 
matchAndRewritemlir::mhlo::__anon071fe3ef0111::TorchIndexSelectOpConversion2828   LogicalResult matchAndRewrite(
2829       mhlo::TorchIndexSelectOp op, OpAdaptor adaptor,
2830       ConversionPatternRewriter& rewriter) const final {
2831     int axis = static_cast<int>(op.dim());
2832     int batch = static_cast<int>(op.batch_dims());
2833     auto indexShapedType = adaptor.index().getType().cast<ShapedType>();
2834     int numIndices = static_cast<int>(indexShapedType.getRank());
2835     auto operandShapedType = adaptor.operand().getType().cast<ShapedType>();
2836     if (axis < 0) axis += static_cast<int>(operandShapedType.getRank());
2837     if (batch < 0) batch += numIndices;
2838 
2839     Location loc = op.getLoc();
2840     auto resultType = this->typeConverter->convertType(op.getResult().getType())
2841                           .cast<ShapedType>();
2842     int rank = static_cast<int>(resultType.getRank());
2843 
2844     // The output shape is
2845     //   `params[:axis] + indices[batch_dims:] + params[axis + 1:]`
2846     SmallVector<Value, 4> dynSizes;
2847     for (int i = 0; i < rank; ++i) {
2848       if (!resultType.isDynamicDim(i)) continue;
2849       if (i < axis) {
2850         dynSizes.push_back(
2851             rewriter.create<tensor::DimOp>(loc, adaptor.operand(), i));
2852       } else if (i < (axis + numIndices - batch)) {
2853         int idx = i - axis + batch;
2854         dynSizes.push_back(
2855             rewriter.create<tensor::DimOp>(loc, adaptor.index(), idx));
2856       } else {
2857         int idx = i - (axis + numIndices - batch) + axis + 1;
2858         dynSizes.push_back(
2859             rewriter.create<tensor::DimOp>(loc, adaptor.operand(), idx));
2860       }
2861     }
2862 
2863     // Generate dummy tensor to preserve slice shape information.
2864     SmallVector<int64_t> sliceShape;
2865     SmallVector<Value, 4> dynSliceSizes;
2866     SmallVector<AffineExpr, 4> sliceExprs;
2867     auto resultShape = resultType.getShape();
2868     for (int i = 0; i < axis; ++i) {
2869       sliceExprs.push_back(rewriter.getAffineDimExpr(i));
2870       sliceShape.push_back(resultShape[i]);
2871       if (!resultType.isDynamicDim(i)) continue;
2872       dynSliceSizes.push_back(
2873           rewriter.create<tensor::DimOp>(loc, adaptor.operand(), i));
2874     }
2875     for (int i = axis + numIndices - batch; i < rank; ++i) {
2876       sliceExprs.push_back(rewriter.getAffineDimExpr(i));
2877       sliceShape.push_back(resultShape[i]);
2878       if (!resultType.isDynamicDim(i)) continue;
2879       int idx = i - (axis + numIndices - batch) + axis + 1;
2880       dynSliceSizes.push_back(
2881           rewriter.create<tensor::DimOp>(loc, adaptor.operand(), idx));
2882     }
2883 
2884     // Setup AffineMap for operand tensor.
2885     SmallVector<AffineExpr, 4> exprs;
2886     for (int i = 0; i < batch; ++i) {
2887       exprs.push_back(rewriter.getAffineDimExpr(i));
2888     }
2889     for (int i = 0, e = numIndices - batch; i < e; ++i) {
2890       exprs.push_back(rewriter.getAffineDimExpr(axis + i));
2891     }
2892 
2893     SmallVector<AffineMap, 2> indexingMaps;
2894     indexingMaps.emplace_back(
2895         AffineMap::get(rank, /*symbolCount=*/0, exprs, rewriter.getContext()));
2896     indexingMaps.emplace_back(AffineMap::get(
2897         rank, /*symbolCount=*/0, sliceExprs, rewriter.getContext()));
2898     indexingMaps.emplace_back(rewriter.getMultiDimIdentityMap(rank));
2899 
2900     Value sliceOp = rewriter.create<linalg::InitTensorOp>(
2901         loc, dynSliceSizes, sliceShape, resultType.getElementType());
2902 
2903     Value initOp = rewriter.create<linalg::InitTensorOp>(
2904         loc, dynSizes, resultType.getShape(), resultType.getElementType());
2905     auto linalgOp = rewriter.create<linalg::GenericOp>(
2906         loc, /*resultTensors=*/ArrayRef<Type>{resultType},
2907         /*inputs=*/ValueRange{adaptor.index(), sliceOp},
2908         /*outputs=*/initOp, indexingMaps, getNParallelLoopsAttrs(rank),
2909         /*bodyBuild=*/nullptr, pruneAttributeList(op));
2910 
2911     SmallVector<Type, 4> bodyArgTypes;
2912     SmallVector<Value, 2> linalgOpArgs = {adaptor.index(), sliceOp};
2913     // Add a block to the region.
2914     auto* region = &linalgOp.region();
2915     auto* block = rewriter.createBlock(region, region->end());
2916     for (auto blockArgs : linalgOpArgs) {
2917       bodyArgTypes.push_back(
2918           blockArgs.getType().cast<ShapedType>().getElementType());
2919     }
2920     block->addArguments(bodyArgTypes,
2921                         SmallVector<Location>(bodyArgTypes.size(), loc));
2922     block->addArguments(resultType.getElementType(), loc);
2923     OpBuilder::InsertionGuard guard(rewriter);
2924     rewriter.setInsertionPointToEnd(block);
2925 
2926     Value castedValue = rewriter.create<arith::IndexCastOp>(
2927         loc, rewriter.getIndexType(), block->getArgument(0));
2928 
2929     SmallVector<Value, 4> indices;
2930     for (int i = 0; i < axis; ++i) {
2931       indices.push_back(rewriter.create<linalg::IndexOp>(loc, i));
2932     }
2933     indices.push_back(castedValue);
2934     for (int i = axis + numIndices - batch; i < rank; ++i) {
2935       indices.push_back(rewriter.create<linalg::IndexOp>(loc, i));
2936     }
2937     Value res =
2938         rewriter.create<tensor::ExtractOp>(loc, adaptor.operand(), indices);
2939     rewriter.create<linalg::YieldOp>(loc, res);
2940 
2941     rewriter.replaceOp(op, linalgOp.getResults());
2942     return success();
2943   }
2944 };
2945 
2946 /// This lowering encompasses the full range of the Gather operation and
2947 /// therefore is very general and just loops over the output and calculate the
2948 /// corresponding input index. It follows the explanation at
2949 /// https://www.tensorflow.org/xla/operation_semantics#gather. The compiler
2950 /// should be able to optimize that a bit, but in order to get efficient
2951 /// lowerings, special-cases of gather should be extracted in separate
2952 /// lowerings, and ideally encapsulated as separate ops or canonicalization
2953 /// patterns.
2954 struct GatherConversion : public OpConversionPattern<mhlo::GatherOp> {
2955   using OpConversionPattern<mhlo::GatherOp>::OpConversionPattern;
2956 
matchAndRewritemlir::mhlo::__anon071fe3ef0111::GatherConversion2957   LogicalResult matchAndRewrite(
2958       mhlo::GatherOp gatherOp, OpAdaptor adaptor,
2959       ConversionPatternRewriter& rewriter) const final {
2960     Location loc = gatherOp.getLoc();
2961 
2962     Value startIndices = adaptor.start_indices();
2963     Value operand = adaptor.operand();
2964 
2965     auto resultType = typeConverter->convertType(gatherOp.getType())
2966                           .dyn_cast<RankedTensorType>();
2967     RankedTensorType startIndicesType =
2968         startIndices.getType().dyn_cast<RankedTensorType>();
2969     // We could actually deal with an unranked result by inferring the result
2970     // rank, but the current reifyReturnTypes doesn't support unranked either.
2971     if (!resultType || !startIndicesType)
2972       return rewriter.notifyMatchFailure(gatherOp,
2973                                          "unranked start indices or result");
2974 
2975     int resultRank = resultType.getRank();
2976     // slice_sizes has to have the same size as operand.rank, and doing it this
2977     // way permits an unranked operand.
2978     int operandRank = gatherOp.slice_sizes().getNumElements();
2979 
2980     int64_t indexVectorDim = gatherOp.dimension_numbers().getIndexVectorDim();
2981 
2982     ArrayRef<int64_t> offsetDims = gatherOp.dimension_numbers().getOffsetDims();
2983     ArrayRef<int64_t> collapsedSliceDims =
2984         gatherOp.dimension_numbers().getCollapsedSliceDims();
2985     ArrayRef<int64_t> startIndexMap =
2986         gatherOp.dimension_numbers().getStartIndexMap();
2987 
2988     auto extractAsIndex = [&](Value input, ArrayRef<Value> index) -> Value {
2989       return rewriter.create<arith::IndexCastOp>(
2990           loc, rewriter.getIndexType(),
2991           rewriter.create<tensor::ExtractOp>(loc, input, index));
2992     };
2993 
2994     // We'll need these later and creating them on demand we end up with
2995     // duplicates, which also makes lit tests really hard to write.
2996     SmallVector<Value> constants;
2997     for (int i = 0; i < std::max({resultRank, operandRank, 2}); ++i) {
2998       constants.push_back(
2999           rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(i)));
3000     }
3001 
3002     // Create ops to calculate the dynamic dimensions of the return shape, which
3003     // are needed for the init tensor.
3004     SmallVector<Value> dynDimSizes;
3005     if (!resultType.hasStaticShape()) {
3006       SmallVector<Value> returnShapes;
3007       if (failed(gatherOp.reifyReturnTypeShapes(rewriter, adaptor.getOperands(),
3008                                                 returnShapes)))
3009         return rewriter.notifyMatchFailure(gatherOp,
3010                                            "could not reify return shape");
3011       assert(returnShapes.size() == 1);
3012       Value returnShape = returnShapes[0];
3013 
3014       for (int i = 0; i < resultRank; ++i)
3015         if (resultType.isDynamicDim(i))
3016           dynDimSizes.push_back(extractAsIndex(returnShape, constants[i]));
3017     }
3018 
3019     Value initOp = rewriter.create<linalg::InitTensorOp>(
3020         loc, dynDimSizes, resultType.getShape(), resultType.getElementType());
3021 
3022     ValueRange ins;
3023     SmallVector<AffineMap, 1> indexingMaps(
3024         {rewriter.getMultiDimIdentityMap(resultRank)});
3025     auto linalgOp = rewriter.create<linalg::GenericOp>(
3026         loc, /*resultTensorTypes=*/resultType,
3027         /*inputs=*/ins,
3028         /*outputs=*/initOp, indexingMaps, getNParallelLoopsAttrs(resultRank),
3029         /*bodyBuild=*/nullptr, pruneAttributeList(gatherOp));
3030 
3031     // Now populate the linalg generic region
3032     auto* region = &linalgOp.region();
3033     auto* block = rewriter.createBlock(region, region->end());
3034     block->addArguments(resultType.getElementType(), loc);
3035     OpBuilder::InsertionGuard guard(rewriter);
3036     rewriter.setInsertionPointToEnd(block);
3037 
3038     // Dimensions in the result that aren't offset dimensions are called batch.
3039     SmallVector<int64_t> batchDims;
3040     for (int dim = 0; dim < resultRank; ++dim)
3041       if (!llvm::is_contained(offsetDims, dim)) batchDims.push_back(dim);
3042 
3043     // Same as with the constants. Creating these all up front is easier than
3044     // potentially getting duplicates later.
3045     SmallVector<Value> linalgIndices;
3046     for (int i = 0; i < resultRank; ++i)
3047       linalgIndices.push_back(rewriter.create<linalg::IndexOp>(loc, i));
3048 
3049     // Now the complicated part. For a given output dimension we build up an
3050     // index into the input. It's composed of two parts: the index coming from
3051     // start_indices, and the offset from that index along the offset
3052     // dimensions. Everything includes dimension shuffling and remapping as well
3053     // because of the way gather is defined to allow for any-layout input by
3054     // adding more attributes.
3055 
3056     // The base gather index (`G` in the documentation) points to a place in
3057     // start_indices along the batch dimensions.
3058     SmallVector<Value> gatherIndex;
3059     for (auto dim : batchDims) gatherIndex.push_back(linalgIndices[dim]);
3060 
3061     SmallVector<Value> indexFromStartIndices;
3062     for (unsigned i = 0; i < startIndexMap.size(); ++i) {
3063       // The index along the index_vector dimension of start_indices varies.
3064       // Basically indexFromStartIndices indexes into a "row" along
3065       // index_vector_dim, where the row is selected by the current output
3066       // index.
3067       // But if index_vector_dim is equal to start_indices.rank, then
3068       // start_indices gets a trailing 1 dimension added. So the row we're
3069       // extracting always has length 1 and the index into it is always 0, so we
3070       // just use the gather index directly
3071       SmallVector<Value> gCombine(gatherIndex);
3072       if (indexVectorDim != startIndicesType.getRank()) {
3073         assert(indexVectorDim <= static_cast<int64_t>(gCombine.size()));
3074         gCombine.insert(gCombine.begin() + indexVectorDim, constants[i]);
3075       }
3076 
3077       indexFromStartIndices.push_back(extractAsIndex(startIndices, gCombine));
3078     }
3079 
3080     // But then start indices are shuffled by the start index map. To make a
3081     // full index into the operand, all missing indices are zeroes.
3082     SmallVector<Value> remappedIndexFromIndices(operandRank, constants[0]);
3083     for (auto& it : llvm::enumerate(startIndexMap))
3084       remappedIndexFromIndices[it.value()] = indexFromStartIndices[it.index()];
3085 
3086     // Now we construct the index based on the offset. First we need to remap
3087     // the offset dimensions by dropping the collapsed indices.
3088     SmallVector<unsigned> remappedOffsetDims;
3089     for (int i = 0; i < operandRank; ++i)
3090       if (!llvm::is_contained(collapsedSliceDims, i))
3091         remappedOffsetDims.push_back(i);
3092 
3093     assert(remappedOffsetDims.size() == offsetDims.size());
3094 
3095     // Clamp out of bounds indices.
3096     for (int i = 0, operandIndexDim = 0; i < operandRank; ++i) {
3097       // Compute the size of the output shape dimension corresponding to this
3098       // index dimension. If it's collapsed set it to 1.
3099       Value outputDimSize = constants[1];
3100       if (!llvm::is_contained(collapsedSliceDims, i)) {
3101         outputDimSize = rewriter.createOrFold<tensor::DimOp>(
3102             loc, initOp, offsetDims[operandIndexDim++]);
3103       }
3104 
3105       // If this is a skipped dimension, we're done and don't have to clamp.
3106       if (remappedIndexFromIndices[i] == constants[0]) continue;
3107 
3108       Value operandDimSize =
3109           rewriter.createOrFold<tensor::DimOp>(loc, operand, i);
3110       Value largestValidIndex = rewriter.createOrFold<arith::SubIOp>(
3111           loc, operandDimSize, outputDimSize);
3112 
3113       // Clamp indices to [0, i, operand_dim-output_dim].
3114       Value clamp = rewriter.create<arith::MinSIOp>(
3115           loc,
3116           rewriter.create<arith::MaxSIOp>(loc, constants[0],
3117                                           remappedIndexFromIndices[i]),
3118           largestValidIndex);
3119       remappedIndexFromIndices[i] = clamp;
3120     }
3121 
3122     // For the (remapped) offset dimensions, the index is the current index in
3123     // the output. As before this is expanded to a full index into the operand
3124     // by using zeroe for the missing indices.
3125     SmallVector<Value> indexFromOffset(operandRank, constants[0]);
3126     for (unsigned k = 0; k < offsetDims.size(); ++k)
3127       indexFromOffset[remappedOffsetDims[k]] = linalgIndices[offsetDims[k]];
3128 
3129     // Now we add together our two indices to get the final index into the
3130     // operand.
3131     SmallVector<Value> combinedIndex;
3132     for (int i = 0; i < operandRank; ++i)
3133       combinedIndex.push_back(rewriter.createOrFold<arith::AddIOp>(
3134           loc, rewriter.getIndexType(), remappedIndexFromIndices[i],
3135           indexFromOffset[i]));
3136 
3137     Value element =
3138         rewriter.create<tensor::ExtractOp>(loc, operand, combinedIndex);
3139     rewriter.create<linalg::YieldOp>(loc, element);
3140 
3141     rewriter.replaceOp(gatherOp, linalgOp.getResults());
3142 
3143     return success();
3144   }
3145 };
3146 
3147 class DotGeneralOpConversion : public OpConversionPattern<mhlo::DotGeneralOp> {
3148  public:
3149   using OpConversionPattern<mhlo::DotGeneralOp>::OpConversionPattern;
matchAndRewrite(mhlo::DotGeneralOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const3150   LogicalResult matchAndRewrite(
3151       mhlo::DotGeneralOp op, OpAdaptor adaptor,
3152       ConversionPatternRewriter& rewriter) const final {
3153     if (!verifyHloOpBufferOrTensorSemantics(op)) {
3154       return failure();
3155     }
3156 
3157     // Get various dimension iterator information
3158     mhlo::DotDimensionNumbersAttr dimNumbers = op.dot_dimension_numbers();
3159     auto lhsBatchingDims = dimNumbers.getLhsBatchingDimensions();
3160     auto rhsBatchingDims = dimNumbers.getRhsBatchingDimensions();
3161     auto lhsContractingDims = dimNumbers.getLhsContractingDimensions();
3162     auto rhsContractingDims = dimNumbers.getRhsContractingDimensions();
3163 
3164     // Get shape information and initialize output
3165     assert(lhsContractingDims.size() == rhsContractingDims.size() &&
3166            "number of contracting dims must be equal");
3167     auto numContracting = lhsContractingDims.size();
3168     // Convert unsigned to signed. This works because signed and unsigned
3169     // integer matmul is the same operation in two's complement.
3170     auto outputType =
3171         typeConverter->convertType(op.getType()).cast<ShapedType>();
3172     auto targetRank = outputType.getRank();
3173     auto totalLoopCount = numContracting + targetRank;
3174 
3175     auto lhsRank = adaptor.lhs().getType().cast<ShapedType>().getRank();
3176     auto lhsExtraDims =
3177         lhsRank - lhsBatchingDims.size() - lhsContractingDims.size();
3178     auto rhsRank = adaptor.rhs().getType().cast<ShapedType>().getRank();
3179 
3180     Location loc = op.getLoc();
3181     auto initTensor =
3182         getInitTensorFor(rewriter, loc, outputType, op, adaptor.getOperands());
3183     Value zeroTensor = fillTensorWithZeros(rewriter, loc, initTensor);
3184     SmallVector<AffineMap, 3> indexingMaps;
3185 
3186     auto getMap = [&](int64_t rank, ArrayRef<int64_t> batchingDims,
3187                       ArrayRef<int64_t> contractingDims, size_t extraDims) {
3188       llvm::SmallVector<AffineExpr> indices(rank);
3189       for (const auto& i : llvm::enumerate(batchingDims)) {
3190         indices[i.value()] = rewriter.getAffineDimExpr(i.index());
3191       }
3192       for (const auto& i : llvm::enumerate(contractingDims)) {
3193         indices[i.value()] = rewriter.getAffineDimExpr(i.index() + targetRank);
3194       }
3195       for (int i = 0; i < rank; ++i) {
3196         if (!indices[i]) {
3197           indices[i] = rewriter.getAffineDimExpr(extraDims++);
3198         }
3199       }
3200       indexingMaps.push_back(AffineMap::get(/*dimCount=*/totalLoopCount,
3201                                             /*symbolCount=*/0, indices,
3202                                             op->getContext()));
3203     };
3204     getMap(lhsRank, lhsBatchingDims, lhsContractingDims,
3205            lhsBatchingDims.size());
3206     getMap(rhsRank, rhsBatchingDims, rhsContractingDims,
3207            rhsBatchingDims.size() + lhsExtraDims);
3208 
3209     {
3210       SmallVector<AffineExpr, 4> dimExprs;
3211       dimExprs.reserve(targetRank);
3212       for (unsigned i = 0; i < targetRank; ++i)
3213         dimExprs.push_back(rewriter.getAffineDimExpr(i));
3214       indexingMaps.push_back(AffineMap::get(/*dimCount=*/totalLoopCount,
3215                                             /*symbolCount=*/0, dimExprs,
3216                                             op.getContext()));
3217     }
3218 
3219     Operation* linalgOp = rewriter.create<linalg::GenericOp>(
3220         loc, /*resultTensorTypes=*/TypeRange{outputType},
3221         /*inputs=*/ValueRange{adaptor.lhs(), adaptor.rhs()},
3222         /*outputBuffers=*/ValueRange{zeroTensor}, indexingMaps,
3223         getParallelAndReductionIterators(
3224             /*nLoops=*/totalLoopCount,
3225             /*nReduction=*/numContracting),
3226         [](OpBuilder& b, Location loc, ValueRange) {
3227           ImplicitLocOpBuilder builder(loc, b);
3228           linalg::MatmulOp::regionBuilder(builder, *b.getInsertionBlock(), {});
3229         },
3230         pruneAttributeList(op));
3231 
3232     rewriter.replaceOp(op, linalgOp->getResults());
3233     return success();
3234   }
3235 };
3236 
3237 struct HloLegalizeToLinalgPass
3238     : public mhlo::HloLegalizeToLinalgPassBase<HloLegalizeToLinalgPass> {
getDependentDialectsmlir::mhlo::__anon071fe3ef0111::HloLegalizeToLinalgPass3239   void getDependentDialects(DialectRegistry& registry) const override {
3240     registry.insert<bufferization::BufferizationDialect, linalg::LinalgDialect,
3241                     scf::SCFDialect, complex::ComplexDialect, math::MathDialect,
3242                     memref::MemRefDialect, shape::ShapeDialect>();
3243   }
3244 
runOnOperationmlir::mhlo::__anon071fe3ef0111::HloLegalizeToLinalgPass3245   void runOnOperation() override {
3246     MLIRContext& ctx = getContext();
3247     RewritePatternSet patterns(&ctx);
3248     ConversionTarget target(ctx);
3249     target.addLegalDialect<
3250         bufferization::BufferizationDialect, arith::ArithmeticDialect,
3251         complex::ComplexDialect, linalg::LinalgDialect, math::MathDialect,
3252         tensor::TensorDialect, sparse_tensor::SparseTensorDialect,
3253         scf::SCFDialect, shape::ShapeDialect>();
3254 
3255     target.addLegalOp<UnrealizedConversionCastOp>();
3256 
3257     auto typeConverter = createHloToLinalgTypeConverter();
3258     auto func = getOperation();
3259     mhlo::populateHloToLinalgConversionPattern(&ctx, *typeConverter, &patterns);
3260     if (failed(applyPartialConversion(func, target, std::move(patterns)))) {
3261       signalPassFailure();
3262     }
3263   }
3264 };
3265 
3266 }  // namespace
3267 
populateHloToLinalgConversionPattern(MLIRContext * context,TypeConverter & typeConverter,RewritePatternSet * patterns)3268 void populateHloToLinalgConversionPattern(MLIRContext* context,
3269                                           TypeConverter& typeConverter,
3270                                           RewritePatternSet* patterns) {
3271   // clang-format off
3272   patterns->add<
3273       BroadcastConverter<mhlo::BroadcastOp>, ConcatenateConverter,
3274       ConstConverterTensor, HloDynamicBroadcastInDimConverter,
3275       HloBroadcastInDimConverter, IotaConverter<mhlo::IotaOp>,
3276       EinsumToLinalgConverter,
3277       IotaConverter<mhlo::DynamicIotaOp>,
3278       MapOpConverter,
3279       PointwiseToLinalgConverter<mhlo::AbsOp>,
3280       PointwiseToLinalgConverter<mhlo::AddOp>,
3281       PointwiseToLinalgConverter<mhlo::AndOp>,
3282       PointwiseToLinalgConverter<mhlo::Atan2Op>,
3283       PointwiseToLinalgConverter<mhlo::BitcastConvertOp>,
3284       PointwiseToLinalgConverter<mhlo::CbrtOp>,
3285       PointwiseToLinalgConverter<mhlo::CeilOp>,
3286       PointwiseToLinalgConverter<mhlo::ClampOp>,
3287       PointwiseToLinalgConverter<mhlo::ClzOp>,
3288       PointwiseToLinalgConverter<mhlo::CompareOp>,
3289       PointwiseToLinalgConverter<mhlo::ComplexOp>,
3290       PointwiseToLinalgConverter<mhlo::ConvertOp>,
3291       PointwiseToLinalgConverter<mhlo::CopyOp>,
3292       PointwiseToLinalgConverter<mhlo::CosineOp>,
3293       PointwiseToLinalgConverter<mhlo::DivOp>,
3294       PointwiseToLinalgConverter<mhlo::ExpOp>,
3295       PointwiseToLinalgConverter<mhlo::Expm1Op>,
3296       PointwiseToLinalgConverter<mhlo::FloorOp>,
3297       PointwiseToLinalgConverter<mhlo::ImagOp>,
3298       PointwiseToLinalgConverter<mhlo::IsFiniteOp>,
3299       PointwiseToLinalgConverter<mhlo::LogOp>,
3300       PointwiseToLinalgConverter<mhlo::LogisticOp>,
3301       PointwiseToLinalgConverter<mhlo::Log1pOp>,
3302       PointwiseToLinalgConverter<mhlo::MaxOp>,
3303       PointwiseToLinalgConverter<mhlo::MinOp>,
3304       PointwiseToLinalgConverter<mhlo::MulOp>,
3305       PointwiseToLinalgConverter<mhlo::NegOp>,
3306       PointwiseToLinalgConverter<mhlo::NotOp>,
3307       PointwiseToLinalgConverter<mhlo::OrOp>,
3308       PointwiseToLinalgConverter<mhlo::PopulationCountOp>,
3309       PointwiseToLinalgConverter<mhlo::PowOp>,
3310       PointwiseToLinalgConverter<mhlo::RealOp>,
3311       PointwiseToLinalgConverter<mhlo::RemOp>,
3312       PointwiseToLinalgConverter<mhlo::RoundOp>,
3313       PointwiseToLinalgConverter<mhlo::RsqrtOp>,
3314       PointwiseToLinalgConverter<mhlo::SelectOp>,
3315       PointwiseToLinalgConverter<mhlo::ShiftLeftOp>,
3316       PointwiseToLinalgConverter<mhlo::ShiftRightArithmeticOp>,
3317       PointwiseToLinalgConverter<mhlo::ShiftRightLogicalOp>,
3318       PointwiseToLinalgConverter<mhlo::SignOp>,
3319       PointwiseToLinalgConverter<mhlo::SineOp>,
3320       PointwiseToLinalgConverter<mhlo::SqrtOp>,
3321       PointwiseToLinalgConverter<mhlo::SubtractOp>,
3322       PointwiseToLinalgConverter<mhlo::TanhOp>,
3323       PointwiseToLinalgConverter<mhlo::XorOp>,
3324       PointwiseToLinalgConverter<mhlo::ReducePrecisionOp>,
3325       RealDynamicSliceConverter,
3326       ReshapeOpConverter,
3327       ReverseConverter,
3328       SliceConverter,
3329       DynamicSliceConverter,
3330       DynamicUpdateSliceConverter,
3331       TransposeConverter<mhlo::TransposeOp>,
3332       GatherConversion,
3333       PadOpConversion,
3334       PadOpNegativePaddingConversion,
3335       ReduceConversion,
3336       ReduceWindowOpOnTensorsGenericConversion,
3337       ReduceWindowOpConversion,
3338       RngUniformConversion,
3339       TorchIndexSelectOpConversion,
3340       ReduceRegionReturnOpConversion>(typeConverter, context);
3341   // Ensure specialized patterns are higher priority than their generic
3342   // versions.
3343   patterns->add<
3344       NormalConvolutionOpConversion,
3345       DepthwiseConvolutionOpConversion,
3346       DotOpConversion<DotOperationType::kMatrixMatrix, linalg::MatmulOp>,
3347       DotOpConversion<DotOperationType::kMatrixVector, linalg::MatvecOp>,
3348       DotOpConversion<DotOperationType::kVectorMatrix, linalg::VecmatOp>,
3349       DotOpConversion<DotOperationType::kVectorDot, linalg::DotOp>,
3350       DotGeneralBatchMatMulOpConversion>(typeConverter, context,
3351                                          PatternBenefit(2));
3352   patterns->add<
3353       ConvolutionOpGeneralConversion,
3354       DotGeneralOpConversion>(typeConverter, context, PatternBenefit(1));
3355   // clang-format on
3356 }
3357 
createLegalizeHloToLinalgPass()3358 std::unique_ptr<OperationPass<func::FuncOp>> createLegalizeHloToLinalgPass() {
3359   return std::make_unique<HloLegalizeToLinalgPass>();
3360 }
3361 
createHloToLinalgTypeConverter()3362 std::unique_ptr<TypeConverter> createHloToLinalgTypeConverter() {
3363   return std::make_unique<LinalgTypeConverter>();
3364 }
3365 
3366 }  // namespace mhlo
3367 }  // namespace mlir
3368