1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef MLIR_HLO_DIALECT_MHLO_TRANSFORMS_PASSES_H 17 #define MLIR_HLO_DIALECT_MHLO_TRANSFORMS_PASSES_H 18 19 #include <memory> 20 #include <string> 21 22 #include "llvm/ADT/ArrayRef.h" 23 24 namespace mlir { 25 26 class ModuleOp; 27 class Operation; 28 template <typename T> 29 class OperationPass; 30 class Pass; 31 namespace func { 32 class FuncOp; 33 } // namespace func 34 namespace lmhlo { 35 class FusionOp; 36 } // namespace lmhlo 37 38 namespace mhlo { 39 40 /// Lowers HLO control flow ops to SCF. 41 std::unique_ptr<OperationPass<func::FuncOp>> createLegalizeControlFlowPass(); 42 43 /// Lowers sort to SCF & arith. 44 std::unique_ptr<OperationPass<func::FuncOp>> createLegalizeSortPass(); 45 46 /// Lowers from HLO dialect to Standard dialect. 47 std::unique_ptr<OperationPass<func::FuncOp>> createLegalizeToStdPass(); 48 49 /// Lowers from the CHLO dialect to the HLO dialect. 50 std::unique_ptr<OperationPass<func::FuncOp>> createChloLegalizeToHloPass( 51 bool legalizeBroadcasts = true, bool expandCompositions = true); 52 53 // Lowers from sparse ops in CHLO dialect to Linalg dialect. 54 std::unique_ptr<OperationPass<func::FuncOp>> 55 createLegalizeSparseChloToLinalgPass(); 56 57 // canonicalize reduction ops to be suitable for codegen. 58 std::unique_ptr<OperationPass<func::FuncOp>> 59 createHloCanonicalizeReductionPass(); 60 61 /// Lowers from HLO dialect to LHLO dialect allocating/deallocating temporary 62 /// buffers if necessary. 63 std::unique_ptr<OperationPass<ModuleOp>> createLegalizeToLhloPass(); 64 65 /// Lowers from HLO dialect to Memref dialect allocating/deallocating temporary 66 /// buffers if necessary. 67 std::unique_ptr<OperationPass<ModuleOp>> createLegalizeToMemrefPass(); 68 69 /// Lowers from HLO dialect to Arithmetic dialect. 70 std::unique_ptr<OperationPass<ModuleOp>> createLegalizeToArithmeticPass(); 71 72 // Lowers shape operations from HLO dialect to Standard dialect. 73 std::unique_ptr<OperationPass<func::FuncOp>> 74 createLegalizeHloShapeOpsToStandardPass(); 75 76 /// Lowers from MHLO dialect to THLO dialect. 77 std::unique_ptr<OperationPass<func::FuncOp>> createLegalizeMHLOToTHLOPass(); 78 79 /// Lowers from HLO dialect to Linalg dialect. 80 std::unique_ptr<OperationPass<func::FuncOp>> createLegalizeHloToLinalgPass(); 81 82 /// Lowers from HLO dialects dim operations. 83 std::unique_ptr<OperationPass<func::FuncOp>> 84 createLegalizeShapeComputationsPass(); 85 86 // Sinks constants implicitly captured in control flow regions. This is 87 // necessary to export to XLA. 88 std::unique_ptr<OperationPass<func::FuncOp>> 89 createSinkConstantsToControlFlowPass(); 90 91 /// Lowers trigonometric operations from the standard dialect to approximations 92 /// that do not use intrinsics. 93 std::unique_ptr<OperationPass<func::FuncOp>> 94 createLegalizeTrigonometricToApproximationPass(); 95 96 // Move dynamic broadcasts up over element-wise operations and broadcast the 97 // operands rather than the result. This will eventually allow for larger 98 // fusions. 99 std::unique_ptr<OperationPass<func::FuncOp>> createBroadcastPropagationPass(); 100 101 // Transformations that helps in restricting maximum rank among tensors in the 102 // pass. 103 std::unique_ptr<OperationPass<func::FuncOp>> createRestrictMaxRankPass(); 104 105 // Prepare moving dynamic broadcasts up over element-wise operations and 106 // broadcast the operands rather than the result. This will eventually allow for 107 // larger fusions. 108 std::unique_ptr<OperationPass<func::FuncOp>> createMergeAssumingOpsPass(); 109 110 // Iteratively reifies all shape computations in the function. 111 std::unique_ptr<OperationPass<func::FuncOp>> createShapeReificationPass(); 112 113 // Fuse shape constraints and merge all assuming regions. 114 std::unique_ptr<OperationPass<func::FuncOp>> createConstraintFusionPass(); 115 116 // Group reduction and parallel dimensions of reduction operations and realize 117 // them through equivalent 1D or 2D reductions. 118 std::unique_ptr<OperationPass<func::FuncOp>> createGroupReductionDimensionsPass( 119 bool preferColumnsReductions = true); 120 121 /// Rank specialization passes: 122 /// - Find compatible operations and group them together in one rank 123 /// specialization cluster. 124 /// - Lower rank specialization clusters to SCF and ranked operations. 125 std::unique_ptr<OperationPass<func::FuncOp>> 126 createRankSpecializationClusterPass(); 127 std::unique_ptr<OperationPass<func::FuncOp>> createRankSpecializationToSCFPass( 128 int64_t maxTargetRank = 5); 129 130 std::unique_ptr<OperationPass<func::FuncOp>> createOptimizeMhloPass(); 131 std::unique_ptr<OperationPass<func::FuncOp>> createLowerComplexPass(); 132 std::unique_ptr<::mlir::Pass> createLegalizeGeneralDotPass(); 133 std::unique_ptr<OperationPass<func::FuncOp>> 134 createLegalizeEinsumToDotGeneralPass(); 135 std::unique_ptr<OperationPass<func::FuncOp>> 136 createLegalizeGatherToTorchIndexSelectPass(); 137 std::unique_ptr<OperationPass<func::FuncOp>> createFlattenTuplePass(); 138 139 // Creates a pass for expanding mhlo.tuple ops. 140 std::unique_ptr<OperationPass<ModuleOp>> createExpandHloTuplesPass( 141 const std::string& entryFunctionName = "main"); 142 143 // Creates a pass for collapsing the mhlo.map if the map only has elementwise 144 // op. 145 std::unique_ptr<OperationPass<func::FuncOp>> createCollapseElementwiseMapPass(); 146 147 // Pass to replace unsigned types with signless integers. 148 std::unique_ptr<OperationPass<ModuleOp>> createConvertToSignlessPass(); 149 150 /// Creates pass for rewriting sparse mhlo ops. 151 std::unique_ptr<OperationPass<func::FuncOp>> createSparseRewritingPass(); 152 153 } // namespace mhlo 154 } // namespace mlir 155 156 #endif // MLIR_HLO_DIALECT_MHLO_TRANSFORMS_PASSES_H 157