1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef MLIR_HLO_DIALECT_MHLO_TRANSFORMS_REWRITERS_H
17 #define MLIR_HLO_DIALECT_MHLO_TRANSFORMS_REWRITERS_H
18 
19 #include <functional>
20 #include <memory>
21 
22 #include "mlir/IR/MLIRContext.h"
23 #include "mlir/IR/PatternMatch.h"
24 #include "mlir/Transforms/DialectConversion.h"
25 
26 namespace mlir {
27 namespace bufferization {
28 class BufferizeTypeConverter;
29 }  // namespace bufferization
30 namespace mhlo {
31 
32 // Collection of rewrite patterns for lowering a general dot product.
33 void populateGeneralDotOpLoweringPatterns(RewritePatternSet *patterns,
34                                           MLIRContext *ctx);
35 
36 // Collection of rewrite patterns for lowering complex operations to equivalent
37 // float operations.
38 void populateComplexLoweringPatterns(MLIRContext *context,
39                                      RewritePatternSet *patterns);
40 
41 void populateOptimizeMhloPatterns(MLIRContext *context,
42                                   RewritePatternSet *patterns);
43 
44 // Rewrite patterns for einsum to equivalent dot_general legalization.
45 void populateEinsumToDotGeneralPatterns(mlir::MLIRContext *context,
46                                         RewritePatternSet *patterns);
47 
48 // Rewrite patterns for gather to equivalent torch index select legalization.
49 void populateGatherToTorchIndexSelectPatterns(mlir::MLIRContext *context,
50                                               RewritePatternSet *patterns);
51 
52 void populateMhloToStdPatterns(RewritePatternSet *patterns, MLIRContext *ctx);
53 
54 // Collection of rewrite patterns for lowering all mhlo ops to their
55 // lmhlo counterparts.
56 void populateDynamicHloToLhloConversionPattern(
57     MLIRContext *context, bufferization::BufferizeTypeConverter *converter,
58     RewritePatternSet *patterns);
59 
60 // Collection of rewrite patterns for lowering of HLO to LHLO dialect.
61 void populateHloToLhloConversionPattern(
62     MLIRContext *context, bufferization::BufferizeTypeConverter *converter,
63     RewritePatternSet *patterns);
64 
65 // Collection of rewrite patterns for lowering of HLO to arithmetic dialect.
66 void populateHloToArithmeticConversionPatterns(RewritePatternSet *patterns);
67 
68 // Collection of rewrite patterns for lowering of shape operations from the HLO
69 // dialect to the standard dialect.
70 void populateHloShapeOpsToStandardConversionPattern(
71     MLIRContext *context, TypeConverter &typeConverter,
72     RewritePatternSet *patterns);
73 
74 // Collection of rewrite patterns for lowering of HLO to Linalg dialect.
75 void populateHloToLinalgConversionPattern(MLIRContext *context,
76                                           TypeConverter &typeConverter,
77                                           RewritePatternSet *patterns);
78 
79 // Collection of rewrite patterns for lowering of HLO dim operations.
80 void populateShapeComputationPatterns(MLIRContext *context,
81                                       RewritePatternSet *patterns);
82 
83 // Converter to signless intergers to be used with linalg conversion patterns.
84 std::unique_ptr<TypeConverter> createHloToLinalgTypeConverter();
85 
86 // Sets up legality definitions for materializing broadcasts.
87 void setupMaterializeBroadcastsLegality(MLIRContext *context,
88                                         ConversionTarget *conversionTarget);
89 
90 // Populates a collection of rewrite patterns for materializing broadcast
91 // attributes to equivalent sequences of ops.
92 void populateMaterializeBroadcastsPatterns(MLIRContext *context,
93                                            RewritePatternSet *patterns);
94 
95 // Populates a collection of rewrite patterns to realize element-wise operations
96 // on ranked tensors where possible.
97 void populateTransformUnrankedHloPatterns(MLIRContext *context,
98                                           RewritePatternSet *patterns);
99 
100 void populateDynamicShapeFusionPatterns(MLIRContext *context,
101                                         RewritePatternSet *patterns);
102 
103 // Populate a collection of conversion patterns for un-fusing
104 // batch_norm_inference into constituent HLO ops.
105 void populateUnfuseBatchNormInferencePattern(MLIRContext *context,
106                                              RewritePatternSet *patterns);
107 
108 // Populate a collection of conversion patterns for un-fusing
109 // batch_norm_training into constituent HLO ops.
110 void populateUnfuseBatchNormTrainingPattern(MLIRContext *context,
111                                             RewritePatternSet *patterns);
112 
113 // Populate a collection of conversion patterns for un-fusing
114 // // batch_norm_inference and batch_norm_training into constituent HLO ops.
populateUnfuseBatchNormPatterns(MLIRContext * context,RewritePatternSet * patterns)115 inline void populateUnfuseBatchNormPatterns(MLIRContext *context,
116                                             RewritePatternSet *patterns) {
117   populateUnfuseBatchNormInferencePattern(context, patterns);
118   populateUnfuseBatchNormTrainingPattern(context, patterns);
119 }
120 
121 // Populates patterns that translate the trigonometric operations from the
122 // standard dialect to approximations that do not use intrinsics.
123 void populateTrigonometricToApproximationPatterns(MLIRContext *context,
124                                                   RewritePatternSet *patterns);
125 
126 // Populate patterns to prepare moving dynamic broadcasts up over element-wise
127 // operations and broadcast the operands rather than the result. This will
128 // eventually allow for larger fusions.
129 void populateMergeAssumingOpsPatterns(MLIRContext *context,
130                                       RewritePatternSet *patterns);
131 
132 // Populate patterns for iterative shape reification.
133 void populateShapeReificationPatterns(MLIRContext *, RewritePatternSet *);
134 
135 // Populate patterns to group reduction and parallel dimensions of reduction
136 // operations and realize them through equivalent 1D or 2D reductions.
137 void populateGroupReductionDimensionsPatterns(MLIRContext *context,
138                                               RewritePatternSet *patterns,
139                                               bool preferColumnsReductions);
140 
141 /// Populate rank specialization clustering and lowering patterns.
142 void populateRankSpecializationClusterPatterns(MLIRContext *context,
143                                                RewritePatternSet *patterns);
144 void populateRankSpecializationToSCFPatterns(MLIRContext *context,
145                                              RewritePatternSet *patterns,
146                                              int64_t maxTargetRank);
147 
148 /// Populate sparse tensor specific rewriting patterns.
149 void populateSparseRewritingPatterns(RewritePatternSet *patterns,
150                                      MLIRContext *ctx);
151 
152 /// Populates sparse ops in CHLO to linalg rewriting patterns.
153 void populateLegalizeSparseChloToLinalgPatterns(MLIRContext *context,
154                                                 TypeConverter &typeConverter,
155                                                 RewritePatternSet *patterns);
156 
157 }  // namespace mhlo
158 
159 namespace chlo {
160 
161 // Populates a collection of conversion patterns for legalizing broadcasting
162 // client-HLO to their non-broadcasting counterparts.
163 void populateChloBroadcastingPatterns(MLIRContext *context,
164                                       RewritePatternSet *patterns);
165 
166 // Populates a collection of conversion patterns for legalizing client-HLO to
167 // HLO by decomposing client-operations to corresponding sequences of more
168 // primitive operations. This does not include the
169 // PopulateChloBroadcastingPatterns above.
170 void populateDecomposeChloPatterns(MLIRContext *context,
171                                    RewritePatternSet *patterns);
172 
173 }  // namespace chlo
174 
175 }  // namespace mlir
176 
177 #endif  // MLIR_HLO_DIALECT_MHLO_TRANSFORMS_REWRITERS_H
178