1 #include <torch/csrc/jit/passes/graph_rewrite_helper.h>
2
3 #include <torch/csrc/jit/ir/subgraph_matcher.h>
4 #include <torch/csrc/jit/passes/constant_propagation.h>
5 #include <torch/csrc/jit/passes/subgraph_rewrite.h>
6
7 namespace torch::jit::graph_rewrite_helper {
8
getFuncName(Value * func_value)9 std::string getFuncName(Value* func_value) {
10 auto func = func_value->type()->expectRef<FunctionType>().function();
11 const auto& qname = func->qualname();
12 const auto& name = qname.qualifiedName();
13 auto rdot_idx = name.rfind('.');
14 if (rdot_idx != std::string::npos) {
15 return name.substr(rdot_idx + 1, name.length());
16 } else {
17 return name;
18 }
19 }
20
getValue(const std::string & name,const std::unordered_map<const Value *,Value * > & match_vmap,const std::unordered_map<std::string,Value * > & vmap)21 Value* getValue(
22 const std::string& name,
23 const std::unordered_map<const Value*, Value*>& match_vmap,
24 const std::unordered_map<std::string, Value*>& vmap) {
25 return match_vmap.at(vmap.at(name));
26 }
27
getIValue(const std::string & name,const std::unordered_map<const Value *,Value * > & match_vmap,const std::unordered_map<std::string,Value * > & vmap)28 std::optional<IValue> getIValue(
29 const std::string& name,
30 const std::unordered_map<const Value*, Value*>& match_vmap,
31 const std::unordered_map<std::string, Value*>& vmap) {
32 return toIValue(getValue(name, match_vmap, vmap));
33 }
34
getConvParams(const Match & match,const std::unordered_map<std::string,Value * > & vmap)35 static std::unordered_map<std::string, c10::IValue> getConvParams(
36 const Match& match,
37 const std::unordered_map<std::string, Value*>& vmap) {
38 std::unordered_map<std::string, c10::IValue> calc_values;
39 const auto& match_vmap = match.values_map;
40 auto transposed_value = getIValue("transposed", match_vmap, vmap).value();
41 calc_values["transposed"] = transposed_value;
42 auto output_padding_value =
43 getIValue("output_padding", match_vmap, vmap).value();
44 calc_values["output_padding"] = output_padding_value;
45 auto stride_value = getIValue("stride", match_vmap, vmap).value();
46 calc_values["stride"] = stride_value;
47 auto padding_value = getIValue("padding", match_vmap, vmap).value();
48 calc_values["padding"] = padding_value;
49 auto dilation_value = getIValue("dilation", match_vmap, vmap).value();
50 calc_values["dilation"] = dilation_value;
51 return calc_values;
52 }
53
replaceConvolutionWithAtenConv(std::shared_ptr<Graph> & graph)54 void replaceConvolutionWithAtenConv(std::shared_ptr<Graph>& graph) {
55 // TODO: remove constant prop in the pass
56 ConstantPropagation(graph);
57 std::string convolution_deprecated = R"(
58 graph(%a, %w, %b, %stride:int[], %padding:int[], %dilation:int[],
59 %transposed:bool, %output_padding:int[], %groups:int, %benchmark:bool,
60 %deterministic:bool, %cudnn_enabled:bool):
61 %r = aten::_convolution(%a, %w, %b, %stride, %padding, %dilation,
62 %transposed, %output_padding, %groups, %benchmark, %deterministic, %cudnn_enabled)
63 return (%r) )";
64
65 std::string convolution = R"(
66 graph(%a, %w, %b, %stride:int[], %padding:int[], %dilation:int[],
67 %transposed:bool, %output_padding:int[], %groups:int, %benchmark:bool,
68 %deterministic:bool, %cudnn_enabled:bool, %allow_tf32:bool):
69 %r = aten::_convolution(%a, %w, %b, %stride, %padding, %dilation,
70 %transposed, %output_padding, %groups, %benchmark, %deterministic, %cudnn_enabled, %allow_tf32)
71 return (%r) )";
72
73 std::string conv2d_for_deprecated_conv = R"(
74 graph(%a, %w, %b, %stride:int[], %padding:int[], %dilation:int[],
75 %transposed:bool, %output_padding:int[], %groups:int, %benchmark:bool,
76 %deterministic:bool, %cudnn_enabled:bool):
77 %r = aten::conv2d(%a, %w, %b, %stride, %padding, %dilation, %groups)
78 return (%r) )";
79 std::string conv2d = R"(
80 graph(%a, %w, %b, %stride:int[], %padding:int[], %dilation:int[],
81 %transposed:bool, %output_padding:int[], %groups:int, %benchmark:bool,
82 %deterministic:bool, %cudnn_enabled:bool, %allow_tf32:bool):
83 %r = aten::conv2d(%a, %w, %b, %stride, %padding, %dilation, %groups)
84 return (%r) )";
85
86 std::string conv1d_for_deprecated_conv = R"(
87 graph(%a, %w, %b, %stride:int[], %padding:int[], %dilation:int[],
88 %transposed:bool, %output_padding:int[], %groups:int, %benchmark:bool,
89 %deterministic:bool, %cudnn_enabled:bool):
90 %r = aten::conv1d(%a, %w, %b, %stride, %padding, %dilation, %groups)
91 return (%r) )";
92 std::string conv1d = R"(
93 graph(%a, %w, %b, %stride:int[], %padding:int[], %dilation:int[],
94 %transposed:bool, %output_padding:int[], %groups:int, %benchmark:bool,
95 %deterministic:bool, %cudnn_enabled:bool, %allow_tf32:bool):
96 %r = aten::conv1d(%a, %w, %b, %stride, %padding, %dilation, %groups)
97 return (%r) )";
98
99 std::string conv3d_for_deprecated_conv = R"(
100 graph(%a, %w, %b, %stride:int[], %padding:int[], %dilation:int[],
101 %transposed:bool, %output_padding:int[], %groups:int, %benchmark:bool,
102 %deterministic:bool, %cudnn_enabled:bool):
103 %r = aten::conv3d(%a, %w, %b, %stride, %padding, %dilation, %groups)
104 return (%r) )";
105 std::string conv3d = R"(
106 graph(%a, %w, %b, %stride:int[], %padding:int[], %dilation:int[],
107 %transposed:bool, %output_padding:int[], %groups:int, %benchmark:bool,
108 %deterministic:bool, %cudnn_enabled:bool, %allow_tf32:bool):
109 %r = aten::conv3d(%a, %w, %b, %stride, %padding, %dilation, %groups)
110 return (%r) )";
111
112 std::string conv_transpose1d_for_deprecated_conv = R"(
113 graph(%a, %w, %b, %stride:int[], %padding:int[], %dilation:int[],
114 %transposed:bool, %output_padding:int[], %groups:int, %benchmark:bool,
115 %deterministic:bool, %cudnn_enabled:bool):
116 %r = aten::conv_transpose1d(%a, %w, %b, %stride, %padding, %output_padding, %groups, %dilation)
117 return (%r) )";
118
119 std::string conv_transpose1d = R"(
120 graph(%a, %w, %b, %stride:int[], %padding:int[], %dilation:int[],
121 %transposed:bool, %output_padding:int[], %groups:int, %benchmark:bool,
122 %deterministic:bool, %cudnn_enabled:bool, %allow_tf32:bool):
123 %r = aten::conv_transpose1d(%a, %w, %b, %stride, %padding, %output_padding, %groups, %dilation)
124 return (%r) )";
125
126 std::string conv_transpose2d_for_deprecated_conv = R"(
127 graph(%a, %w, %b, %stride:int[], %padding:int[], %dilation:int[],
128 %transposed:bool, %output_padding:int[], %groups:int, %benchmark:bool,
129 %deterministic:bool, %cudnn_enabled:bool):
130 %r = aten::conv_transpose2d(%a, %w, %b, %stride, %padding, %output_padding, %groups, %dilation)
131 return (%r) )";
132
133 std::string conv_transpose2d = R"(
134 graph(%a, %w, %b, %stride:int[], %padding:int[], %dilation:int[],
135 %transposed:bool, %output_padding:int[], %groups:int, %benchmark:bool,
136 %deterministic:bool, %cudnn_enabled:bool, %allow_tf32:bool):
137 %r = aten::conv_transpose2d(%a, %w, %b, %stride, %padding, %output_padding, %groups, %dilation)
138 return (%r) )";
139
140 std::string conv_transpose3d_for_deprecated_conv = R"(
141 graph(%a, %w, %b, %stride:int[], %padding:int[], %dilation:int[],
142 %transposed:bool, %output_padding:int[], %groups:int, %benchmark:bool,
143 %deterministic:bool, %cudnn_enabled:bool):
144 %r = aten::conv_transpose3d(%a, %w, %b, %stride, %padding, %output_padding, %groups, %dilation)
145 return (%r) )";
146
147 std::string conv_transpose3d = R"(
148 graph(%a, %w, %b, %stride:int[], %padding:int[], %dilation:int[],
149 %transposed:bool, %output_padding:int[], %groups:int, %benchmark:bool,
150 %deterministic:bool, %cudnn_enabled:bool, %allow_tf32:bool):
151 %r = aten::conv_transpose3d(%a, %w, %b, %stride, %padding, %output_padding, %groups, %dilation)
152 return (%r) )";
153
154 // Filter the unsupported case
155 auto filter_conv1d = [](const Match& match,
156 const std::unordered_map<std::string, Value*>& vmap) {
157 auto calc_value_map = getConvParams(match, vmap);
158 if (calc_value_map["output_padding"].toIntList().size() != 1 ||
159 calc_value_map["stride"].toIntList().size() != 1 ||
160 calc_value_map["padding"].toIntList().size() != 1 ||
161 calc_value_map["dilation"].toIntList().size() != 1) {
162 return false;
163 }
164 return !calc_value_map["transposed"].toBool();
165 };
166 auto filter_conv2d = [](const Match& match,
167 const std::unordered_map<std::string, Value*>& vmap) {
168 auto calc_value_map = getConvParams(match, vmap);
169 if (calc_value_map["output_padding"].toIntList().size() != 2 ||
170 calc_value_map["stride"].toIntList().size() != 2 ||
171 calc_value_map["padding"].toIntList().size() != 2 ||
172 calc_value_map["dilation"].toIntList().size() != 2) {
173 return false;
174 }
175 return !calc_value_map["transposed"].toBool();
176 };
177 auto filter_conv3d = [](const Match& match,
178 const std::unordered_map<std::string, Value*>& vmap) {
179 auto calc_value_map = getConvParams(match, vmap);
180 if (calc_value_map["output_padding"].toIntList().size() != 3 ||
181 calc_value_map["stride"].toIntList().size() != 3 ||
182 calc_value_map["padding"].toIntList().size() != 3 ||
183 calc_value_map["dilation"].toIntList().size() != 3) {
184 return false;
185 }
186 return !calc_value_map["transposed"].toBool();
187 };
188 auto filter_conv_transpose1d =
189 [](const Match& match,
190 const std::unordered_map<std::string, Value*>& vmap) {
191 auto calc_value_map = getConvParams(match, vmap);
192 if (calc_value_map["output_padding"].toIntList().size() != 1 ||
193 calc_value_map["stride"].toIntList().size() != 1 ||
194 calc_value_map["padding"].toIntList().size() != 1 ||
195 calc_value_map["dilation"].toIntList().size() != 1) {
196 return false;
197 }
198 return calc_value_map["transposed"].toBool();
199 };
200 auto filter_conv_transpose2d =
201 [](const Match& match,
202 const std::unordered_map<std::string, Value*>& vmap) {
203 auto calc_value_map = getConvParams(match, vmap);
204 if (calc_value_map["output_padding"].toIntList().size() != 2 ||
205 calc_value_map["stride"].toIntList().size() != 2 ||
206 calc_value_map["padding"].toIntList().size() != 2 ||
207 calc_value_map["dilation"].toIntList().size() != 2) {
208 return false;
209 }
210 return calc_value_map["transposed"].toBool();
211 };
212 auto filter_conv_transpose3d =
213 [](const Match& match,
214 const std::unordered_map<std::string, Value*>& vmap) {
215 auto calc_value_map = getConvParams(match, vmap);
216 if (calc_value_map["output_padding"].toIntList().size() != 3 ||
217 calc_value_map["stride"].toIntList().size() != 3 ||
218 calc_value_map["padding"].toIntList().size() != 3 ||
219 calc_value_map["dilation"].toIntList().size() != 3) {
220 return false;
221 }
222 return calc_value_map["transposed"].toBool();
223 };
224
225 SubgraphRewriter rewriter_conv1d;
226 rewriter_conv1d.RegisterRewritePattern(convolution, conv1d);
227 rewriter_conv1d.RegisterRewritePattern(
228 convolution_deprecated, conv1d_for_deprecated_conv);
229 rewriter_conv1d.runOnGraph(graph, filter_conv1d);
230
231 SubgraphRewriter rewriter_conv2d;
232 rewriter_conv2d.RegisterRewritePattern(convolution, conv2d);
233 rewriter_conv2d.RegisterRewritePattern(
234 convolution_deprecated, conv2d_for_deprecated_conv);
235 rewriter_conv2d.runOnGraph(graph, filter_conv2d);
236
237 SubgraphRewriter rewriter_conv3d;
238 rewriter_conv3d.RegisterRewritePattern(convolution, conv3d);
239 rewriter_conv3d.RegisterRewritePattern(
240 convolution_deprecated, conv3d_for_deprecated_conv);
241 rewriter_conv3d.runOnGraph(graph, filter_conv3d);
242
243 SubgraphRewriter rewriter_conv_transpose1d;
244 rewriter_conv_transpose1d.RegisterRewritePattern(
245 convolution, conv_transpose1d);
246 rewriter_conv_transpose1d.RegisterRewritePattern(
247 convolution_deprecated, conv_transpose1d_for_deprecated_conv);
248 rewriter_conv_transpose1d.runOnGraph(graph, filter_conv_transpose1d);
249
250 SubgraphRewriter rewriter_conv_transpose2d;
251 rewriter_conv_transpose2d.RegisterRewritePattern(
252 convolution, conv_transpose2d);
253 rewriter_conv_transpose2d.RegisterRewritePattern(
254 convolution_deprecated, conv_transpose2d_for_deprecated_conv);
255 rewriter_conv_transpose2d.runOnGraph(graph, filter_conv_transpose2d);
256
257 SubgraphRewriter rewriter_conv_transpose3d;
258 rewriter_conv_transpose3d.RegisterRewritePattern(
259 convolution, conv_transpose3d);
260 rewriter_conv_transpose3d.RegisterRewritePattern(
261 convolution_deprecated, conv_transpose3d_for_deprecated_conv);
262 rewriter_conv_transpose3d.runOnGraph(graph, filter_conv_transpose3d);
263 }
264
isClampFusable(const Match & match,const std::unordered_map<std::string,Value * > & vmap)265 bool isClampFusable(
266 const Match& match,
267 const std::unordered_map<std::string, Value*>& vmap) {
268 const auto& match_vmap = match.values_map;
269 TORCH_CHECK(
270 vmap.find("dummy_min_max") != vmap.end(),
271 "Expected to find dummy_min_max Value in the subgraph to be replaced.");
272 auto dummy_min_max =
273 graph_rewrite_helper::getIValue("dummy_min_max", match_vmap, vmap);
274
275 auto is_fusable = !dummy_min_max || dummy_min_max.value().isNone();
276
277 // Also check if the output_min and output_max values are actually constant.
278 // If hardtanh's min/max Value's are not actually constants, we will end up
279 // rerouting those values to prepack op. And if they are not constants
280 // we will not be able to remove prepacking ops.
281 if (vmap.find("output_min") != vmap.end()) {
282 // aten::relu pattern does not have output_min/output_max.
283 // aten::hardtanh/_ does.
284 TORCH_CHECK(
285 vmap.find("output_max") != vmap.end(),
286 "Expected to find output_max as well given "
287 "output_min exist in pattern graph.");
288 // If output_min/max are not constant, we get std::nullopt.
289 auto output_min =
290 graph_rewrite_helper::getIValue("output_min", match_vmap, vmap);
291 auto output_max =
292 graph_rewrite_helper::getIValue("output_max", match_vmap, vmap);
293 is_fusable =
294 is_fusable && (output_min.has_value() && output_max.has_value());
295 }
296
297 return is_fusable;
298 }
299
300 } // namespace torch::jit::graph_rewrite_helper
301