1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // This file implements logic for lowering LHLO dialect to GPU dialect.
17
18 #include <cstdint>
19
20 #include "llvm/ADT/ArrayRef.h"
21 #include "mlir-hlo/Dialect/lhlo/IR/lhlo_ops.h"
22 #include "mlir-hlo/Dialect/lhlo/transforms/PassDetail.h"
23 #include "mlir-hlo/Dialect/lhlo/transforms/map_lmhlo_to_scalar_op.h"
24 #include "mlir/Dialect/Affine/IR/AffineOps.h"
25 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
26 #include "mlir/Dialect/Func/IR/FuncOps.h"
27 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
28 #include "mlir/Dialect/Linalg/IR/Linalg.h"
29 #include "mlir/Dialect/MemRef/IR/MemRef.h"
30 #include "mlir/Dialect/SCF/IR/SCF.h"
31 #include "mlir/IR/Attributes.h"
32 #include "mlir/IR/BlockAndValueMapping.h"
33 #include "mlir/IR/Builders.h"
34 #include "mlir/IR/BuiltinOps.h"
35 #include "mlir/IR/BuiltinTypes.h"
36 #include "mlir/IR/Location.h"
37 #include "mlir/IR/MLIRContext.h"
38 #include "mlir/IR/Operation.h"
39 #include "mlir/IR/PatternMatch.h"
40 #include "mlir/Pass/Pass.h"
41 #include "mlir/Transforms/DialectConversion.h"
42
43 namespace mlir {
44 namespace lmhlo {
45 namespace {
46
47 // A simple translation of LHLO reduce operations to a corresponding gpu
48 // launch operation. The transformation does no tiling and also only supports
49 // 1d results.
50 class LhloReduceToGPULaunchConverter : public OpConversionPattern<ReduceOp> {
51 public:
52 using OpConversionPattern::OpConversionPattern;
53
matchAndRewrite(ReduceOp reduceOp,OpAdaptor,ConversionPatternRewriter & rewriter) const54 LogicalResult matchAndRewrite(
55 ReduceOp reduceOp, OpAdaptor /*adaptor*/,
56 ConversionPatternRewriter& rewriter) const final {
57 auto loc = reduceOp.getLoc();
58 // Only support 1d reductions for now.
59 int64_t size = 0;
60 for (auto result : reduceOp.getOut()) {
61 auto shapedType = result.getType().dyn_cast<ShapedType>();
62 if (!shapedType || shapedType.getRank() != 1) {
63 return failure();
64 }
65 auto dimSize = shapedType.getDimSize(0);
66 if (size && size != dimSize) {
67 return failure();
68 }
69 size = dimSize;
70 }
71
72 auto reducingDimension = *reduceOp.getDimensions().value_begin<APInt>();
73
74 // Require all inputs to have the same shape.
75 int64_t reduceDimSize = 0;
76 for (auto input : reduceOp.getInputs()) {
77 auto shapedType = input.getType().dyn_cast<ShapedType>();
78 if (!shapedType || !shapedType.hasStaticShape()) {
79 return failure();
80 }
81 reduceDimSize = shapedType.getDimSize(reducingDimension.getSExtValue());
82 }
83
84 // Create a launch that is parallel in the result dimension.
85 auto blockSizeX = rewriter.create<mlir::arith::ConstantOp>(
86 loc, rewriter.getIndexType(),
87 rewriter.getIntegerAttr(rewriter.getIndexType(), size));
88 auto one = rewriter.create<mlir::arith::ConstantOp>(
89 loc, rewriter.getIndexType(),
90 rewriter.getIntegerAttr(rewriter.getIndexType(), 1));
91 auto launchOp = rewriter.create<mlir::gpu::LaunchOp>(loc, one, one, one,
92 blockSizeX, one, one);
93 {
94 OpBuilder::InsertionGuard guard(rewriter);
95 rewriter.setInsertionPointToEnd(&launchOp.body().front());
96 auto index = launchOp.getThreadIds().x;
97
98 // Load the initial value and store it to the output.
99 for (auto pair : llvm::zip(reduceOp.getInitValues(), reduceOp.getOut())) {
100 auto initValue =
101 rewriter.create<mlir::memref::LoadOp>(loc, std::get<0>(pair));
102 rewriter.create<mlir::memref::StoreOp>(
103 loc, initValue, std::get<1>(pair), ArrayRef<Value>{index});
104 }
105
106 // Insert a loop into the body to compute the reduction. The loop ranges
107 // from [0.dim).
108 auto zero = rewriter.create<mlir::arith::ConstantOp>(
109 loc, rewriter.getIndexType(),
110 rewriter.getIntegerAttr(rewriter.getIndexType(), 0));
111 // TODO(b/137624192) Use dimOp to make it shape independent.
112 auto upper = rewriter.create<mlir::arith::ConstantOp>(
113 loc, rewriter.getIndexType(),
114 rewriter.getIntegerAttr(rewriter.getIndexType(), reduceDimSize));
115 auto step = rewriter.create<mlir::arith::ConstantOp>(
116 loc, rewriter.getIndexType(),
117 rewriter.getIntegerAttr(rewriter.getIndexType(), 1));
118 auto loop = rewriter.create<mlir::scf::ForOp>(loc, zero, upper, step);
119
120 rewriter.setInsertionPointToStart(loop.getBody());
121 // Compute memrefs for the value to reduce. This makes it easier to just
122 // inline the body.
123 auto output = *reduceOp.getOut().begin();
124 auto resType = MemRefType::get(
125 llvm::None, getElementTypeOrSelf(output.getType()),
126 makeStridedLinearLayoutMap(llvm::None,
127 MemRefType::getDynamicStrideOrOffset(),
128 rewriter.getContext()));
129 OpFoldResult offset = launchOp.getThreadIds().x;
130 auto oneAttr = rewriter.getI64IntegerAttr(1);
131 OpFoldResult size = oneAttr;
132 OpFoldResult stride = oneAttr;
133 auto accumulator = rewriter.create<memref::SubViewOp>(
134 loc, resType, output, offset, size, stride);
135 llvm::SmallVector<Value, 4> indexings;
136 Value inputBuffer = reduceOp.getInputs().front();
137 auto inputTypeRank = inputBuffer.getType().cast<MemRefType>().getRank();
138
139 Value input = *reduceOp.operand_begin();
140 SmallVector<OpFoldResult> offsets = llvm::to_vector<4>(llvm::map_range(
141 llvm::seq<int>(0, inputTypeRank), [&](int dim) -> OpFoldResult {
142 return dim == reducingDimension ? loop.getInductionVar()
143 : launchOp.getThreadIds().x;
144 }));
145 SmallVector<OpFoldResult> sizes(inputTypeRank, oneAttr);
146 SmallVector<OpFoldResult> strides(inputTypeRank, oneAttr);
147 auto rhs = rewriter.create<memref::SubViewOp>(
148 loc, accumulator.getType(), input, offsets, sizes, strides);
149
150 // Now copy over the actual body of the reduction, leaving out the
151 // terminator.
152 BlockAndValueMapping mapping;
153 mapping.map(reduceOp.getBody().getArgument(0), accumulator);
154 mapping.map(reduceOp.getBody().getArgument(1), rhs);
155 mapping.map(reduceOp.getBody().getArgument(2), accumulator);
156 for (auto& nested : reduceOp.getBody().front().without_terminator()) {
157 auto* clone = rewriter.clone(nested, mapping);
158 for (auto pair : llvm::zip(nested.getResults(), clone->getResults())) {
159 mapping.map(std::get<0>(pair), std::get<1>(pair));
160 }
161 }
162
163 // Finally, insert the terminator for the launchOp.
164 rewriter.setInsertionPointToEnd(&launchOp.body().front());
165 rewriter.create<mlir::gpu::TerminatorOp>(loc);
166 }
167
168 rewriter.eraseOp(reduceOp);
169 return success();
170 };
171 };
172
173 struct LhloLegalizeToGpuPass
174 : public LhloLegalizeToGpuPassBase<LhloLegalizeToGpuPass> {
getDependentDialectsmlir::lmhlo::__anon2e617e510111::LhloLegalizeToGpuPass175 void getDependentDialects(DialectRegistry& registry) const override {
176 registry.insert<AffineDialect, gpu::GPUDialect, linalg::LinalgDialect,
177 memref::MemRefDialect, scf::SCFDialect>();
178 }
179
runOnOperationmlir::lmhlo::__anon2e617e510111::LhloLegalizeToGpuPass180 void runOnOperation() override {
181 RewritePatternSet patterns(&getContext());
182 ConversionTarget target(getContext());
183 target.addLegalDialect<arith::ArithmeticDialect, linalg::LinalgDialect,
184 memref::MemRefDialect, func::FuncDialect,
185 gpu::GPUDialect, scf::SCFDialect, LmhloDialect>();
186 target.addIllegalOp<ReduceOp>();
187 auto func = getOperation();
188 patterns.add<LhloReduceToGPULaunchConverter>(func.getContext());
189 if (failed(applyPartialConversion(func, target, std::move(patterns)))) {
190 signalPassFailure();
191 }
192 }
193 };
194
195 } // namespace
196
createLegalizeToGpuPass()197 std::unique_ptr<OperationPass<func::FuncOp>> createLegalizeToGpuPass() {
198 return std::make_unique<LhloLegalizeToGpuPass>();
199 }
200
201 } // namespace lmhlo
202 } // namespace mlir
203