1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <algorithm>
16 #include <cstdint>
17 #include <memory>
18 #include <string>
19 #include <utility>
20
21 #include "llvm/ADT/STLExtras.h"
22 #include "llvm/ADT/StringExtras.h"
23 #include "llvm/ADT/StringRef.h"
24 #include "mlir-hlo/Dialect/lhlo/IR/lhlo_ops.h"
25 #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
26 #include "mlir-hlo/Transforms/GPUPassDetail.h"
27 #include "mlir-hlo/Transforms/gpu_passes.h"
28 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
29 #include "mlir/Dialect/Func/IR/FuncOps.h"
30 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
31 #include "mlir/Dialect/MemRef/IR/MemRef.h"
32 #include "mlir/IR/BlockAndValueMapping.h"
33 #include "mlir/IR/BuiltinOps.h"
34 #include "mlir/IR/BuiltinTypes.h"
35 #include "mlir/IR/MLIRContext.h"
36 #include "mlir/IR/SymbolTable.h"
37 #include "mlir/IR/Visitors.h"
38 #include "mlir/Pass/Pass.h"
39 #include "mlir/Pass/PassManager.h"
40 #include "mlir/Rewrite/FrozenRewritePatternSet.h"
41 #include "mlir/Transforms/DialectConversion.h"
42 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
43 #include "mlir/Transforms/RegionUtils.h"
44
45 namespace mlir {
46
47 namespace {
48 class GpuFusionRewritePass
49 : public GpuFusionRewritePassBase<GpuFusionRewritePass> {
50 public:
51 explicit GpuFusionRewritePass() = default;
52 using Pass::runPipeline; // Give FusionRewritePattern access.
53
54 private:
55 void getDependentDialects(DialectRegistry& registry) const override;
56 void runOnOperation() override;
57 };
58
59 // Rewrites `lmhlo.fusion` to `gpu.launch_func` for fusion regions that the
60 // HLO to GPU pipeline can handle.
61 class FusionRewritePattern : public OpRewritePattern<lmhlo::FusionOp> {
62 public:
63 explicit FusionRewritePattern(MLIRContext* ctx,
64 GpuFusionRewritePass& parentPass,
65 SymbolTable& symbolTable);
66
67 private:
68 LogicalResult matchAndRewrite(lmhlo::FusionOp fusionOp,
69 PatternRewriter& rewriter) const override;
70
71 // Returns whether all ops in fusionOp's region are legal to rewritableTarget.
72 bool isRewritable(lmhlo::FusionOp fusionOp) const;
73
74 // Annotates gpu.launch_func with attribute specifying written operands.
75 //
76 // gpu.launch_func ..., %memref, ...
77 // %tensor = bufferize.to_tensor %memref
78 // memref.tensor_store %tensor, %argN
79 //
80 // is replaced with:
81 //
82 // gpu.launch_func ..., %argN, ... { written_operands = [..., unit, ...] }
83 //
84 // The 'written_operands' attribute is used later to retrieve which
85 // gpu.launch_func arguments are written vs. just read.
86 static void annotateLaunchFunc(func::FuncOp funcOp,
87 PatternRewriter& rewriter);
88
89 // Returns target where lowerable fusion ops are marked legal.
90 static ConversionTarget getRewritableTarget(MLIRContext* ctx);
91
92 GpuFusionRewritePass& parentPass;
93 SymbolTable& symbolTable;
94 ConversionTarget rewritableTarget = getRewritableTarget(getContext());
95 };
96 } // namespace
97
98 // Name of the 'gpu.launch_func' attribute which specifies the written operands.
99 static constexpr llvm::StringLiteral kWrittenOperandsAttrName =
100 "written_operands";
101
getDependentDialects(DialectRegistry & registry) const102 void GpuFusionRewritePass::getDependentDialects(
103 DialectRegistry& registry) const {
104 OpPassManager passManager;
105 createHloToGpuPipeline(passManager, /*tileSizes=*/{}, /*unrollFactors=*/{});
106 passManager.getDependentDialects(registry);
107 }
108
runOnOperation()109 void GpuFusionRewritePass::runOnOperation() {
110 SymbolTable symbolTable(getOperation());
111 auto pattern =
112 std::make_unique<FusionRewritePattern>(&getContext(), *this, symbolTable);
113 mlir::FrozenRewritePatternSet patterns({&getContext(), std::move(pattern)});
114 auto callback = [&](lmhlo::FusionOp fusion) {
115 if (failed(applyOpPatternsAndFold(fusion, patterns)))
116 return WalkResult::interrupt();
117 return WalkResult::advance();
118 };
119 if (getOperation().walk(callback).wasInterrupted())
120 return signalPassFailure();
121 }
122
FusionRewritePattern(MLIRContext * ctx,GpuFusionRewritePass & parentPass,SymbolTable & symbolTable)123 FusionRewritePattern::FusionRewritePattern(MLIRContext* ctx,
124 GpuFusionRewritePass& parentPass,
125 SymbolTable& symbolTable)
126 : OpRewritePattern<lmhlo::FusionOp>::OpRewritePattern(ctx),
127 parentPass(parentPass),
128 symbolTable(symbolTable) {}
129
130 // Returns the number of elements each thread should handle for 'type'.
131 // The intention is that loads and stores are vectorized later on to this width
132 // to maximize memory throughput.
getElementsPerThread(TensorType type)133 static int64_t getElementsPerThread(TensorType type) {
134 // Don't vectorize if the number of elements cannot saturate the GPU.
135 // Use a coarse heuristic because we don't know the target GPU here.
136 const int64_t kNumFp32AlusOnV100 = 5376;
137 if (type.getNumElements() < kNumFp32AlusOnV100) return 1;
138
139 // Don't vectorize if element type is not int or float.
140 if (!type.getElementType().isIntOrFloat()) return 1;
141
142 // Vectorize so that loads and stores are 128 bits per thread.
143 return 128 / type.getElementType().getIntOrFloatBitWidth();
144 }
145
146 // Returns the number of threads per block to use for 'type', given the number
147 // of elements each thread handles. The returned block size is in the [128, 384]
148 // range, preferrably close to 256 and evenly dividing the number of threads
149 // required to handle all elements in 'type'.
getThreadsPerBlock(TensorType type,int64_t elementsPerThread)150 static int64_t getThreadsPerBlock(TensorType type, int64_t elementsPerThread) {
151 int64_t numThreads =
152 llvm::divideCeil(type.getNumElements(), elementsPerThread);
153
154 // Use a single block for small problems.
155 if (numThreads < 256) return numThreads;
156
157 // Use 256 if that block size evenly divides the problem.
158 if (numThreads % 256 == 0) return 256;
159
160 int64_t elementSizeBits = 32;
161 if (type.getElementType().isIntOrFloat())
162 elementSizeBits = type.getElementType().getIntOrFloatBitWidth();
163 int64_t threadSizeBits = elementSizeBits * elementsPerThread;
164
165 // Search block sizes in the [128, 384] range near 256 with decreasing
166 // power-of-2 factor, down to a multiple of a cache line (assumed to be 1024
167 // bits). Use the first one that evenly divides the problem, which allows the
168 // loop tail to be optimized away.
169 for (int i = 128; i * threadSizeBits >= 1024; i /= 2) {
170 // 2 * i: earlier iterations already handled even multiples of i.
171 for (int blockSize = 256 - i; blockSize >= 128; blockSize -= 2 * i)
172 if (numThreads % blockSize == 0) return blockSize;
173 for (int blockSize = 256 + i; blockSize <= 384; blockSize += 2 * i)
174 if (numThreads % blockSize == 0) return blockSize;
175 }
176
177 // None of the checked block sizes evenly divides the number of required
178 // threads. Use a default of 256 and accept the loop tail.
179 return 256;
180 }
181
matchAndRewrite(lmhlo::FusionOp fusionOp,PatternRewriter & rewriter) const182 LogicalResult FusionRewritePattern::matchAndRewrite(
183 lmhlo::FusionOp fusionOp, PatternRewriter& rewriter) const {
184 // If fusion_op (including its region) is not legal by rewriteable_target,
185 // we expect lowering to GPU to fail or produce incorrect results.
186 if (!isRewritable(fusionOp))
187 return rewriter.notifyMatchFailure(fusionOp, "not rewritable");
188
189 // Collect values in fusion region defined above.
190 SetVector<Value> captures;
191 getUsedValuesDefinedAbove(fusionOp->getRegions(), captures);
192
193 // Converts statically shaped types to their 1D equivalent. This only works
194 // for element wise fusions and will have to become a more sophisticated
195 // pass when e.g. broadcasts are involved.
196 TypeConverter converter;
197 converter.addConversion([](Type type) { return type; });
198 converter.addConversion([](ShapedType type) {
199 if (!type.hasStaticShape()) return type;
200 return type.clone(type.getNumElements());
201 });
202 converter.addConversion([&](MemRefType type) {
203 if (!type.hasStaticShape() || !type.getLayout().isIdentity()) return type;
204 return MemRefType::get(type.getNumElements(), type.getElementType(),
205 MemRefLayoutAttrInterface(), type.getMemorySpace());
206 });
207
208 // Create a new module with a function, clone fusion region into it.
209 Location loc = fusionOp.getLoc();
210 auto moduleOp = rewriter.create<ModuleOp>(loc);
211 rewriter.setInsertionPointToEnd(moduleOp.getBody());
212 auto argTypes = llvm::to_vector(llvm::map_range(captures, [&](Value value) {
213 return converter.convertType(value.getType());
214 }));
215 auto funcType = rewriter.getFunctionType(argTypes, llvm::None);
216 auto funcOp = rewriter.create<func::FuncOp>(loc, "fusion", funcType);
217 rewriter.setInsertionPointToEnd(funcOp.addEntryBlock());
218 BlockAndValueMapping mapping;
219 for (const auto& [from, to] :
220 llvm::zip_first(captures, funcOp.getArguments())) {
221 mapping.map(from, to);
222 }
223 rewriter.cloneRegionBefore(fusionOp.getRegion(), funcOp.getRegion(),
224 funcOp.end(), mapping);
225 rewriter.mergeBlocks(&funcOp.back(), &funcOp.front());
226 funcOp->walk([&](Operation* op) {
227 for (auto result : op->getResults())
228 result.setType(converter.convertType(result.getType()));
229 });
230
231 // Create and run the HLO to GPU pass pipeline.
232 auto resultType =
233 fusionOp.getFusionResults().front().getType().cast<TensorType>();
234 int64_t unrollFactor = getElementsPerThread(resultType);
235 int64_t tileSize = getThreadsPerBlock(resultType, unrollFactor);
236 // Note: passManager.enableIRPrinting() doesn't do anything on dynamic pass
237 // pipelines. Printing needs to be enabled on the parent pass manager.
238 PassManager passManager(getContext());
239 createHloToGpuPipeline(passManager, {tileSize},
240 {&unrollFactor, unrollFactor > 1});
241 if (failed(parentPass.runPipeline(passManager, moduleOp)))
242 return rewriter.notifyMatchFailure(fusionOp, "failed to run pipeline");
243
244 // Clone the (single) gpu module with the device function.
245 rewriter.setInsertionPoint(fusionOp->getParentOfType<func::FuncOp>());
246 for (auto gpuModuleOp : moduleOp.getBodyRegion().getOps<gpu::GPUModuleOp>()) {
247 StringAttr symbol =
248 symbolTable.insert(rewriter.clone(*gpuModuleOp.getOperation()));
249 if (failed(symbolTable.replaceAllSymbolUses(gpuModuleOp, symbol, funcOp)))
250 return rewriter.notifyMatchFailure(fusionOp, "failed to replace symbol");
251 }
252 // Add 'gpu.container_module' attribute to parent module.
253 fusionOp->getParentOfType<ModuleOp>()->setAttr(
254 gpu::GPUDialect::getContainerModuleAttrName(), rewriter.getUnitAttr());
255
256 // Annotate gpu.launch_func loc and attribute specifying written operands.
257 funcOp->walk([&](gpu::LaunchFuncOp op) { op->setLoc(loc); });
258 annotateLaunchFunc(funcOp, rewriter);
259
260 // Remove dead allocations that were only used by store_op erased above.
261 RewritePatternSet patterns(getContext());
262 memref::AllocOp::getCanonicalizationPatterns(patterns, getContext());
263 if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns))))
264 return rewriter.notifyMatchFailure(fusionOp, "failed to canonicalize");
265
266 // Replace fusion op with host function region.
267 rewriter.splitBlock(&funcOp.front(),
268 funcOp.front().getTerminator()->getIterator());
269 rewriter.mergeBlockBefore(&funcOp.front(), fusionOp, captures.getArrayRef());
270
271 rewriter.eraseOp(fusionOp);
272 rewriter.eraseOp(moduleOp);
273
274 return success();
275 }
276
isRewritable(lmhlo::FusionOp fusionOp) const277 bool FusionRewritePattern::isRewritable(lmhlo::FusionOp fusionOp) const {
278 if (fusionOp.getFusionResults().size() != 1)
279 return false; // Only rewrite fusions with a single result.
280 auto callback = [this](Operation* op) {
281 if (rewritableTarget.isLegal(op)) return WalkResult::advance();
282 return WalkResult::interrupt();
283 };
284 return !fusionOp.getRegion().walk(callback).wasInterrupted();
285 }
286
annotateLaunchFunc(func::FuncOp funcOp,PatternRewriter & rewriter)287 void FusionRewritePattern::annotateLaunchFunc(func::FuncOp funcOp,
288 PatternRewriter& rewriter) {
289 llvm::SmallDenseMap<Operation*, SmallVector<bool>> writtenOperands;
290 funcOp.walk([&](memref::TensorStoreOp storeOp) {
291 auto toTensor =
292 storeOp.getTensor().getDefiningOp<bufferization::ToTensorOp>();
293 assert(toTensor && "not defined by bufferization.to_tensor");
294 for (auto& use : toTensor.getMemref().getUses()) {
295 Operation* user = use.getOwner();
296 if (isa<gpu::LaunchFuncOp>(user)) {
297 writtenOperands.try_emplace(user, user->getNumOperands())
298 .first->second[use.getOperandNumber()] = true;
299 use.set(storeOp.getMemref());
300 }
301 }
302 rewriter.eraseOp(storeOp);
303 rewriter.eraseOp(toTensor);
304 });
305 for (const auto& [op, vec] : writtenOperands)
306 op->setAttr(kWrittenOperandsAttrName, rewriter.getBoolArrayAttr(vec));
307 }
308
309 // Returns whether 'type' is can be lowered by the FusionRewritePattern.
isRewritableType(Type type)310 static bool isRewritableType(Type type) {
311 auto shapedType = type.cast<ShapedType>();
312 // Complex types are not yet supported.
313 if (shapedType.getElementType().isa<ComplexType>()) return false;
314 // Zero ranked shapes are not yet supported.
315 if (shapedType.getRank() == 0) return false;
316 // MemRef types need to have identity layout.
317 if (auto memrefType = shapedType.dyn_cast<MemRefType>())
318 return memrefType.getLayout().isIdentity();
319 // Unsigned integers are not yet supported.
320 if (auto intType = shapedType.getElementType().dyn_cast<IntegerType>())
321 return !intType.isUnsigned();
322 return true;
323 }
324
getRewritableTarget(MLIRContext * ctx)325 ConversionTarget FusionRewritePattern::getRewritableTarget(MLIRContext* ctx) {
326 ConversionTarget target(*ctx);
327 // Mark expected auxiliary ops as legal.
328 target.addLegalOp<lmhlo::TerminatorOp>();
329 target.addDynamicallyLegalOp<bufferization::ToTensorOp>(
330 [&](bufferization::ToTensorOp op) {
331 return isRewritableType(op.getMemref().getType()) &&
332 isRewritableType(op.getType());
333 });
334 target.addDynamicallyLegalOp<memref::TensorStoreOp>(
335 [&](memref::TensorStoreOp op) {
336 return isRewritableType(op.getTensor().getType()) &&
337 isRewritableType(op.getMemref().getType());
338 });
339 // For now, use an explicit allow-list of hlo ops inside the fusion. If any
340 // other op is present, the fusion will not be rewritten.
341 target.addLegalOp<
342 mhlo::AddOp, mhlo::AbsOp, mhlo::CbrtOp, mhlo::CeilOp, mhlo::CosineOp,
343 mhlo::DivOp, mhlo::ExpOp, mhlo::Expm1Op, mhlo::FloorOp, mhlo::LogOp,
344 mhlo::Log1pOp, mhlo::LogisticOp, mhlo::MulOp, mhlo::NegOp, mhlo::RoundOp,
345 /*unsupported: mhlo::RoundNearestEvenOp,*/ mhlo::RsqrtOp, mhlo::SignOp,
346 mhlo::SineOp, mhlo::SqrtOp, mhlo::SubtractOp, mhlo::TanhOp>();
347 return target;
348 }
349
createGpuFusionRewritePass()350 std::unique_ptr<OperationPass<ModuleOp>> createGpuFusionRewritePass() {
351 return std::make_unique<GpuFusionRewritePass>();
352 }
353
getWrittenOperandsAttribute(Operation * op)354 ArrayAttr getWrittenOperandsAttribute(Operation* op) {
355 return op->getAttrOfType<ArrayAttr>(kWrittenOperandsAttrName);
356 }
357
358 } // namespace mlir
359