1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef MLIR_HLO_DIALECT_MHLO_TRANSFORMS_REWRITERS_H
17 #define MLIR_HLO_DIALECT_MHLO_TRANSFORMS_REWRITERS_H
18
19 #include <functional>
20 #include <memory>
21
22 #include "mlir/IR/MLIRContext.h"
23 #include "mlir/IR/PatternMatch.h"
24 #include "mlir/Transforms/DialectConversion.h"
25
26 namespace mlir {
27 namespace bufferization {
28 class BufferizeTypeConverter;
29 } // namespace bufferization
30 namespace mhlo {
31
32 // Collection of rewrite patterns for lowering a general dot product.
33 void populateGeneralDotOpLoweringPatterns(RewritePatternSet *patterns,
34 MLIRContext *ctx);
35
36 // Collection of rewrite patterns for lowering complex operations to equivalent
37 // float operations.
38 void populateComplexLoweringPatterns(MLIRContext *context,
39 RewritePatternSet *patterns);
40
41 void populateOptimizeMhloPatterns(MLIRContext *context,
42 RewritePatternSet *patterns);
43
44 // Rewrite patterns for einsum to equivalent dot_general legalization.
45 void populateEinsumToDotGeneralPatterns(mlir::MLIRContext *context,
46 RewritePatternSet *patterns);
47
48 // Rewrite patterns for gather to equivalent torch index select legalization.
49 void populateGatherToTorchIndexSelectPatterns(mlir::MLIRContext *context,
50 RewritePatternSet *patterns);
51
52 void populateMhloToStdPatterns(RewritePatternSet *patterns, MLIRContext *ctx);
53
54 // Collection of rewrite patterns for lowering all mhlo ops to their
55 // lmhlo counterparts.
56 void populateDynamicHloToLhloConversionPattern(
57 MLIRContext *context, bufferization::BufferizeTypeConverter *converter,
58 RewritePatternSet *patterns);
59
60 // Collection of rewrite patterns for lowering of HLO to LHLO dialect.
61 void populateHloToLhloConversionPattern(
62 MLIRContext *context, bufferization::BufferizeTypeConverter *converter,
63 RewritePatternSet *patterns);
64
65 // Collection of rewrite patterns for lowering of HLO to arithmetic dialect.
66 void populateHloToArithmeticConversionPatterns(RewritePatternSet *patterns);
67
68 // Collection of rewrite patterns for lowering of shape operations from the HLO
69 // dialect to the standard dialect.
70 void populateHloShapeOpsToStandardConversionPattern(
71 MLIRContext *context, TypeConverter &typeConverter,
72 RewritePatternSet *patterns);
73
74 // Collection of rewrite patterns for lowering of HLO to Linalg dialect.
75 void populateHloToLinalgConversionPattern(MLIRContext *context,
76 TypeConverter &typeConverter,
77 RewritePatternSet *patterns);
78
79 // Collection of rewrite patterns for lowering of HLO dim operations.
80 void populateShapeComputationPatterns(MLIRContext *context,
81 RewritePatternSet *patterns);
82
83 // Converter to signless intergers to be used with linalg conversion patterns.
84 std::unique_ptr<TypeConverter> createHloToLinalgTypeConverter();
85
86 // Sets up legality definitions for materializing broadcasts.
87 void setupMaterializeBroadcastsLegality(MLIRContext *context,
88 ConversionTarget *conversionTarget);
89
90 // Populates a collection of rewrite patterns for materializing broadcast
91 // attributes to equivalent sequences of ops.
92 void populateMaterializeBroadcastsPatterns(MLIRContext *context,
93 RewritePatternSet *patterns);
94
95 // Populates a collection of rewrite patterns to realize element-wise operations
96 // on ranked tensors where possible.
97 void populateTransformUnrankedHloPatterns(MLIRContext *context,
98 RewritePatternSet *patterns);
99
100 void populateDynamicShapeFusionPatterns(MLIRContext *context,
101 RewritePatternSet *patterns);
102
103 // Populate a collection of conversion patterns for un-fusing
104 // batch_norm_inference into constituent HLO ops.
105 void populateUnfuseBatchNormInferencePattern(MLIRContext *context,
106 RewritePatternSet *patterns);
107
108 // Populate a collection of conversion patterns for un-fusing
109 // batch_norm_training into constituent HLO ops.
110 void populateUnfuseBatchNormTrainingPattern(MLIRContext *context,
111 RewritePatternSet *patterns);
112
113 // Populate a collection of conversion patterns for un-fusing
114 // // batch_norm_inference and batch_norm_training into constituent HLO ops.
populateUnfuseBatchNormPatterns(MLIRContext * context,RewritePatternSet * patterns)115 inline void populateUnfuseBatchNormPatterns(MLIRContext *context,
116 RewritePatternSet *patterns) {
117 populateUnfuseBatchNormInferencePattern(context, patterns);
118 populateUnfuseBatchNormTrainingPattern(context, patterns);
119 }
120
121 // Populates patterns that translate the trigonometric operations from the
122 // standard dialect to approximations that do not use intrinsics.
123 void populateTrigonometricToApproximationPatterns(MLIRContext *context,
124 RewritePatternSet *patterns);
125
126 // Populate patterns to prepare moving dynamic broadcasts up over element-wise
127 // operations and broadcast the operands rather than the result. This will
128 // eventually allow for larger fusions.
129 void populateMergeAssumingOpsPatterns(MLIRContext *context,
130 RewritePatternSet *patterns);
131
132 // Populate patterns for iterative shape reification.
133 void populateShapeReificationPatterns(MLIRContext *, RewritePatternSet *);
134
135 // Populate patterns to group reduction and parallel dimensions of reduction
136 // operations and realize them through equivalent 1D or 2D reductions.
137 void populateGroupReductionDimensionsPatterns(MLIRContext *context,
138 RewritePatternSet *patterns,
139 bool preferColumnsReductions);
140
141 /// Populate rank specialization clustering and lowering patterns.
142 void populateRankSpecializationClusterPatterns(MLIRContext *context,
143 RewritePatternSet *patterns);
144 void populateRankSpecializationToSCFPatterns(MLIRContext *context,
145 RewritePatternSet *patterns,
146 int64_t maxTargetRank);
147
148 /// Populate sparse tensor specific rewriting patterns.
149 void populateSparseRewritingPatterns(RewritePatternSet *patterns,
150 MLIRContext *ctx);
151
152 /// Populates sparse ops in CHLO to linalg rewriting patterns.
153 void populateLegalizeSparseChloToLinalgPatterns(MLIRContext *context,
154 TypeConverter &typeConverter,
155 RewritePatternSet *patterns);
156
157 } // namespace mhlo
158
159 namespace chlo {
160
161 // Populates a collection of conversion patterns for legalizing broadcasting
162 // client-HLO to their non-broadcasting counterparts.
163 void populateChloBroadcastingPatterns(MLIRContext *context,
164 RewritePatternSet *patterns);
165
166 // Populates a collection of conversion patterns for legalizing client-HLO to
167 // HLO by decomposing client-operations to corresponding sequences of more
168 // primitive operations. This does not include the
169 // PopulateChloBroadcastingPatterns above.
170 void populateDecomposeChloPatterns(MLIRContext *context,
171 RewritePatternSet *patterns);
172
173 } // namespace chlo
174
175 } // namespace mlir
176
177 #endif // MLIR_HLO_DIALECT_MHLO_TRANSFORMS_REWRITERS_H
178