1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <algorithm>
17 #include <functional>
18 #include <iterator>
19 #include <memory>
20 #include <numeric>
21 #include <utility>
22 
23 #include "llvm/ADT/STLExtras.h"
24 #include "llvm/ADT/STLFunctionalExtras.h"
25 #include "llvm/ADT/SmallVector.h"
26 #include "mlir-hlo/Transforms/PassDetail.h"
27 #include "mlir-hlo/Transforms/passes.h"
28 #include "mlir/AsmParser/AsmParser.h"
29 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
30 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
31 #include "mlir/Dialect/LLVMIR/LLVMTypes.h"
32 #include "mlir/IR/BuiltinAttributes.h"
33 #include "mlir/IR/BuiltinOps.h"
34 #include "mlir/IR/BuiltinTypes.h"
35 #include "mlir/IR/MLIRContext.h"
36 #include "mlir/IR/PatternMatch.h"
37 #include "mlir/IR/SymbolTable.h"
38 #include "mlir/IR/TypeRange.h"
39 #include "mlir/IR/Value.h"
40 #include "mlir/IR/Visitors.h"
41 #include "mlir/Pass/Pass.h"
42 #include "mlir/Rewrite/FrozenRewritePatternSet.h"
43 #include "mlir/Support/LLVM.h"
44 #include "mlir/Support/LogicalResult.h"
45 #include "mlir/Transforms/DialectConversion.h"
46 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
47 
48 namespace mlir {
49 
50 namespace {
51 
52 // Replaces flattened memref arguments (base, aligned, offset, sizes, strides)
53 // with base and constants if the corresponding launch_func ops argument has
54 // static shape. Removes all arguments but base.
55 class PropagateStaticShapesPattern : public OpRewritePattern<LLVM::LLVMFuncOp> {
56  public:
PropagateStaticShapesPattern(MLIRContext * ctx,SymbolTable & symbolTable,Type pointerType)57   explicit PropagateStaticShapesPattern(MLIRContext* ctx,
58                                         SymbolTable& symbolTable,
59                                         Type pointerType)
60       : OpRewritePattern<LLVM::LLVMFuncOp>(ctx),
61         symbolTable(symbolTable),
62         pointerType(pointerType) {}
63 
64  private:
65   LogicalResult matchAndRewrite(LLVM::LLVMFuncOp funcOp,
66                                 PatternRewriter& rewriter) const final;
67 
68   SymbolTable& symbolTable;
69   Type pointerType;
70 };
71 
72 class PropagateStaticShapesToKernelPass
73     : public PropagateStaticShapesToKernelPassBase<
74           PropagateStaticShapesToKernelPass> {
75  public:
PropagateStaticShapesToKernelPass(Type pointerType)76   explicit PropagateStaticShapesToKernelPass(Type pointerType)
77       : pointerType(pointerType) {}
78 
79  private:
80   void runOnOperation() override;
81 
82   Type pointerType;
83 };
84 
85 }  // namespace
86 
87 // Replaces 'arguments' (containing 'base', 'align', 'offset', 'sizes[rank]',
88 // 'strides[rank]') corresponding to statically shaped 'memref' with the base
89 // pointer and constants. The base pointer is changed to 'pointer_type' if
90 // provided.
replaceStaticMemRefArguments(ArrayRef<BlockArgument> arguments,MemRefType memref,Type pointerType,PatternRewriter & rewriter)91 static void replaceStaticMemRefArguments(ArrayRef<BlockArgument> arguments,
92                                          MemRefType memref, Type pointerType,
93                                          PatternRewriter& rewriter) {
94   assert(arguments.size() >= 3 && "expected at least 3 arguments");
95   Value base = arguments[0];
96   if (pointerType) {
97     // Change base to given type, replace with bitcast back to original type.
98     Type type = base.getType();
99     base.setType(pointerType);
100     auto cast = rewriter.create<LLVM::BitcastOp>(base.getLoc(), type, base);
101     base.replaceAllUsesExcept(/*newValue=*/cast, /*exceptedUser=*/cast);
102     base = cast.getResult();
103   }
104 
105   // Replace uses of 'aligned' with 'base'.
106   arguments[1].replaceAllUsesWith(base);
107   // Replace uses of 'offset' with constant.
108   arguments[2].replaceAllUsesWith(rewriter.create<LLVM::ConstantOp>(
109       arguments[2].getLoc(), arguments[2].getType(),
110       rewriter.getIntegerAttr(arguments[2].getType(), 0)));
111   auto replace = [&](ArrayRef<int64_t> values,
112                      ArrayRef<BlockArgument> arguments) {
113     for (auto valAndArg : llvm::zip_first(values, arguments)) {
114       auto argument = std::get<1>(valAndArg);
115       argument.replaceAllUsesWith(rewriter.create<LLVM::ConstantOp>(
116           argument.getLoc(), argument.getType(),
117           rewriter.getIntegerAttr(argument.getType(), std::get<0>(valAndArg))));
118     }
119   };
120   // Replace 'sizes' and 'strides' with constants.
121   replace(memref.getShape(), arguments.drop_front(3));
122   auto strides = llvm::to_vector<4>(memref.getShape());
123   std::partial_sum(strides.rbegin(), strides.rend(), strides.rbegin(),
124                    std::multiplies<int64_t>());
125   strides.push_back(1);
126   replace(llvm::makeArrayRef(strides).drop_front(),
127           arguments.drop_front(3 + memref.getRank()));
128 }
129 
matchAndRewrite(LLVM::LLVMFuncOp funcOp,PatternRewriter & rewriter) const130 LogicalResult PropagateStaticShapesPattern::matchAndRewrite(
131     LLVM::LLVMFuncOp funcOp, PatternRewriter& rewriter) const {
132   if (funcOp.isExternal())
133     return rewriter.notifyMatchFailure(funcOp, "external");
134   if (!funcOp->getAttrOfType<UnitAttr>(
135           gpu::GPUDialect::getKernelFuncAttrName())) {
136     return rewriter.notifyMatchFailure(funcOp, "missing gpu.kernel");
137   }
138 
139   // Collect gpu.launch_func ops which launch the func_op kernel.
140   Optional<SymbolTable::UseRange> symUses =
141       symbolTable.getSymbolUses(funcOp, symbolTable.getOp());
142   if (!symUses)
143     return rewriter.notifyMatchFailure(funcOp, "failed to find symbol uses");
144   auto mapper = [](SymbolTable::SymbolUse symUse) {
145     return dyn_cast<gpu::LaunchFuncOp>(symUse.getUser());
146   };
147   auto filter = [](gpu::LaunchFuncOp op) -> bool { return op; };
148   auto launchOps = llvm::to_vector(
149       llvm::make_filter_range(llvm::map_range(*symUses, mapper), filter));
150   if (launchOps.empty())
151     return rewriter.notifyMatchFailure(funcOp, "no gpu.launch_func uses");
152   OperandRange operands = launchOps.begin()->operands();
153   if (llvm::any_of(launchOps, [&](gpu::LaunchFuncOp op) {
154         return op.operands().getTypes() != operands.getTypes();
155       })) {
156     return rewriter.notifyMatchFailure(funcOp, "operand types mismatch");
157   }
158 
159   rewriter.setInsertionPointToStart(&funcOp.front());
160   BitVector argsToDrop(funcOp.getNumArguments());
161   // Loop over the launch_op's 'operands' containing scalars and memrefs and the
162   // func_ops's 'arguments' containing scalars and flattened memrefs. When an
163   // operand is a staticlly shaped memref, replace the range of arguments
164   // corresponding to the flattened memref with just the 'base' pointer.
165   for (auto arguments = funcOp.getArguments(); !arguments.empty();
166        operands = operands.drop_front()) {
167     auto memref = operands.getTypes().front().dyn_cast<MemRefType>();
168     if (!memref) {
169       // Scalar argument, advance by one.
170       arguments = arguments.drop_front();
171       continue;
172     }
173     if (!memref.hasRank()) break;  // Bail out if unranked.
174     // memref is flattened to base, align, offset, strides and sizes.
175     int64_t numArgs = 3 + memref.getRank() * 2;
176     auto isPtr = [](BlockArgument arg) {
177       return arg.getType().isa<LLVM::LLVMPointerType>();
178     };
179     auto isInt = [](BlockArgument arg) {
180       return arg.getType().isa<IntegerType>();
181     };
182     // Bail out if the next num_args are not the expected type.
183     if (static_cast<int64_t>(arguments.size()) < numArgs) break;
184     ArrayRef<BlockArgument> memrefArgs = arguments.take_front(numArgs);
185     if (!llvm::all_of(memrefArgs.take_front(2), isPtr)) break;
186     if (!llvm::all_of(memrefArgs.drop_front(2), isInt)) break;
187     // Replace memref_args with just memref_args[0] if memref has static shape.
188     if (memref.hasStaticShape() && memref.getLayout().isIdentity()) {
189       replaceStaticMemRefArguments(memrefArgs, memref, pointerType, rewriter);
190       unsigned argNumber = arguments.front().getArgNumber();
191       // Drop all but 'base' from the flattened memref arguments.
192       argsToDrop.set(argNumber + 1, argNumber + numArgs);
193     }
194     arguments = arguments.drop_front(numArgs);
195   }
196   if (argsToDrop.none()) {
197     return rewriter.notifyMatchFailure(funcOp, "no static shapes");
198   }
199   rewriter.updateRootInPlace(funcOp, [&] {
200     funcOp.eraseArguments(argsToDrop);
201     auto argTypes =
202         llvm::to_vector(TypeRange{ValueRange{funcOp.getArguments()}});
203     funcOp.setType(LLVM::LLVMFunctionType::get(
204         funcOp.getFunctionType().getReturnType(), argTypes));
205   });
206   return success();
207 }
208 
runOnOperation()209 void PropagateStaticShapesToKernelPass::runOnOperation() {
210   MLIRContext* ctx = getOperation().getContext();
211   auto pointerType = [&]() -> FailureOr<Type> {
212     if (ptr_type_opt.empty()) return this->pointerType;
213     Type type = parseType(ptr_type_opt, ctx);
214     if (!type)
215       return emitError(UnknownLoc::get(ctx), "invalid convert_pointer_args");
216     return type;
217   }();
218   if (failed(pointerType)) return signalPassFailure();
219   SymbolTable symbolTable(getOperation());
220   RewritePatternSet patterns(ctx);
221   patterns.add<PropagateStaticShapesPattern>(ctx, symbolTable, *pointerType);
222   FrozenRewritePatternSet frozen(std::move(patterns));
223   auto callback = [&](gpu::GPUModuleOp gpuModule) -> WalkResult {
224     return applyPatternsAndFoldGreedily(gpuModule, frozen);
225   };
226   if (getOperation()->walk(callback).wasInterrupted())
227     return signalPassFailure();
228 }
229 
230 std::unique_ptr<OperationPass<ModuleOp>>
createPropagateStaticShapesToKernelPass(Type pointerType)231 createPropagateStaticShapesToKernelPass(Type pointerType) {
232   return std::make_unique<PropagateStaticShapesToKernelPass>(pointerType);
233 }
234 
235 }  // namespace mlir
236