xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Transforms/passes.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef MLIR_HLO_TRANSFORMS_PASSES_H
17 #define MLIR_HLO_TRANSFORMS_PASSES_H
18 
19 #include <functional>
20 #include <memory>
21 
22 #include "mlir/Pass/Pass.h"
23 
24 namespace mlir {
25 class ModuleOp;
26 class MLIRContext;
27 class ConversionTarget;
28 class DialectRegistry;
29 class PassManager;
30 
31 namespace func {
32 class FuncOp;
33 }  // namespace func
34 namespace bufferization {
35 class BufferizeTypeConverter;
36 }  // namespace bufferization
37 
38 using BufferizeDialectsCallback = std::function<void(DialectRegistry&)>;
39 using BufferizePatternsCallback = std::function<void(
40     ConversionTarget&, MLIRContext*, bufferization::BufferizeTypeConverter*,
41     RewritePatternSet*)>;
42 
43 //===----------------------------------------------------------------------===//
44 // Passes
45 //===----------------------------------------------------------------------===//
46 
47 /// Creates a pass that reuses buffers which are already allocated.
48 std::unique_ptr<OperationPass<func::FuncOp>> createBufferReusePass();
49 
50 /// Creates a pass to analyze shapes and to use that information for
51 /// shape-related optimizations.
52 std::unique_ptr<OperationPass<func::FuncOp>>
53 createSymbolicShapeOptimizationPass();
54 
55 /// Creates a pass that merges smaller buffer into bigger buffer to optimize
56 /// memory consumption.
57 std::unique_ptr<OperationPass<func::FuncOp>> createBufferPackingPass(
58     unsigned windowSize = 5);
59 
60 /// Creates a pass that tests the useranges of the UserangeAnalysis.
61 std::unique_ptr<OperationPass<func::FuncOp>> createTestUserangePass();
62 
63 /// Creates a pass that prints the analysis results of ShapeComponentsAnalysis.
64 std::unique_ptr<OperationPass<func::FuncOp>>
65 createTestShapeComponentAnalysisPass();
66 
67 /// Creates a pass that removes redundant operations that implement a
68 /// CopyOpInterface.
69 std::unique_ptr<OperationPass<func::FuncOp>> createCopyRemovalPass();
70 
71 /// Creates a pass that computes the allocated memory.
72 std::unique_ptr<OperationPass<func::FuncOp>> createMemoryCountPass();
73 
74 // Pass to lower index cast on tensors to tensor dialect.
75 std::unique_ptr<OperationPass<func::FuncOp>> createLowerIndexCastPass();
76 
77 // Pass to simplify shape ops.
78 std::unique_ptr<OperationPass<func::FuncOp>> createShapeSimplification();
79 
80 // Pass to tranform compute computations (hlo and linalg) on values to their
81 // corresponding counterparts on buffers. Also bufferizes function signatures.
82 std::unique_ptr<OperationPass<ModuleOp>> createComputeOpAndFuncBufferizePass();
83 
84 // Pass to tranform computations on values to their corresponding parts on
85 // buffers.
86 std::unique_ptr<OperationPass<ModuleOp>> createFinalBufferizePass();
87 
88 std::unique_ptr<OperationPass<ModuleOp>> createFinalBufferizePass(
89     uint64_t alignment, BufferizeDialectsCallback dc = {},
90     BufferizePatternsCallback pc = {});
91 
92 // Pass to propagate static shapes to kernel, reducing the kernel arguments
93 // from a flattened memref to a single pointer. The pointer is converted to
94 // `pointer_type`, if provided.
95 std::unique_ptr<OperationPass<ModuleOp>>
96 createPropagateStaticShapesToKernelPass(Type pointerType = {});
97 
98 // Creates a pass for collapsing multidimensional parallel loops into 1D loops.
99 std::unique_ptr<OperationPass<>> createCollapseParallelLoopsTo1DPass();
100 
101 // Creates a TileLoopsPass with tiles sizes provided through `tile_sizes`
102 // and unroll factors provided through `unroll_factors`.
103 std::unique_ptr<OperationPass<func::FuncOp>> createTileLoopsPass(
104     ArrayRef<int64_t> tileSizes = {}, ArrayRef<int64_t> unrollFactors = {});
105 
106 namespace hlo {
107 std::unique_ptr<OperationPass<ModuleOp>> createOneShotBufferizePass();
108 
109 std::unique_ptr<OperationPass<ModuleOp>> createGenericHostToLLVMPass();
110 }  // namespace hlo
111 }  // namespace mlir
112 
113 #endif  // MLIR_HLO_TRANSFORMS_PASSES_H
114