1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // This file implements utilities for lowering CHLO/HLO/LHLO dialect to Linalg
17 // dialect.
18
19 #include "mlir-hlo/Dialect/mhlo/transforms/legalize_to_linalg_utils.h"
20
21 #include <algorithm>
22 #include <numeric>
23 #include <string>
24 #include <utility>
25
26 #include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
27 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
28 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
29 #include "mlir/Dialect/Tensor/IR/Tensor.h"
30
31 namespace mlir {
32 namespace mhlo {
33 namespace {
34
hasIntegralShapeType(Operation * op)35 bool hasIntegralShapeType(Operation* op) {
36 auto stp = op->getOperand(0).getType().dyn_cast<ShapedType>();
37 return stp && stp.getElementType().isIntOrIndex();
38 }
39
getInitSparseTensor(OpBuilder & b,Location loc,ShapedType type,ArrayRef<Value> dynSizes)40 Value getInitSparseTensor(OpBuilder& b, Location loc, ShapedType type,
41 ArrayRef<Value> dynSizes) {
42 return b.create<bufferization::AllocTensorOp>(loc, type, dynSizes,
43 /*copy=*/Value(),
44 /*memory_space=*/IntegerAttr());
45 }
46
47 } // namespace
48
getParallelAndReductionIterators(unsigned nLoops,unsigned nReduction)49 SmallVector<StringRef, 3> getParallelAndReductionIterators(
50 unsigned nLoops, unsigned nReduction) {
51 SmallVector<StringRef, 3> res(nLoops - nReduction,
52 getParallelIteratorTypeName());
53 res.append(nReduction, getReductionIteratorTypeName());
54 return res;
55 }
56
getNParallelLoopsAttrs(unsigned nParallelLoops)57 SmallVector<StringRef, 3> getNParallelLoopsAttrs(unsigned nParallelLoops) {
58 return getParallelAndReductionIterators(nParallelLoops, 0);
59 }
60
getInitTensor(OpBuilder & b,Location loc,ShapedType type,ArrayRef<Value> dynSizes)61 Value getInitTensor(OpBuilder& b, Location loc, ShapedType type,
62 ArrayRef<Value> dynSizes) {
63 return b.create<linalg::InitTensorOp>(loc, dynSizes, type.getShape(),
64 type.getElementType());
65 }
66
getInitTensorFor(OpBuilder & b,Location loc,ShapedType resultType,Operation * op,ValueRange operands)67 Value getInitTensorFor(OpBuilder& b, Location loc, ShapedType resultType,
68 Operation* op, ValueRange operands) {
69 bool isSparse = sparse_tensor::getSparseTensorEncoding(resultType) != nullptr;
70 // Collect the sizes for a ranked tensor to be passed as parameter to a
71 // new tensor initialization operation. This operation only needs the
72 // dynamic sizes.
73 SmallVector<Value> sizes;
74 if (resultType.hasRank() && !resultType.hasStaticShape()) {
75 // Ask the op for its output shape.
76 auto shapeSource = cast<InferShapedTypeOpInterface>(op);
77 SmallVector<Value, 1> reifiedShapes;
78 (void)shapeSource.reifyReturnTypeShapes(b, operands, reifiedShapes);
79 assert(reifiedShapes.size() == 1 && "Expected one reified result");
80 // Construct sizes for the required dimensions.
81 for (auto& en : llvm::enumerate(resultType.getShape())) {
82 if (en.value() != ShapedType::kDynamicSize) continue;
83 sizes.push_back(b.create<tensor::ExtractOp>(
84 loc, reifiedShapes[0],
85 ValueRange{b.create<arith::ConstantIndexOp>(loc, en.index())}));
86 }
87 }
88 return isSparse ? getInitSparseTensor(b, loc, resultType, sizes)
89 : getInitTensor(b, loc, resultType, sizes);
90 }
91
preSparsify(Operation * op,llvm::SmallVector<Value,2> & values,Type rtp,OpBuilder * b)92 Value preSparsify(Operation* op, llvm::SmallVector<Value, 2>& values, Type rtp,
93 OpBuilder* b) {
94 // Apply for semi-ring operations that lower to elaborate code
95 // (any sign-op, any elt-wise conversion, or an integral abs-op).
96 if (isa<mhlo::SignOp>(op) || isa<mhlo::ConvertOp>(op) ||
97 (isa<mhlo::AbsOp>(op) && hasIntegralShapeType(op)) ||
98 isa<chlo::AsinOp>(op) || isa<chlo::AsinhOp>(op) ||
99 isa<chlo::AtanOp>(op) || isa<chlo::AtanhOp>(op) ||
100 isa<chlo::BesselI1eOp>(op) || isa<chlo::SinhOp>(op) ||
101 isa<chlo::TanOp>(op)) {
102 if (!sparse_tensor::getSparseTensorEncoding(op->getResult(0).getType()) &&
103 !sparse_tensor::getSparseTensorEncoding(op->getOperand(0).getType()))
104 return Value();
105 Location loc = op->getLoc();
106 auto semiring = b->create<sparse_tensor::UnaryOp>(loc, rtp, values[0]);
107 Type itp = values[0].getType();
108 Block* present = b->createBlock(&semiring.getPresentRegion(), {}, itp, loc);
109 b->setInsertionPointToStart(&semiring.getPresentRegion().front());
110 values[0] = present->getArgument(0);
111 return semiring;
112 }
113 return Value();
114 }
115
postSparsify(Operation * op,Value semiring,Value result,OpBuilder * b)116 Value postSparsify(Operation* op, Value semiring, Value result, OpBuilder* b) {
117 if (semiring) {
118 b->create<sparse_tensor::YieldOp>(op->getLoc(), result);
119 b->setInsertionPointAfter(semiring.getDefiningOp());
120 return semiring;
121 }
122 return result;
123 }
124
125 } // namespace mhlo
126
127 } // namespace mlir
128