xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/mlir_hlo/lib/Transforms/gpu_fusion_rewrite.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <algorithm>
16 #include <cstdint>
17 #include <memory>
18 #include <string>
19 #include <utility>
20 
21 #include "llvm/ADT/STLExtras.h"
22 #include "llvm/ADT/StringExtras.h"
23 #include "llvm/ADT/StringRef.h"
24 #include "mlir-hlo/Dialect/lhlo/IR/lhlo_ops.h"
25 #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
26 #include "mlir-hlo/Transforms/GPUPassDetail.h"
27 #include "mlir-hlo/Transforms/gpu_passes.h"
28 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
29 #include "mlir/Dialect/Func/IR/FuncOps.h"
30 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
31 #include "mlir/Dialect/MemRef/IR/MemRef.h"
32 #include "mlir/IR/BlockAndValueMapping.h"
33 #include "mlir/IR/BuiltinOps.h"
34 #include "mlir/IR/BuiltinTypes.h"
35 #include "mlir/IR/MLIRContext.h"
36 #include "mlir/IR/SymbolTable.h"
37 #include "mlir/IR/Visitors.h"
38 #include "mlir/Pass/Pass.h"
39 #include "mlir/Pass/PassManager.h"
40 #include "mlir/Rewrite/FrozenRewritePatternSet.h"
41 #include "mlir/Transforms/DialectConversion.h"
42 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
43 #include "mlir/Transforms/RegionUtils.h"
44 
45 namespace mlir {
46 
47 namespace {
48 class GpuFusionRewritePass
49     : public GpuFusionRewritePassBase<GpuFusionRewritePass> {
50  public:
51   explicit GpuFusionRewritePass() = default;
52   using Pass::runPipeline;  // Give FusionRewritePattern access.
53 
54  private:
55   void getDependentDialects(DialectRegistry& registry) const override;
56   void runOnOperation() override;
57 };
58 
59 // Rewrites `lmhlo.fusion` to `gpu.launch_func` for fusion regions that the
60 // HLO to GPU pipeline can handle.
61 class FusionRewritePattern : public OpRewritePattern<lmhlo::FusionOp> {
62  public:
63   explicit FusionRewritePattern(MLIRContext* ctx,
64                                 GpuFusionRewritePass& parentPass,
65                                 SymbolTable& symbolTable);
66 
67  private:
68   LogicalResult matchAndRewrite(lmhlo::FusionOp fusionOp,
69                                 PatternRewriter& rewriter) const override;
70 
71   // Returns whether all ops in fusionOp's region are legal to rewritableTarget.
72   bool isRewritable(lmhlo::FusionOp fusionOp) const;
73 
74   // Annotates gpu.launch_func with attribute specifying written operands.
75   //
76   //   gpu.launch_func ..., %memref, ...
77   //   %tensor = bufferize.to_tensor %memref
78   //   memref.tensor_store %tensor, %argN
79   //
80   // is replaced with:
81   //
82   //   gpu.launch_func ..., %argN, ... { written_operands = [..., unit, ...] }
83   //
84   // The 'written_operands' attribute is used later to retrieve which
85   // gpu.launch_func arguments are written vs. just read.
86   static void annotateLaunchFunc(func::FuncOp funcOp,
87                                  PatternRewriter& rewriter);
88 
89   // Returns target where lowerable fusion ops are marked legal.
90   static ConversionTarget getRewritableTarget(MLIRContext* ctx);
91 
92   GpuFusionRewritePass& parentPass;
93   SymbolTable& symbolTable;
94   ConversionTarget rewritableTarget = getRewritableTarget(getContext());
95 };
96 }  // namespace
97 
98 // Name of the 'gpu.launch_func' attribute which specifies the written operands.
99 static constexpr llvm::StringLiteral kWrittenOperandsAttrName =
100     "written_operands";
101 
getDependentDialects(DialectRegistry & registry) const102 void GpuFusionRewritePass::getDependentDialects(
103     DialectRegistry& registry) const {
104   OpPassManager passManager;
105   createHloToGpuPipeline(passManager, /*tileSizes=*/{}, /*unrollFactors=*/{});
106   passManager.getDependentDialects(registry);
107 }
108 
runOnOperation()109 void GpuFusionRewritePass::runOnOperation() {
110   SymbolTable symbolTable(getOperation());
111   auto pattern =
112       std::make_unique<FusionRewritePattern>(&getContext(), *this, symbolTable);
113   mlir::FrozenRewritePatternSet patterns({&getContext(), std::move(pattern)});
114   auto callback = [&](lmhlo::FusionOp fusion) {
115     if (failed(applyOpPatternsAndFold(fusion, patterns)))
116       return WalkResult::interrupt();
117     return WalkResult::advance();
118   };
119   if (getOperation().walk(callback).wasInterrupted())
120     return signalPassFailure();
121 }
122 
FusionRewritePattern(MLIRContext * ctx,GpuFusionRewritePass & parentPass,SymbolTable & symbolTable)123 FusionRewritePattern::FusionRewritePattern(MLIRContext* ctx,
124                                            GpuFusionRewritePass& parentPass,
125                                            SymbolTable& symbolTable)
126     : OpRewritePattern<lmhlo::FusionOp>::OpRewritePattern(ctx),
127       parentPass(parentPass),
128       symbolTable(symbolTable) {}
129 
130 // Returns the number of elements each thread should handle for 'type'.
131 // The intention is that loads and stores are vectorized later on to this width
132 // to maximize memory throughput.
getElementsPerThread(TensorType type)133 static int64_t getElementsPerThread(TensorType type) {
134   // Don't vectorize if the number of elements cannot saturate the GPU.
135   // Use a coarse heuristic because we don't know the target GPU here.
136   const int64_t kNumFp32AlusOnV100 = 5376;
137   if (type.getNumElements() < kNumFp32AlusOnV100) return 1;
138 
139   // Don't vectorize if element type is not int or float.
140   if (!type.getElementType().isIntOrFloat()) return 1;
141 
142   // Vectorize so that loads and stores are 128 bits per thread.
143   return 128 / type.getElementType().getIntOrFloatBitWidth();
144 }
145 
146 // Returns the number of threads per block to use for 'type', given the number
147 // of elements each thread handles. The returned block size is in the [128, 384]
148 // range, preferrably close to 256 and evenly dividing the number of threads
149 // required to handle all elements in 'type'.
getThreadsPerBlock(TensorType type,int64_t elementsPerThread)150 static int64_t getThreadsPerBlock(TensorType type, int64_t elementsPerThread) {
151   int64_t numThreads =
152       llvm::divideCeil(type.getNumElements(), elementsPerThread);
153 
154   // Use a single block for small problems.
155   if (numThreads < 256) return numThreads;
156 
157   // Use 256 if that block size evenly divides the problem.
158   if (numThreads % 256 == 0) return 256;
159 
160   int64_t elementSizeBits = 32;
161   if (type.getElementType().isIntOrFloat())
162     elementSizeBits = type.getElementType().getIntOrFloatBitWidth();
163   int64_t threadSizeBits = elementSizeBits * elementsPerThread;
164 
165   // Search block sizes in the [128, 384] range near 256 with decreasing
166   // power-of-2 factor, down to a multiple of a cache line (assumed to be 1024
167   // bits). Use the first one that evenly divides the problem, which allows the
168   // loop tail to be optimized away.
169   for (int i = 128; i * threadSizeBits >= 1024; i /= 2) {
170     // 2 * i: earlier iterations already handled even multiples of i.
171     for (int blockSize = 256 - i; blockSize >= 128; blockSize -= 2 * i)
172       if (numThreads % blockSize == 0) return blockSize;
173     for (int blockSize = 256 + i; blockSize <= 384; blockSize += 2 * i)
174       if (numThreads % blockSize == 0) return blockSize;
175   }
176 
177   // None of the checked block sizes evenly divides the number of required
178   // threads. Use a default of 256 and accept the loop tail.
179   return 256;
180 }
181 
matchAndRewrite(lmhlo::FusionOp fusionOp,PatternRewriter & rewriter) const182 LogicalResult FusionRewritePattern::matchAndRewrite(
183     lmhlo::FusionOp fusionOp, PatternRewriter& rewriter) const {
184   // If fusion_op (including its region) is not legal by rewriteable_target,
185   // we expect lowering to GPU to fail or produce incorrect results.
186   if (!isRewritable(fusionOp))
187     return rewriter.notifyMatchFailure(fusionOp, "not rewritable");
188 
189   // Collect values in fusion region defined above.
190   SetVector<Value> captures;
191   getUsedValuesDefinedAbove(fusionOp->getRegions(), captures);
192 
193   // Converts statically shaped types to their 1D equivalent. This only works
194   // for element wise fusions and will have to become a more sophisticated
195   // pass when e.g. broadcasts are involved.
196   TypeConverter converter;
197   converter.addConversion([](Type type) { return type; });
198   converter.addConversion([](ShapedType type) {
199     if (!type.hasStaticShape()) return type;
200     return type.clone(type.getNumElements());
201   });
202   converter.addConversion([&](MemRefType type) {
203     if (!type.hasStaticShape() || !type.getLayout().isIdentity()) return type;
204     return MemRefType::get(type.getNumElements(), type.getElementType(),
205                            MemRefLayoutAttrInterface(), type.getMemorySpace());
206   });
207 
208   // Create a new module with a function, clone fusion region into it.
209   Location loc = fusionOp.getLoc();
210   auto moduleOp = rewriter.create<ModuleOp>(loc);
211   rewriter.setInsertionPointToEnd(moduleOp.getBody());
212   auto argTypes = llvm::to_vector(llvm::map_range(captures, [&](Value value) {
213     return converter.convertType(value.getType());
214   }));
215   auto funcType = rewriter.getFunctionType(argTypes, llvm::None);
216   auto funcOp = rewriter.create<func::FuncOp>(loc, "fusion", funcType);
217   rewriter.setInsertionPointToEnd(funcOp.addEntryBlock());
218   BlockAndValueMapping mapping;
219   for (const auto& [from, to] :
220        llvm::zip_first(captures, funcOp.getArguments())) {
221     mapping.map(from, to);
222   }
223   rewriter.cloneRegionBefore(fusionOp.getRegion(), funcOp.getRegion(),
224                              funcOp.end(), mapping);
225   rewriter.mergeBlocks(&funcOp.back(), &funcOp.front());
226   funcOp->walk([&](Operation* op) {
227     for (auto result : op->getResults())
228       result.setType(converter.convertType(result.getType()));
229   });
230 
231   // Create and run the HLO to GPU pass pipeline.
232   auto resultType =
233       fusionOp.getFusionResults().front().getType().cast<TensorType>();
234   int64_t unrollFactor = getElementsPerThread(resultType);
235   int64_t tileSize = getThreadsPerBlock(resultType, unrollFactor);
236   // Note: passManager.enableIRPrinting() doesn't do anything on dynamic pass
237   // pipelines. Printing needs to be enabled on the parent pass manager.
238   PassManager passManager(getContext());
239   createHloToGpuPipeline(passManager, {tileSize},
240                          {&unrollFactor, unrollFactor > 1});
241   if (failed(parentPass.runPipeline(passManager, moduleOp)))
242     return rewriter.notifyMatchFailure(fusionOp, "failed to run pipeline");
243 
244   // Clone the (single) gpu module with the device function.
245   rewriter.setInsertionPoint(fusionOp->getParentOfType<func::FuncOp>());
246   for (auto gpuModuleOp : moduleOp.getBodyRegion().getOps<gpu::GPUModuleOp>()) {
247     StringAttr symbol =
248         symbolTable.insert(rewriter.clone(*gpuModuleOp.getOperation()));
249     if (failed(symbolTable.replaceAllSymbolUses(gpuModuleOp, symbol, funcOp)))
250       return rewriter.notifyMatchFailure(fusionOp, "failed to replace symbol");
251   }
252   // Add 'gpu.container_module' attribute to parent module.
253   fusionOp->getParentOfType<ModuleOp>()->setAttr(
254       gpu::GPUDialect::getContainerModuleAttrName(), rewriter.getUnitAttr());
255 
256   // Annotate gpu.launch_func loc and attribute specifying written operands.
257   funcOp->walk([&](gpu::LaunchFuncOp op) { op->setLoc(loc); });
258   annotateLaunchFunc(funcOp, rewriter);
259 
260   // Remove dead allocations that were only used by store_op erased above.
261   RewritePatternSet patterns(getContext());
262   memref::AllocOp::getCanonicalizationPatterns(patterns, getContext());
263   if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns))))
264     return rewriter.notifyMatchFailure(fusionOp, "failed to canonicalize");
265 
266   // Replace fusion op with host function region.
267   rewriter.splitBlock(&funcOp.front(),
268                       funcOp.front().getTerminator()->getIterator());
269   rewriter.mergeBlockBefore(&funcOp.front(), fusionOp, captures.getArrayRef());
270 
271   rewriter.eraseOp(fusionOp);
272   rewriter.eraseOp(moduleOp);
273 
274   return success();
275 }
276 
isRewritable(lmhlo::FusionOp fusionOp) const277 bool FusionRewritePattern::isRewritable(lmhlo::FusionOp fusionOp) const {
278   if (fusionOp.getFusionResults().size() != 1)
279     return false;  // Only rewrite fusions with a single result.
280   auto callback = [this](Operation* op) {
281     if (rewritableTarget.isLegal(op)) return WalkResult::advance();
282     return WalkResult::interrupt();
283   };
284   return !fusionOp.getRegion().walk(callback).wasInterrupted();
285 }
286 
annotateLaunchFunc(func::FuncOp funcOp,PatternRewriter & rewriter)287 void FusionRewritePattern::annotateLaunchFunc(func::FuncOp funcOp,
288                                               PatternRewriter& rewriter) {
289   llvm::SmallDenseMap<Operation*, SmallVector<bool>> writtenOperands;
290   funcOp.walk([&](memref::TensorStoreOp storeOp) {
291     auto toTensor =
292         storeOp.getTensor().getDefiningOp<bufferization::ToTensorOp>();
293     assert(toTensor && "not defined by bufferization.to_tensor");
294     for (auto& use : toTensor.getMemref().getUses()) {
295       Operation* user = use.getOwner();
296       if (isa<gpu::LaunchFuncOp>(user)) {
297         writtenOperands.try_emplace(user, user->getNumOperands())
298             .first->second[use.getOperandNumber()] = true;
299         use.set(storeOp.getMemref());
300       }
301     }
302     rewriter.eraseOp(storeOp);
303     rewriter.eraseOp(toTensor);
304   });
305   for (const auto& [op, vec] : writtenOperands)
306     op->setAttr(kWrittenOperandsAttrName, rewriter.getBoolArrayAttr(vec));
307 }
308 
309 // Returns whether 'type' is can be lowered by the FusionRewritePattern.
isRewritableType(Type type)310 static bool isRewritableType(Type type) {
311   auto shapedType = type.cast<ShapedType>();
312   // Complex types are not yet supported.
313   if (shapedType.getElementType().isa<ComplexType>()) return false;
314   // Zero ranked shapes are not yet supported.
315   if (shapedType.getRank() == 0) return false;
316   // MemRef types need to have identity layout.
317   if (auto memrefType = shapedType.dyn_cast<MemRefType>())
318     return memrefType.getLayout().isIdentity();
319   // Unsigned integers are not yet supported.
320   if (auto intType = shapedType.getElementType().dyn_cast<IntegerType>())
321     return !intType.isUnsigned();
322   return true;
323 }
324 
getRewritableTarget(MLIRContext * ctx)325 ConversionTarget FusionRewritePattern::getRewritableTarget(MLIRContext* ctx) {
326   ConversionTarget target(*ctx);
327   // Mark expected auxiliary ops as legal.
328   target.addLegalOp<lmhlo::TerminatorOp>();
329   target.addDynamicallyLegalOp<bufferization::ToTensorOp>(
330       [&](bufferization::ToTensorOp op) {
331         return isRewritableType(op.getMemref().getType()) &&
332                isRewritableType(op.getType());
333       });
334   target.addDynamicallyLegalOp<memref::TensorStoreOp>(
335       [&](memref::TensorStoreOp op) {
336         return isRewritableType(op.getTensor().getType()) &&
337                isRewritableType(op.getMemref().getType());
338       });
339   // For now, use an explicit allow-list of hlo ops inside the fusion. If any
340   // other op is present, the fusion will not be rewritten.
341   target.addLegalOp<
342       mhlo::AddOp, mhlo::AbsOp, mhlo::CbrtOp, mhlo::CeilOp, mhlo::CosineOp,
343       mhlo::DivOp, mhlo::ExpOp, mhlo::Expm1Op, mhlo::FloorOp, mhlo::LogOp,
344       mhlo::Log1pOp, mhlo::LogisticOp, mhlo::MulOp, mhlo::NegOp, mhlo::RoundOp,
345       /*unsupported: mhlo::RoundNearestEvenOp,*/ mhlo::RsqrtOp, mhlo::SignOp,
346       mhlo::SineOp, mhlo::SqrtOp, mhlo::SubtractOp, mhlo::TanhOp>();
347   return target;
348 }
349 
createGpuFusionRewritePass()350 std::unique_ptr<OperationPass<ModuleOp>> createGpuFusionRewritePass() {
351   return std::make_unique<GpuFusionRewritePass>();
352 }
353 
getWrittenOperandsAttribute(Operation * op)354 ArrayAttr getWrittenOperandsAttribute(Operation* op) {
355   return op->getAttrOfType<ArrayAttr>(kWrittenOperandsAttrName);
356 }
357 
358 }  // namespace mlir
359