1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <algorithm>
17 #include <functional>
18 #include <iterator>
19 #include <memory>
20 #include <numeric>
21 #include <utility>
22
23 #include "llvm/ADT/STLExtras.h"
24 #include "llvm/ADT/STLFunctionalExtras.h"
25 #include "llvm/ADT/SmallVector.h"
26 #include "mlir-hlo/Transforms/PassDetail.h"
27 #include "mlir-hlo/Transforms/passes.h"
28 #include "mlir/AsmParser/AsmParser.h"
29 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
30 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
31 #include "mlir/Dialect/LLVMIR/LLVMTypes.h"
32 #include "mlir/IR/BuiltinAttributes.h"
33 #include "mlir/IR/BuiltinOps.h"
34 #include "mlir/IR/BuiltinTypes.h"
35 #include "mlir/IR/MLIRContext.h"
36 #include "mlir/IR/PatternMatch.h"
37 #include "mlir/IR/SymbolTable.h"
38 #include "mlir/IR/TypeRange.h"
39 #include "mlir/IR/Value.h"
40 #include "mlir/IR/Visitors.h"
41 #include "mlir/Pass/Pass.h"
42 #include "mlir/Rewrite/FrozenRewritePatternSet.h"
43 #include "mlir/Support/LLVM.h"
44 #include "mlir/Support/LogicalResult.h"
45 #include "mlir/Transforms/DialectConversion.h"
46 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
47
48 namespace mlir {
49
50 namespace {
51
52 // Replaces flattened memref arguments (base, aligned, offset, sizes, strides)
53 // with base and constants if the corresponding launch_func ops argument has
54 // static shape. Removes all arguments but base.
55 class PropagateStaticShapesPattern : public OpRewritePattern<LLVM::LLVMFuncOp> {
56 public:
PropagateStaticShapesPattern(MLIRContext * ctx,SymbolTable & symbolTable,Type pointerType)57 explicit PropagateStaticShapesPattern(MLIRContext* ctx,
58 SymbolTable& symbolTable,
59 Type pointerType)
60 : OpRewritePattern<LLVM::LLVMFuncOp>(ctx),
61 symbolTable(symbolTable),
62 pointerType(pointerType) {}
63
64 private:
65 LogicalResult matchAndRewrite(LLVM::LLVMFuncOp funcOp,
66 PatternRewriter& rewriter) const final;
67
68 SymbolTable& symbolTable;
69 Type pointerType;
70 };
71
72 class PropagateStaticShapesToKernelPass
73 : public PropagateStaticShapesToKernelPassBase<
74 PropagateStaticShapesToKernelPass> {
75 public:
PropagateStaticShapesToKernelPass(Type pointerType)76 explicit PropagateStaticShapesToKernelPass(Type pointerType)
77 : pointerType(pointerType) {}
78
79 private:
80 void runOnOperation() override;
81
82 Type pointerType;
83 };
84
85 } // namespace
86
87 // Replaces 'arguments' (containing 'base', 'align', 'offset', 'sizes[rank]',
88 // 'strides[rank]') corresponding to statically shaped 'memref' with the base
89 // pointer and constants. The base pointer is changed to 'pointer_type' if
90 // provided.
replaceStaticMemRefArguments(ArrayRef<BlockArgument> arguments,MemRefType memref,Type pointerType,PatternRewriter & rewriter)91 static void replaceStaticMemRefArguments(ArrayRef<BlockArgument> arguments,
92 MemRefType memref, Type pointerType,
93 PatternRewriter& rewriter) {
94 assert(arguments.size() >= 3 && "expected at least 3 arguments");
95 Value base = arguments[0];
96 if (pointerType) {
97 // Change base to given type, replace with bitcast back to original type.
98 Type type = base.getType();
99 base.setType(pointerType);
100 auto cast = rewriter.create<LLVM::BitcastOp>(base.getLoc(), type, base);
101 base.replaceAllUsesExcept(/*newValue=*/cast, /*exceptedUser=*/cast);
102 base = cast.getResult();
103 }
104
105 // Replace uses of 'aligned' with 'base'.
106 arguments[1].replaceAllUsesWith(base);
107 // Replace uses of 'offset' with constant.
108 arguments[2].replaceAllUsesWith(rewriter.create<LLVM::ConstantOp>(
109 arguments[2].getLoc(), arguments[2].getType(),
110 rewriter.getIntegerAttr(arguments[2].getType(), 0)));
111 auto replace = [&](ArrayRef<int64_t> values,
112 ArrayRef<BlockArgument> arguments) {
113 for (auto valAndArg : llvm::zip_first(values, arguments)) {
114 auto argument = std::get<1>(valAndArg);
115 argument.replaceAllUsesWith(rewriter.create<LLVM::ConstantOp>(
116 argument.getLoc(), argument.getType(),
117 rewriter.getIntegerAttr(argument.getType(), std::get<0>(valAndArg))));
118 }
119 };
120 // Replace 'sizes' and 'strides' with constants.
121 replace(memref.getShape(), arguments.drop_front(3));
122 auto strides = llvm::to_vector<4>(memref.getShape());
123 std::partial_sum(strides.rbegin(), strides.rend(), strides.rbegin(),
124 std::multiplies<int64_t>());
125 strides.push_back(1);
126 replace(llvm::makeArrayRef(strides).drop_front(),
127 arguments.drop_front(3 + memref.getRank()));
128 }
129
matchAndRewrite(LLVM::LLVMFuncOp funcOp,PatternRewriter & rewriter) const130 LogicalResult PropagateStaticShapesPattern::matchAndRewrite(
131 LLVM::LLVMFuncOp funcOp, PatternRewriter& rewriter) const {
132 if (funcOp.isExternal())
133 return rewriter.notifyMatchFailure(funcOp, "external");
134 if (!funcOp->getAttrOfType<UnitAttr>(
135 gpu::GPUDialect::getKernelFuncAttrName())) {
136 return rewriter.notifyMatchFailure(funcOp, "missing gpu.kernel");
137 }
138
139 // Collect gpu.launch_func ops which launch the func_op kernel.
140 Optional<SymbolTable::UseRange> symUses =
141 symbolTable.getSymbolUses(funcOp, symbolTable.getOp());
142 if (!symUses)
143 return rewriter.notifyMatchFailure(funcOp, "failed to find symbol uses");
144 auto mapper = [](SymbolTable::SymbolUse symUse) {
145 return dyn_cast<gpu::LaunchFuncOp>(symUse.getUser());
146 };
147 auto filter = [](gpu::LaunchFuncOp op) -> bool { return op; };
148 auto launchOps = llvm::to_vector(
149 llvm::make_filter_range(llvm::map_range(*symUses, mapper), filter));
150 if (launchOps.empty())
151 return rewriter.notifyMatchFailure(funcOp, "no gpu.launch_func uses");
152 OperandRange operands = launchOps.begin()->operands();
153 if (llvm::any_of(launchOps, [&](gpu::LaunchFuncOp op) {
154 return op.operands().getTypes() != operands.getTypes();
155 })) {
156 return rewriter.notifyMatchFailure(funcOp, "operand types mismatch");
157 }
158
159 rewriter.setInsertionPointToStart(&funcOp.front());
160 BitVector argsToDrop(funcOp.getNumArguments());
161 // Loop over the launch_op's 'operands' containing scalars and memrefs and the
162 // func_ops's 'arguments' containing scalars and flattened memrefs. When an
163 // operand is a staticlly shaped memref, replace the range of arguments
164 // corresponding to the flattened memref with just the 'base' pointer.
165 for (auto arguments = funcOp.getArguments(); !arguments.empty();
166 operands = operands.drop_front()) {
167 auto memref = operands.getTypes().front().dyn_cast<MemRefType>();
168 if (!memref) {
169 // Scalar argument, advance by one.
170 arguments = arguments.drop_front();
171 continue;
172 }
173 if (!memref.hasRank()) break; // Bail out if unranked.
174 // memref is flattened to base, align, offset, strides and sizes.
175 int64_t numArgs = 3 + memref.getRank() * 2;
176 auto isPtr = [](BlockArgument arg) {
177 return arg.getType().isa<LLVM::LLVMPointerType>();
178 };
179 auto isInt = [](BlockArgument arg) {
180 return arg.getType().isa<IntegerType>();
181 };
182 // Bail out if the next num_args are not the expected type.
183 if (static_cast<int64_t>(arguments.size()) < numArgs) break;
184 ArrayRef<BlockArgument> memrefArgs = arguments.take_front(numArgs);
185 if (!llvm::all_of(memrefArgs.take_front(2), isPtr)) break;
186 if (!llvm::all_of(memrefArgs.drop_front(2), isInt)) break;
187 // Replace memref_args with just memref_args[0] if memref has static shape.
188 if (memref.hasStaticShape() && memref.getLayout().isIdentity()) {
189 replaceStaticMemRefArguments(memrefArgs, memref, pointerType, rewriter);
190 unsigned argNumber = arguments.front().getArgNumber();
191 // Drop all but 'base' from the flattened memref arguments.
192 argsToDrop.set(argNumber + 1, argNumber + numArgs);
193 }
194 arguments = arguments.drop_front(numArgs);
195 }
196 if (argsToDrop.none()) {
197 return rewriter.notifyMatchFailure(funcOp, "no static shapes");
198 }
199 rewriter.updateRootInPlace(funcOp, [&] {
200 funcOp.eraseArguments(argsToDrop);
201 auto argTypes =
202 llvm::to_vector(TypeRange{ValueRange{funcOp.getArguments()}});
203 funcOp.setType(LLVM::LLVMFunctionType::get(
204 funcOp.getFunctionType().getReturnType(), argTypes));
205 });
206 return success();
207 }
208
runOnOperation()209 void PropagateStaticShapesToKernelPass::runOnOperation() {
210 MLIRContext* ctx = getOperation().getContext();
211 auto pointerType = [&]() -> FailureOr<Type> {
212 if (ptr_type_opt.empty()) return this->pointerType;
213 Type type = parseType(ptr_type_opt, ctx);
214 if (!type)
215 return emitError(UnknownLoc::get(ctx), "invalid convert_pointer_args");
216 return type;
217 }();
218 if (failed(pointerType)) return signalPassFailure();
219 SymbolTable symbolTable(getOperation());
220 RewritePatternSet patterns(ctx);
221 patterns.add<PropagateStaticShapesPattern>(ctx, symbolTable, *pointerType);
222 FrozenRewritePatternSet frozen(std::move(patterns));
223 auto callback = [&](gpu::GPUModuleOp gpuModule) -> WalkResult {
224 return applyPatternsAndFoldGreedily(gpuModule, frozen);
225 };
226 if (getOperation()->walk(callback).wasInterrupted())
227 return signalPassFailure();
228 }
229
230 std::unique_ptr<OperationPass<ModuleOp>>
createPropagateStaticShapesToKernelPass(Type pointerType)231 createPropagateStaticShapesToKernelPass(Type pointerType) {
232 return std::make_unique<PropagateStaticShapesToKernelPass>(pointerType);
233 }
234
235 } // namespace mlir
236