1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This file implements logic for lowering LHLO dialect to GPU dialect.
17 
18 #include <cstdint>
19 
20 #include "llvm/ADT/ArrayRef.h"
21 #include "mlir-hlo/Dialect/lhlo/IR/lhlo_ops.h"
22 #include "mlir-hlo/Dialect/lhlo/transforms/PassDetail.h"
23 #include "mlir-hlo/Dialect/lhlo/transforms/map_lmhlo_to_scalar_op.h"
24 #include "mlir/Dialect/Affine/IR/AffineOps.h"
25 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
26 #include "mlir/Dialect/Func/IR/FuncOps.h"
27 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
28 #include "mlir/Dialect/Linalg/IR/Linalg.h"
29 #include "mlir/Dialect/MemRef/IR/MemRef.h"
30 #include "mlir/Dialect/SCF/IR/SCF.h"
31 #include "mlir/IR/Attributes.h"
32 #include "mlir/IR/BlockAndValueMapping.h"
33 #include "mlir/IR/Builders.h"
34 #include "mlir/IR/BuiltinOps.h"
35 #include "mlir/IR/BuiltinTypes.h"
36 #include "mlir/IR/Location.h"
37 #include "mlir/IR/MLIRContext.h"
38 #include "mlir/IR/Operation.h"
39 #include "mlir/IR/PatternMatch.h"
40 #include "mlir/Pass/Pass.h"
41 #include "mlir/Transforms/DialectConversion.h"
42 
43 namespace mlir {
44 namespace lmhlo {
45 namespace {
46 
47 // A simple translation of LHLO reduce operations to a corresponding gpu
48 // launch operation. The transformation does no tiling and also only supports
49 // 1d results.
50 class LhloReduceToGPULaunchConverter : public OpConversionPattern<ReduceOp> {
51  public:
52   using OpConversionPattern::OpConversionPattern;
53 
matchAndRewrite(ReduceOp reduceOp,OpAdaptor,ConversionPatternRewriter & rewriter) const54   LogicalResult matchAndRewrite(
55       ReduceOp reduceOp, OpAdaptor /*adaptor*/,
56       ConversionPatternRewriter& rewriter) const final {
57     auto loc = reduceOp.getLoc();
58     // Only support 1d reductions for now.
59     int64_t size = 0;
60     for (auto result : reduceOp.getOut()) {
61       auto shapedType = result.getType().dyn_cast<ShapedType>();
62       if (!shapedType || shapedType.getRank() != 1) {
63         return failure();
64       }
65       auto dimSize = shapedType.getDimSize(0);
66       if (size && size != dimSize) {
67         return failure();
68       }
69       size = dimSize;
70     }
71 
72     auto reducingDimension = *reduceOp.getDimensions().value_begin<APInt>();
73 
74     // Require all inputs to have the same shape.
75     int64_t reduceDimSize = 0;
76     for (auto input : reduceOp.getInputs()) {
77       auto shapedType = input.getType().dyn_cast<ShapedType>();
78       if (!shapedType || !shapedType.hasStaticShape()) {
79         return failure();
80       }
81       reduceDimSize = shapedType.getDimSize(reducingDimension.getSExtValue());
82     }
83 
84     // Create a launch that is parallel in the result dimension.
85     auto blockSizeX = rewriter.create<mlir::arith::ConstantOp>(
86         loc, rewriter.getIndexType(),
87         rewriter.getIntegerAttr(rewriter.getIndexType(), size));
88     auto one = rewriter.create<mlir::arith::ConstantOp>(
89         loc, rewriter.getIndexType(),
90         rewriter.getIntegerAttr(rewriter.getIndexType(), 1));
91     auto launchOp = rewriter.create<mlir::gpu::LaunchOp>(loc, one, one, one,
92                                                          blockSizeX, one, one);
93     {
94       OpBuilder::InsertionGuard guard(rewriter);
95       rewriter.setInsertionPointToEnd(&launchOp.body().front());
96       auto index = launchOp.getThreadIds().x;
97 
98       // Load the initial value and store it to the output.
99       for (auto pair : llvm::zip(reduceOp.getInitValues(), reduceOp.getOut())) {
100         auto initValue =
101             rewriter.create<mlir::memref::LoadOp>(loc, std::get<0>(pair));
102         rewriter.create<mlir::memref::StoreOp>(
103             loc, initValue, std::get<1>(pair), ArrayRef<Value>{index});
104       }
105 
106       // Insert a loop into the body to compute the reduction. The loop ranges
107       // from [0.dim).
108       auto zero = rewriter.create<mlir::arith::ConstantOp>(
109           loc, rewriter.getIndexType(),
110           rewriter.getIntegerAttr(rewriter.getIndexType(), 0));
111       // TODO(b/137624192) Use dimOp to make it shape independent.
112       auto upper = rewriter.create<mlir::arith::ConstantOp>(
113           loc, rewriter.getIndexType(),
114           rewriter.getIntegerAttr(rewriter.getIndexType(), reduceDimSize));
115       auto step = rewriter.create<mlir::arith::ConstantOp>(
116           loc, rewriter.getIndexType(),
117           rewriter.getIntegerAttr(rewriter.getIndexType(), 1));
118       auto loop = rewriter.create<mlir::scf::ForOp>(loc, zero, upper, step);
119 
120       rewriter.setInsertionPointToStart(loop.getBody());
121       // Compute memrefs for the value to reduce. This makes it easier to just
122       // inline the body.
123       auto output = *reduceOp.getOut().begin();
124       auto resType = MemRefType::get(
125           llvm::None, getElementTypeOrSelf(output.getType()),
126           makeStridedLinearLayoutMap(llvm::None,
127                                      MemRefType::getDynamicStrideOrOffset(),
128                                      rewriter.getContext()));
129       OpFoldResult offset = launchOp.getThreadIds().x;
130       auto oneAttr = rewriter.getI64IntegerAttr(1);
131       OpFoldResult size = oneAttr;
132       OpFoldResult stride = oneAttr;
133       auto accumulator = rewriter.create<memref::SubViewOp>(
134           loc, resType, output, offset, size, stride);
135       llvm::SmallVector<Value, 4> indexings;
136       Value inputBuffer = reduceOp.getInputs().front();
137       auto inputTypeRank = inputBuffer.getType().cast<MemRefType>().getRank();
138 
139       Value input = *reduceOp.operand_begin();
140       SmallVector<OpFoldResult> offsets = llvm::to_vector<4>(llvm::map_range(
141           llvm::seq<int>(0, inputTypeRank), [&](int dim) -> OpFoldResult {
142             return dim == reducingDimension ? loop.getInductionVar()
143                                             : launchOp.getThreadIds().x;
144           }));
145       SmallVector<OpFoldResult> sizes(inputTypeRank, oneAttr);
146       SmallVector<OpFoldResult> strides(inputTypeRank, oneAttr);
147       auto rhs = rewriter.create<memref::SubViewOp>(
148           loc, accumulator.getType(), input, offsets, sizes, strides);
149 
150       // Now copy over the actual body of the reduction, leaving out the
151       // terminator.
152       BlockAndValueMapping mapping;
153       mapping.map(reduceOp.getBody().getArgument(0), accumulator);
154       mapping.map(reduceOp.getBody().getArgument(1), rhs);
155       mapping.map(reduceOp.getBody().getArgument(2), accumulator);
156       for (auto& nested : reduceOp.getBody().front().without_terminator()) {
157         auto* clone = rewriter.clone(nested, mapping);
158         for (auto pair : llvm::zip(nested.getResults(), clone->getResults())) {
159           mapping.map(std::get<0>(pair), std::get<1>(pair));
160         }
161       }
162 
163       // Finally, insert the terminator for the launchOp.
164       rewriter.setInsertionPointToEnd(&launchOp.body().front());
165       rewriter.create<mlir::gpu::TerminatorOp>(loc);
166     }
167 
168     rewriter.eraseOp(reduceOp);
169     return success();
170   };
171 };
172 
173 struct LhloLegalizeToGpuPass
174     : public LhloLegalizeToGpuPassBase<LhloLegalizeToGpuPass> {
getDependentDialectsmlir::lmhlo::__anon2e617e510111::LhloLegalizeToGpuPass175   void getDependentDialects(DialectRegistry& registry) const override {
176     registry.insert<AffineDialect, gpu::GPUDialect, linalg::LinalgDialect,
177                     memref::MemRefDialect, scf::SCFDialect>();
178   }
179 
runOnOperationmlir::lmhlo::__anon2e617e510111::LhloLegalizeToGpuPass180   void runOnOperation() override {
181     RewritePatternSet patterns(&getContext());
182     ConversionTarget target(getContext());
183     target.addLegalDialect<arith::ArithmeticDialect, linalg::LinalgDialect,
184                            memref::MemRefDialect, func::FuncDialect,
185                            gpu::GPUDialect, scf::SCFDialect, LmhloDialect>();
186     target.addIllegalOp<ReduceOp>();
187     auto func = getOperation();
188     patterns.add<LhloReduceToGPULaunchConverter>(func.getContext());
189     if (failed(applyPartialConversion(func, target, std::move(patterns)))) {
190       signalPassFailure();
191     }
192   }
193 };
194 
195 }  // namespace
196 
createLegalizeToGpuPass()197 std::unique_ptr<OperationPass<func::FuncOp>> createLegalizeToGpuPass() {
198   return std::make_unique<LhloLegalizeToGpuPass>();
199 }
200 
201 }  // namespace lmhlo
202 }  // namespace mlir
203