xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/passes/graph_rewrite_helper.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/passes/graph_rewrite_helper.h>
2 
3 #include <torch/csrc/jit/ir/subgraph_matcher.h>
4 #include <torch/csrc/jit/passes/constant_propagation.h>
5 #include <torch/csrc/jit/passes/subgraph_rewrite.h>
6 
7 namespace torch::jit::graph_rewrite_helper {
8 
getFuncName(Value * func_value)9 std::string getFuncName(Value* func_value) {
10   auto func = func_value->type()->expectRef<FunctionType>().function();
11   const auto& qname = func->qualname();
12   const auto& name = qname.qualifiedName();
13   auto rdot_idx = name.rfind('.');
14   if (rdot_idx != std::string::npos) {
15     return name.substr(rdot_idx + 1, name.length());
16   } else {
17     return name;
18   }
19 }
20 
getValue(const std::string & name,const std::unordered_map<const Value *,Value * > & match_vmap,const std::unordered_map<std::string,Value * > & vmap)21 Value* getValue(
22     const std::string& name,
23     const std::unordered_map<const Value*, Value*>& match_vmap,
24     const std::unordered_map<std::string, Value*>& vmap) {
25   return match_vmap.at(vmap.at(name));
26 }
27 
getIValue(const std::string & name,const std::unordered_map<const Value *,Value * > & match_vmap,const std::unordered_map<std::string,Value * > & vmap)28 std::optional<IValue> getIValue(
29     const std::string& name,
30     const std::unordered_map<const Value*, Value*>& match_vmap,
31     const std::unordered_map<std::string, Value*>& vmap) {
32   return toIValue(getValue(name, match_vmap, vmap));
33 }
34 
getConvParams(const Match & match,const std::unordered_map<std::string,Value * > & vmap)35 static std::unordered_map<std::string, c10::IValue> getConvParams(
36     const Match& match,
37     const std::unordered_map<std::string, Value*>& vmap) {
38   std::unordered_map<std::string, c10::IValue> calc_values;
39   const auto& match_vmap = match.values_map;
40   auto transposed_value = getIValue("transposed", match_vmap, vmap).value();
41   calc_values["transposed"] = transposed_value;
42   auto output_padding_value =
43       getIValue("output_padding", match_vmap, vmap).value();
44   calc_values["output_padding"] = output_padding_value;
45   auto stride_value = getIValue("stride", match_vmap, vmap).value();
46   calc_values["stride"] = stride_value;
47   auto padding_value = getIValue("padding", match_vmap, vmap).value();
48   calc_values["padding"] = padding_value;
49   auto dilation_value = getIValue("dilation", match_vmap, vmap).value();
50   calc_values["dilation"] = dilation_value;
51   return calc_values;
52 }
53 
replaceConvolutionWithAtenConv(std::shared_ptr<Graph> & graph)54 void replaceConvolutionWithAtenConv(std::shared_ptr<Graph>& graph) {
55   // TODO: remove constant prop in the pass
56   ConstantPropagation(graph);
57   std::string convolution_deprecated = R"(
58       graph(%a, %w, %b, %stride:int[], %padding:int[], %dilation:int[],
59           %transposed:bool, %output_padding:int[], %groups:int, %benchmark:bool,
60           %deterministic:bool, %cudnn_enabled:bool):
61         %r = aten::_convolution(%a, %w, %b, %stride, %padding, %dilation,
62             %transposed, %output_padding, %groups, %benchmark, %deterministic, %cudnn_enabled)
63         return (%r) )";
64 
65   std::string convolution = R"(
66       graph(%a, %w, %b, %stride:int[], %padding:int[], %dilation:int[],
67           %transposed:bool, %output_padding:int[], %groups:int, %benchmark:bool,
68           %deterministic:bool, %cudnn_enabled:bool, %allow_tf32:bool):
69         %r = aten::_convolution(%a, %w, %b, %stride, %padding, %dilation,
70             %transposed, %output_padding, %groups, %benchmark, %deterministic, %cudnn_enabled, %allow_tf32)
71         return (%r) )";
72 
73   std::string conv2d_for_deprecated_conv = R"(
74       graph(%a, %w, %b, %stride:int[], %padding:int[], %dilation:int[],
75           %transposed:bool, %output_padding:int[], %groups:int, %benchmark:bool,
76           %deterministic:bool, %cudnn_enabled:bool):
77         %r = aten::conv2d(%a, %w, %b, %stride, %padding, %dilation, %groups)
78         return (%r) )";
79   std::string conv2d = R"(
80       graph(%a, %w, %b, %stride:int[], %padding:int[], %dilation:int[],
81           %transposed:bool, %output_padding:int[], %groups:int, %benchmark:bool,
82           %deterministic:bool, %cudnn_enabled:bool, %allow_tf32:bool):
83         %r = aten::conv2d(%a, %w, %b, %stride, %padding, %dilation, %groups)
84         return (%r) )";
85 
86   std::string conv1d_for_deprecated_conv = R"(
87       graph(%a, %w, %b, %stride:int[], %padding:int[], %dilation:int[],
88           %transposed:bool, %output_padding:int[], %groups:int, %benchmark:bool,
89           %deterministic:bool, %cudnn_enabled:bool):
90         %r = aten::conv1d(%a, %w, %b, %stride, %padding, %dilation, %groups)
91         return (%r) )";
92   std::string conv1d = R"(
93       graph(%a, %w, %b, %stride:int[], %padding:int[], %dilation:int[],
94           %transposed:bool, %output_padding:int[], %groups:int, %benchmark:bool,
95           %deterministic:bool, %cudnn_enabled:bool, %allow_tf32:bool):
96         %r = aten::conv1d(%a, %w, %b, %stride, %padding, %dilation, %groups)
97         return (%r) )";
98 
99   std::string conv3d_for_deprecated_conv = R"(
100       graph(%a, %w, %b, %stride:int[], %padding:int[], %dilation:int[],
101           %transposed:bool, %output_padding:int[], %groups:int, %benchmark:bool,
102           %deterministic:bool, %cudnn_enabled:bool):
103         %r = aten::conv3d(%a, %w, %b, %stride, %padding, %dilation, %groups)
104         return (%r) )";
105   std::string conv3d = R"(
106       graph(%a, %w, %b, %stride:int[], %padding:int[], %dilation:int[],
107           %transposed:bool, %output_padding:int[], %groups:int, %benchmark:bool,
108           %deterministic:bool, %cudnn_enabled:bool, %allow_tf32:bool):
109         %r = aten::conv3d(%a, %w, %b, %stride, %padding, %dilation, %groups)
110         return (%r) )";
111 
112   std::string conv_transpose1d_for_deprecated_conv = R"(
113       graph(%a, %w, %b, %stride:int[], %padding:int[], %dilation:int[],
114           %transposed:bool, %output_padding:int[], %groups:int, %benchmark:bool,
115           %deterministic:bool, %cudnn_enabled:bool):
116         %r = aten::conv_transpose1d(%a, %w, %b, %stride, %padding, %output_padding, %groups, %dilation)
117         return (%r) )";
118 
119   std::string conv_transpose1d = R"(
120       graph(%a, %w, %b, %stride:int[], %padding:int[], %dilation:int[],
121           %transposed:bool, %output_padding:int[], %groups:int, %benchmark:bool,
122           %deterministic:bool, %cudnn_enabled:bool, %allow_tf32:bool):
123         %r = aten::conv_transpose1d(%a, %w, %b, %stride, %padding, %output_padding, %groups, %dilation)
124         return (%r) )";
125 
126   std::string conv_transpose2d_for_deprecated_conv = R"(
127       graph(%a, %w, %b, %stride:int[], %padding:int[], %dilation:int[],
128           %transposed:bool, %output_padding:int[], %groups:int, %benchmark:bool,
129           %deterministic:bool, %cudnn_enabled:bool):
130         %r = aten::conv_transpose2d(%a, %w, %b, %stride, %padding, %output_padding, %groups, %dilation)
131         return (%r) )";
132 
133   std::string conv_transpose2d = R"(
134       graph(%a, %w, %b, %stride:int[], %padding:int[], %dilation:int[],
135           %transposed:bool, %output_padding:int[], %groups:int, %benchmark:bool,
136           %deterministic:bool, %cudnn_enabled:bool, %allow_tf32:bool):
137         %r = aten::conv_transpose2d(%a, %w, %b, %stride, %padding, %output_padding, %groups, %dilation)
138         return (%r) )";
139 
140   std::string conv_transpose3d_for_deprecated_conv = R"(
141       graph(%a, %w, %b, %stride:int[], %padding:int[], %dilation:int[],
142           %transposed:bool, %output_padding:int[], %groups:int, %benchmark:bool,
143           %deterministic:bool, %cudnn_enabled:bool):
144         %r = aten::conv_transpose3d(%a, %w, %b, %stride, %padding, %output_padding, %groups, %dilation)
145         return (%r) )";
146 
147   std::string conv_transpose3d = R"(
148       graph(%a, %w, %b, %stride:int[], %padding:int[], %dilation:int[],
149           %transposed:bool, %output_padding:int[], %groups:int, %benchmark:bool,
150           %deterministic:bool, %cudnn_enabled:bool, %allow_tf32:bool):
151         %r = aten::conv_transpose3d(%a, %w, %b, %stride, %padding, %output_padding, %groups, %dilation)
152         return (%r) )";
153 
154   // Filter the unsupported case
155   auto filter_conv1d = [](const Match& match,
156                           const std::unordered_map<std::string, Value*>& vmap) {
157     auto calc_value_map = getConvParams(match, vmap);
158     if (calc_value_map["output_padding"].toIntList().size() != 1 ||
159         calc_value_map["stride"].toIntList().size() != 1 ||
160         calc_value_map["padding"].toIntList().size() != 1 ||
161         calc_value_map["dilation"].toIntList().size() != 1) {
162       return false;
163     }
164     return !calc_value_map["transposed"].toBool();
165   };
166   auto filter_conv2d = [](const Match& match,
167                           const std::unordered_map<std::string, Value*>& vmap) {
168     auto calc_value_map = getConvParams(match, vmap);
169     if (calc_value_map["output_padding"].toIntList().size() != 2 ||
170         calc_value_map["stride"].toIntList().size() != 2 ||
171         calc_value_map["padding"].toIntList().size() != 2 ||
172         calc_value_map["dilation"].toIntList().size() != 2) {
173       return false;
174     }
175     return !calc_value_map["transposed"].toBool();
176   };
177   auto filter_conv3d = [](const Match& match,
178                           const std::unordered_map<std::string, Value*>& vmap) {
179     auto calc_value_map = getConvParams(match, vmap);
180     if (calc_value_map["output_padding"].toIntList().size() != 3 ||
181         calc_value_map["stride"].toIntList().size() != 3 ||
182         calc_value_map["padding"].toIntList().size() != 3 ||
183         calc_value_map["dilation"].toIntList().size() != 3) {
184       return false;
185     }
186     return !calc_value_map["transposed"].toBool();
187   };
188   auto filter_conv_transpose1d =
189       [](const Match& match,
190          const std::unordered_map<std::string, Value*>& vmap) {
191         auto calc_value_map = getConvParams(match, vmap);
192         if (calc_value_map["output_padding"].toIntList().size() != 1 ||
193             calc_value_map["stride"].toIntList().size() != 1 ||
194             calc_value_map["padding"].toIntList().size() != 1 ||
195             calc_value_map["dilation"].toIntList().size() != 1) {
196           return false;
197         }
198         return calc_value_map["transposed"].toBool();
199       };
200   auto filter_conv_transpose2d =
201       [](const Match& match,
202          const std::unordered_map<std::string, Value*>& vmap) {
203         auto calc_value_map = getConvParams(match, vmap);
204         if (calc_value_map["output_padding"].toIntList().size() != 2 ||
205             calc_value_map["stride"].toIntList().size() != 2 ||
206             calc_value_map["padding"].toIntList().size() != 2 ||
207             calc_value_map["dilation"].toIntList().size() != 2) {
208           return false;
209         }
210         return calc_value_map["transposed"].toBool();
211       };
212   auto filter_conv_transpose3d =
213       [](const Match& match,
214          const std::unordered_map<std::string, Value*>& vmap) {
215         auto calc_value_map = getConvParams(match, vmap);
216         if (calc_value_map["output_padding"].toIntList().size() != 3 ||
217             calc_value_map["stride"].toIntList().size() != 3 ||
218             calc_value_map["padding"].toIntList().size() != 3 ||
219             calc_value_map["dilation"].toIntList().size() != 3) {
220           return false;
221         }
222         return calc_value_map["transposed"].toBool();
223       };
224 
225   SubgraphRewriter rewriter_conv1d;
226   rewriter_conv1d.RegisterRewritePattern(convolution, conv1d);
227   rewriter_conv1d.RegisterRewritePattern(
228       convolution_deprecated, conv1d_for_deprecated_conv);
229   rewriter_conv1d.runOnGraph(graph, filter_conv1d);
230 
231   SubgraphRewriter rewriter_conv2d;
232   rewriter_conv2d.RegisterRewritePattern(convolution, conv2d);
233   rewriter_conv2d.RegisterRewritePattern(
234       convolution_deprecated, conv2d_for_deprecated_conv);
235   rewriter_conv2d.runOnGraph(graph, filter_conv2d);
236 
237   SubgraphRewriter rewriter_conv3d;
238   rewriter_conv3d.RegisterRewritePattern(convolution, conv3d);
239   rewriter_conv3d.RegisterRewritePattern(
240       convolution_deprecated, conv3d_for_deprecated_conv);
241   rewriter_conv3d.runOnGraph(graph, filter_conv3d);
242 
243   SubgraphRewriter rewriter_conv_transpose1d;
244   rewriter_conv_transpose1d.RegisterRewritePattern(
245       convolution, conv_transpose1d);
246   rewriter_conv_transpose1d.RegisterRewritePattern(
247       convolution_deprecated, conv_transpose1d_for_deprecated_conv);
248   rewriter_conv_transpose1d.runOnGraph(graph, filter_conv_transpose1d);
249 
250   SubgraphRewriter rewriter_conv_transpose2d;
251   rewriter_conv_transpose2d.RegisterRewritePattern(
252       convolution, conv_transpose2d);
253   rewriter_conv_transpose2d.RegisterRewritePattern(
254       convolution_deprecated, conv_transpose2d_for_deprecated_conv);
255   rewriter_conv_transpose2d.runOnGraph(graph, filter_conv_transpose2d);
256 
257   SubgraphRewriter rewriter_conv_transpose3d;
258   rewriter_conv_transpose3d.RegisterRewritePattern(
259       convolution, conv_transpose3d);
260   rewriter_conv_transpose3d.RegisterRewritePattern(
261       convolution_deprecated, conv_transpose3d_for_deprecated_conv);
262   rewriter_conv_transpose3d.runOnGraph(graph, filter_conv_transpose3d);
263 }
264 
isClampFusable(const Match & match,const std::unordered_map<std::string,Value * > & vmap)265 bool isClampFusable(
266     const Match& match,
267     const std::unordered_map<std::string, Value*>& vmap) {
268   const auto& match_vmap = match.values_map;
269   TORCH_CHECK(
270       vmap.find("dummy_min_max") != vmap.end(),
271       "Expected to find dummy_min_max Value in the subgraph to be replaced.");
272   auto dummy_min_max =
273       graph_rewrite_helper::getIValue("dummy_min_max", match_vmap, vmap);
274 
275   auto is_fusable = !dummy_min_max || dummy_min_max.value().isNone();
276 
277   // Also check if the output_min and output_max values are actually constant.
278   // If hardtanh's min/max Value's are not actually constants, we will end up
279   // rerouting those values to prepack op. And if they are not constants
280   // we will not be able to remove prepacking ops.
281   if (vmap.find("output_min") != vmap.end()) {
282     // aten::relu pattern does not have output_min/output_max.
283     // aten::hardtanh/_ does.
284     TORCH_CHECK(
285         vmap.find("output_max") != vmap.end(),
286         "Expected to find output_max as well given "
287         "output_min exist in pattern graph.");
288     // If output_min/max are not constant, we get std::nullopt.
289     auto output_min =
290         graph_rewrite_helper::getIValue("output_min", match_vmap, vmap);
291     auto output_max =
292         graph_rewrite_helper::getIValue("output_max", match_vmap, vmap);
293     is_fusable =
294         is_fusable && (output_min.has_value() && output_max.has_value());
295   }
296 
297   return is_fusable;
298 }
299 
300 } // namespace torch::jit::graph_rewrite_helper
301