1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This file implements utilities for lowering CHLO/HLO/LHLO dialect to Linalg
17 // dialect.
18 
19 #include "mlir-hlo/Dialect/mhlo/transforms/legalize_to_linalg_utils.h"
20 
21 #include <algorithm>
22 #include <numeric>
23 #include <string>
24 #include <utility>
25 
26 #include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
27 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
28 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
29 #include "mlir/Dialect/Tensor/IR/Tensor.h"
30 
31 namespace mlir {
32 namespace mhlo {
33 namespace {
34 
hasIntegralShapeType(Operation * op)35 bool hasIntegralShapeType(Operation* op) {
36   auto stp = op->getOperand(0).getType().dyn_cast<ShapedType>();
37   return stp && stp.getElementType().isIntOrIndex();
38 }
39 
getInitSparseTensor(OpBuilder & b,Location loc,ShapedType type,ArrayRef<Value> dynSizes)40 Value getInitSparseTensor(OpBuilder& b, Location loc, ShapedType type,
41                           ArrayRef<Value> dynSizes) {
42   return b.create<bufferization::AllocTensorOp>(loc, type, dynSizes,
43                                                 /*copy=*/Value(),
44                                                 /*memory_space=*/IntegerAttr());
45 }
46 
47 }  // namespace
48 
getParallelAndReductionIterators(unsigned nLoops,unsigned nReduction)49 SmallVector<StringRef, 3> getParallelAndReductionIterators(
50     unsigned nLoops, unsigned nReduction) {
51   SmallVector<StringRef, 3> res(nLoops - nReduction,
52                                 getParallelIteratorTypeName());
53   res.append(nReduction, getReductionIteratorTypeName());
54   return res;
55 }
56 
getNParallelLoopsAttrs(unsigned nParallelLoops)57 SmallVector<StringRef, 3> getNParallelLoopsAttrs(unsigned nParallelLoops) {
58   return getParallelAndReductionIterators(nParallelLoops, 0);
59 }
60 
getInitTensor(OpBuilder & b,Location loc,ShapedType type,ArrayRef<Value> dynSizes)61 Value getInitTensor(OpBuilder& b, Location loc, ShapedType type,
62                     ArrayRef<Value> dynSizes) {
63   return b.create<linalg::InitTensorOp>(loc, dynSizes, type.getShape(),
64                                         type.getElementType());
65 }
66 
getInitTensorFor(OpBuilder & b,Location loc,ShapedType resultType,Operation * op,ValueRange operands)67 Value getInitTensorFor(OpBuilder& b, Location loc, ShapedType resultType,
68                        Operation* op, ValueRange operands) {
69   bool isSparse = sparse_tensor::getSparseTensorEncoding(resultType) != nullptr;
70   // Collect the sizes for a ranked tensor to be passed as parameter to a
71   // new tensor initialization operation. This operation only needs the
72   // dynamic sizes.
73   SmallVector<Value> sizes;
74   if (resultType.hasRank() && !resultType.hasStaticShape()) {
75     // Ask the op for its output shape.
76     auto shapeSource = cast<InferShapedTypeOpInterface>(op);
77     SmallVector<Value, 1> reifiedShapes;
78     (void)shapeSource.reifyReturnTypeShapes(b, operands, reifiedShapes);
79     assert(reifiedShapes.size() == 1 && "Expected one reified result");
80     // Construct sizes for the required dimensions.
81     for (auto& en : llvm::enumerate(resultType.getShape())) {
82       if (en.value() != ShapedType::kDynamicSize) continue;
83       sizes.push_back(b.create<tensor::ExtractOp>(
84           loc, reifiedShapes[0],
85           ValueRange{b.create<arith::ConstantIndexOp>(loc, en.index())}));
86     }
87   }
88   return isSparse ? getInitSparseTensor(b, loc, resultType, sizes)
89                   : getInitTensor(b, loc, resultType, sizes);
90 }
91 
preSparsify(Operation * op,llvm::SmallVector<Value,2> & values,Type rtp,OpBuilder * b)92 Value preSparsify(Operation* op, llvm::SmallVector<Value, 2>& values, Type rtp,
93                   OpBuilder* b) {
94   // Apply for semi-ring operations that lower to elaborate code
95   // (any sign-op, any elt-wise conversion, or an integral abs-op).
96   if (isa<mhlo::SignOp>(op) || isa<mhlo::ConvertOp>(op) ||
97       (isa<mhlo::AbsOp>(op) && hasIntegralShapeType(op)) ||
98       isa<chlo::AsinOp>(op) || isa<chlo::AsinhOp>(op) ||
99       isa<chlo::AtanOp>(op) || isa<chlo::AtanhOp>(op) ||
100       isa<chlo::BesselI1eOp>(op) || isa<chlo::SinhOp>(op) ||
101       isa<chlo::TanOp>(op)) {
102     if (!sparse_tensor::getSparseTensorEncoding(op->getResult(0).getType()) &&
103         !sparse_tensor::getSparseTensorEncoding(op->getOperand(0).getType()))
104       return Value();
105     Location loc = op->getLoc();
106     auto semiring = b->create<sparse_tensor::UnaryOp>(loc, rtp, values[0]);
107     Type itp = values[0].getType();
108     Block* present = b->createBlock(&semiring.getPresentRegion(), {}, itp, loc);
109     b->setInsertionPointToStart(&semiring.getPresentRegion().front());
110     values[0] = present->getArgument(0);
111     return semiring;
112   }
113   return Value();
114 }
115 
postSparsify(Operation * op,Value semiring,Value result,OpBuilder * b)116 Value postSparsify(Operation* op, Value semiring, Value result, OpBuilder* b) {
117   if (semiring) {
118     b->create<sparse_tensor::YieldOp>(op->getLoc(), result);
119     b->setInsertionPointAfter(semiring.getDefiningOp());
120     return semiring;
121   }
122   return result;
123 }
124 
125 }  // namespace mhlo
126 
127 }  // namespace mlir
128