xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/passes/quantization/quantization_patterns.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <c10/util/irange.h>
4 #include <torch/csrc/jit/ir/ir.h>
5 #include <torch/csrc/jit/ir/subgraph_matcher.h>
6 #include <torch/csrc/jit/jit_log.h>
7 #include <torch/csrc/jit/passes/quantization/helper.h>
8 #include <torch/csrc/jit/passes/subgraph_rewrite.h>
9 #include <string>
10 #include <unordered_map>
11 #include <utility>
12 
13 namespace torch {
14 namespace jit {
15 
16 struct QuantFusionInfo {
17   std::string quantized_op_name;
18   std::string pattern;
19   std::string replacement;
20   std::vector<MatchFilter> filters = {};
21 };
22 
23 namespace {
getExtraArgList(std::vector<std::string> extra_args)24 std::string getExtraArgList(std::vector<std::string> extra_args) {
25   return std::accumulate(
26       extra_args.begin(),
27       extra_args.end(),
28       std::string(),
29       [](std::string acc, const std::string& arg) { return acc + ", " + arg; });
30 }
31 
32 // Get the pattern we want to replace the match with
33 std::string getAtenOpPattern(
34     const std::string& graph_header,
35     const std::string& op_name,
36     const std::vector<std::string>& extra_op_args,
37     bool scalar_args = false) {
38   std::vector<std::string> _extra_op_args = extra_op_args;
39   std::string aten_op_pattern = graph_header;
40   if (scalar_args) {
41     for (const auto& extra_arg : _extra_op_args) {
42       aten_op_pattern
43           .append(R"(
44           )")
45           .append(extra_arg)
46           .append("_scalar = aten::item(")
47           .append(extra_arg)
48           .append(")");
49     }
50 
51     for (auto& _extra_op_arg : _extra_op_args) {
52       _extra_op_arg.append("_scalar");
53     }
54   }
55   const auto& extra_op_arg_list = getExtraArgList(std::move(_extra_op_args));
56   aten_op_pattern += R"(
57           %r = )";
58   aten_op_pattern += op_name + "(" + "%a_quant" + extra_op_arg_list + ")";
59   aten_op_pattern += R"(
60           return (%r) )";
61   return aten_op_pattern;
62 }
63 
64 // generate ops for quantize pattern for a scalar value
getQuantizeForScalar(const std::string & value)65 std::string getQuantizeForScalar(const std::string& value) {
66   // 6 is `torch.float` ScalarType, we are creating a float scalar
67   // tensor from a scalar value
68   std::string quantize_pattern = R"(
69           )" +
70       value + "_float_scalar_type : int = prim::Constant[value=6]()";
71   quantize_pattern += R"(
72           )" +
73       value + "_none : None = prim::Constant()";
74   quantize_pattern += R"(
75           )" +
76       value + "_tensor : Tensor = aten::scalar_tensor(" + value + ", " + value +
77       "_float_scalar_type";
78   for (const auto i : c10::irange(3)) {
79     (void)i; // Suppress unused variable warning
80     quantize_pattern += ", " + value + "_none";
81   }
82   quantize_pattern += ")";
83   quantize_pattern +=
84       R"(
85           )" +
86       value + "_quant = aten::quantize_per_tensor(" + value + "_tensor" +
87       getExtraArgList(
88           {value + "_scale", value + "_zero_point", value + "_dtype"}) +
89       ")";
90   return quantize_pattern;
91 }
92 
getDequantize(const std::string & value)93 std::string getDequantize(const std::string& value) {
94   return R"(
95           )" +
96       value + "_dequant = aten::dequantize(" + value + "_quant)";
97 }
98 
getItem(const std::string & value)99 std::string getItem(const std::string& value) {
100   return R"(
101           )" +
102       value + "_scalar : float = aten::item(" + value + "_dequant)";
103 }
104 
105 // Patterns for the ops that inherit parameters from input
getInputTensorQParamOpPattern(const std::string & op_name,const std::vector<std::string> & extra_op_args)106 std::string getInputTensorQParamOpPattern(
107     const std::string& op_name,
108     const std::vector<std::string>& extra_op_args) {
109   const auto& extra_op_arg_list = getExtraArgList(extra_op_args);
110   std::string op_pattern = "graph(%a_quant" + extra_op_arg_list + "):" + R"(
111           %a_dequant = aten::dequantize(%a_quant)
112           %r = )" +
113       op_name + "(" + "%a_dequant" + extra_op_arg_list + ")" + R"(
114           %r_scale : float = aten::q_scale(%a_quant)
115           %r_zero_point : int = aten::q_zero_point(%a_quant)
116           %r_dtype : int = prim::dtype(%a_quant)
117           %r_quant = aten::quantize_per_tensor(%r, %r_scale, %r_zero_point, %r_dtype)
118           return (%r_quant) )";
119   return op_pattern;
120 }
121 
122 // QuantFusionInfo for the ops that inherit parameters from input
getInputTensorQParamOpFusionInfo(const std::string & op_name,const std::vector<std::string> & extra_op_args)123 QuantFusionInfo getInputTensorQParamOpFusionInfo(
124     const std::string& op_name,
125     const std::vector<std::string>& extra_op_args) {
126   std::string op_pattern =
127       getInputTensorQParamOpPattern(op_name, extra_op_args);
128   const auto& extra_op_arg_list = getExtraArgList(extra_op_args);
129   std::string graph_header = "graph(%a_quant" + extra_op_arg_list + "):";
130   std::string op_replacement =
131       getAtenOpPattern(graph_header, op_name, extra_op_args);
132 
133   return {op_name, std::move(op_pattern), std::move(op_replacement)};
134 }
135 
136 // quant fusion for ops like `quantized::add_scalar`, `quantized::mul_scalar`
137 QuantFusionInfo getBinaryOpScalarFusionInfo(
138     const std::string& op_name,
139     const std::vector<std::string>& extra_op_args,
140     const std::string& quantized_op_name,
141     const std::vector<std::string>& extra_quantized_op_args,
142     const std::vector<MatchFilter>& filters = {}) {
143   std::string op_pattern =
144       getInputTensorQParamOpPattern(op_name, extra_op_args);
145 
146   const auto& extra_op_arg_list = getExtraArgList(extra_op_args);
147   std::string graph_header = "graph(%a_quant" + extra_op_arg_list + "):";
148   std::string op_replacement = getAtenOpPattern(
149       graph_header, quantized_op_name, extra_quantized_op_args);
150 
151   return {op_name, std::move(op_pattern), std::move(op_replacement), filters};
152 }
153 
getClampOpFusionInfo(const std::string & op_name,const std::vector<std::string> & extra_op_args)154 QuantFusionInfo getClampOpFusionInfo(
155     const std::string& op_name,
156     const std::vector<std::string>& extra_op_args) {
157   std::vector<std::string> header_args = extra_op_args;
158   std::vector<std::string> input_qparams = {"_scale", "_zero_point", "_dtype"};
159   for (const auto& arg : extra_op_args) {
160     for (const auto& qparam : input_qparams) {
161       header_args.push_back(arg + qparam);
162     }
163   }
164   for (const auto& qparam : input_qparams) {
165     header_args.push_back("%r" + qparam);
166   }
167   const auto& extra_header_arg_list = getExtraArgList(std::move(header_args));
168   std::string graph_header = "graph(%a_quant" + extra_header_arg_list + "):";
169   std::string op_pattern = graph_header;
170   for (const auto& arg : extra_op_args) {
171     op_pattern += getQuantizeForScalar(arg);
172     op_pattern += getDequantize(arg);
173     op_pattern += getItem(arg);
174   }
175   op_pattern += getDequantize("%a");
176   op_pattern += R"(
177           %r = )";
178   std::vector<std::string> scalar_extra_args;
179   scalar_extra_args.reserve(extra_op_args.size());
180   for (const auto& arg : extra_op_args) {
181     scalar_extra_args.push_back(arg + "_scalar");
182   }
183   op_pattern += op_name + "(" + "%a_dequant" +
184       getExtraArgList(std::move(scalar_extra_args)) + ")";
185   // IR pattern common to all ops that inherit qparam from input
186   op_pattern += R"(
187           %r_quant = aten::quantize_per_tensor(%r, %r_scale, %r_zero_point, %r_dtype)
188           return (%r_quant) )";
189 
190   std::string aten_op_pattern =
191       getAtenOpPattern(graph_header, op_name, extra_op_args);
192 
193   return {op_name, std::move(op_pattern), std::move(aten_op_pattern)};
194 }
195 
196 // Patterns for the ops that has fixed quantization parameters
getFixedQParamOpFusionInfo(const std::string & op_name,const std::vector<std::string> & extra_op_args,bool is_symmetric)197 QuantFusionInfo getFixedQParamOpFusionInfo(
198     const std::string& op_name,
199     const std::vector<std::string>& extra_op_args,
200     bool is_symmetric) {
201   const auto& extra_op_arg_list = getExtraArgList(extra_op_args);
202   std::string graph_header = "graph(%a_quant" + extra_op_arg_list + "):";
203   std::string op_pattern = graph_header;
204   op_pattern += R"(
205           %a_dequant = aten::dequantize(%a_quant)
206           %r = )";
207   op_pattern += op_name + "(" + "%a_dequant" + extra_op_arg_list + ")";
208   // IR pattern common to all ops with fixed quantization parameters for
209   // asymetric quantization
210   std::string asym_fixed_qparam_op_suffix = R"(
211           %r_scale : float = prim::Constant[value=0.00390625]()
212           %r_zero_point : int = prim::Constant[value=0]()
213           %r_dtype : int = prim::Constant[value=13]()
214           %r_quant = aten::quantize_per_tensor(%r, %r_scale, %r_zero_point, %r_dtype)
215           return (%r_quant) )";
216 
217   std::string sym_fixed_qparam_op_suffix = R"(
218           %r_scale : float = prim::Constant[value=0.0078125]()
219           %r_zero_point : int = prim::Constant[value=128]()
220           %r_dtype : int = prim::Constant[value=13]()
221           %r_quant = aten::quantize_per_tensor(%r, %r_scale, %r_zero_point, %r_dtype)
222           return (%r_quant) )";
223   op_pattern +=
224       is_symmetric ? sym_fixed_qparam_op_suffix : asym_fixed_qparam_op_suffix;
225 
226   std::string aten_op_pattern =
227       getAtenOpPattern(graph_header, op_name, extra_op_args);
228 
229   return {op_name, std::move(op_pattern), std::move(aten_op_pattern)};
230 }
231 
232 // filter that checks %b_scalar is a scalar
input_b_is_scalar(const Match & match,const std::unordered_map<std::string,Value * > & vmap)233 bool input_b_is_scalar(
234     const Match& match,
235     const std::unordered_map<std::string, Value*>& vmap) {
236   const auto& match_vmap = match.values_map;
237   auto b_scalar = match_vmap.at(vmap.at("b_scalar"));
238   return isScalar(b_scalar);
239 }
240 
241 // Patterns for ops that require observation for output quantization parameters
242 // Example:
243 //
244 // before fusion:
245 //
246 // graph(%a_quant, %r_scale, %r_zero_point, %r_dtype):
247 //     %a_dequant = aten::dequantize(%a_quant)
248 //     %r = {op_name}(%a_dequant, {extra_args})
249 //     %r_quant = aten::quantize_per_tensor(%r, %r_scale, %r_zero_point,
250 //     %r_dtype) return (%r_quant)
251 //
252 // after fusion:
253 //
254 // graph(%a_quant, %r_scale, %r_zero_point, %r_dtype):
255 //     %r_quant = {quantized_op_name}(%a_quant, {extra_args}, %r_scale,
256 //     %r_zero_point) return (%r_quant)
getObservedQParamOpFusionInfo(const std::string & fp_op_name,const std::string & q_op_name,const std::vector<std::string> & fp_extra_args,const std::vector<std::string> & q_extra_args)257 QuantFusionInfo getObservedQParamOpFusionInfo(
258     const std::string& fp_op_name,
259     const std::string& q_op_name,
260     const std::vector<std::string>& fp_extra_args,
261     const std::vector<std::string>& q_extra_args) {
262   const auto& fp_extra_arg_list = getExtraArgList(fp_extra_args);
263   const auto& q_extra_arg_list = getExtraArgList(q_extra_args);
264 
265   std::string op_pattern = "graph(%a_quant" + fp_extra_arg_list +
266       ", %r_scale, %r_zero_point, %r_dtype):" + R"(
267           %a_dequant = aten::dequantize(%a_quant)
268           %r = )" +
269       fp_op_name + "(" + "%a_dequant" + fp_extra_arg_list + ")" + R"(
270           %r_quant = aten::quantize_per_tensor(%r, %r_scale, %r_zero_point, %r_dtype)
271           return (%r_quant) )";
272 
273   std::string aten_op_pattern = "graph(%a_quant" + fp_extra_arg_list +
274       ", %r_scale, %r_zero_point, %r_dtype):" + R"(
275           %r_quant = )" +
276       q_op_name + "(%a_quant" + q_extra_arg_list +
277       ", %r_scale, %r_zero_point)" + R"(
278           return (%r_quant) )";
279 
280   return {q_op_name, std::move(op_pattern), std::move(aten_op_pattern)};
281 }
282 
283 } // namespace
284 
quant_fusion_pattern_and_replacements()285 static std::vector<QuantFusionInfo> quant_fusion_pattern_and_replacements() {
286   // aten::conv1d
287   std::string conv1d = R"(
288 graph(%a_quant, %packed_params, %r_scale, %r_zero_point, %r_dtype, %stride, %padding, %dilation, %groups):
289         %a_dequant = aten::dequantize(%a_quant)
290         %w_quant : Tensor, %b : Tensor? = quantized::conv1d_unpack(%packed_params)
291         %w_dequant = aten::dequantize(%w_quant)
292         %r = aten::conv1d(%a_dequant, %w_dequant, %b, %stride, %padding, %dilation, %groups)
293         %r_quant = aten::quantize_per_tensor(%r, %r_scale, %r_zero_point, %r_dtype)
294         return (%r_quant) )";
295 
296   // aten::conv1d - aten::relu
297   std::string conv1d_relu = R"(
298 graph(%a_quant, %packed_params, %r_scale, %r_zero_point, %r_dtype, %stride, %padding, %dilation, %groups):
299         %a_dequant = aten::dequantize(%a_quant)
300         %w_quant : Tensor, %b : Tensor? = quantized::conv1d_unpack(%packed_params)
301         %w_dequant = aten::dequantize(%w_quant)
302         %conv_out = aten::conv1d(%a_dequant, %w_dequant, %b, %stride, %padding, %dilation, %groups)
303         %r = aten::relu(%conv_out)
304         %r_quant = aten::quantize_per_tensor(%r, %r_scale, %r_zero_point, %r_dtype)
305         return (%r_quant) )";
306 
307   // aten::conv1d - aten::relu_
308   std::string conv1d_inplace_relu = R"(
309 graph(%a_quant, %packed_params, %r_scale, %r_zero_point, %r_dtype, %stride, %padding, %dilation, %groups):
310         %a_dequant = aten::dequantize(%a_quant)
311         %w_quant : Tensor, %b : Tensor? = quantized::conv1d_unpack(%packed_params)
312         %w_dequant = aten::dequantize(%w_quant)
313         %conv_out = aten::conv1d(%a_dequant, %w_dequant, %b, %stride, %padding, %dilation, %groups)
314         %r = aten::relu_(%conv_out)
315         %r_quant = aten::quantize_per_tensor(%r, %r_scale, %r_zero_point, %r_dtype)
316         return (%r_quant) )";
317 
318   // quantized::conv1d
319   std::string quantized_conv1d = R"(
320 graph(%a_quant, %packed_params, %r_scale, %r_zero_point, %r_dtype, %stride, %padding, %dilation, %groups):
321         %r_quant = quantized::conv1d(%a_quant, %packed_params, %r_scale, %r_zero_point)
322         return (%r_quant) )";
323 
324   // quantized::conv1d_relu
325   std::string quantized_conv1d_relu = R"(
326 graph(%a_quant, %packed_params, %r_scale, %r_zero_point, %r_dtype, %stride, %padding, %dilation, %groups):
327         %r_quant = quantized::conv1d_relu(%a_quant, %packed_params, %r_scale, %r_zero_point)
328         return (%r_quant) )";
329 
330   // aten::conv2d
331   std::string conv2d = R"(
332 graph(%a_quant, %packed_params, %r_scale, %r_zero_point, %r_dtype, %stride, %padding, %dilation, %groups):
333         %a_dequant = aten::dequantize(%a_quant)
334         %w_quant : Tensor, %b : Tensor? = quantized::conv2d_unpack(%packed_params)
335         %w_dequant = aten::dequantize(%w_quant)
336         %r = aten::conv2d(%a_dequant, %w_dequant, %b, %stride, %padding, %dilation, %groups)
337         %r_quant = aten::quantize_per_tensor(%r, %r_scale, %r_zero_point, %r_dtype)
338         return (%r_quant) )";
339 
340   // aten::conv2d - aten::relu
341   std::string conv2d_relu = R"(
342 graph(%a_quant, %packed_params, %r_scale, %r_zero_point, %r_dtype, %stride, %padding, %dilation, %groups):
343         %a_dequant = aten::dequantize(%a_quant)
344         %w_quant : Tensor, %b : Tensor? = quantized::conv2d_unpack(%packed_params)
345         %w_dequant = aten::dequantize(%w_quant)
346         %conv_out = aten::conv2d(%a_dequant, %w_dequant, %b, %stride, %padding, %dilation, %groups)
347         %r = aten::relu(%conv_out)
348         %r_quant = aten::quantize_per_tensor(%r, %r_scale, %r_zero_point, %r_dtype)
349         return (%r_quant) )";
350 
351   // aten::conv2d - aten::relu_
352   std::string conv2d_inplace_relu = R"(
353 graph(%a_quant, %packed_params, %r_scale, %r_zero_point, %r_dtype, %stride, %padding, %dilation, %groups):
354         %a_dequant = aten::dequantize(%a_quant)
355         %w_quant : Tensor, %b : Tensor? = quantized::conv2d_unpack(%packed_params)
356         %w_dequant = aten::dequantize(%w_quant)
357         %conv_out = aten::conv2d(%a_dequant, %w_dequant, %b, %stride, %padding, %dilation, %groups)
358         %r = aten::relu_(%conv_out)
359         %r_quant = aten::quantize_per_tensor(%r, %r_scale, %r_zero_point, %r_dtype)
360         return (%r_quant) )";
361 
362   // quantized::conv2d
363   std::string quantized_conv2d = R"(
364 graph(%a_quant, %packed_params, %r_scale, %r_zero_point, %r_dtype, %stride, %padding, %dilation, %groups):
365         %r_quant = quantized::conv2d(%a_quant, %packed_params, %r_scale, %r_zero_point)
366         return (%r_quant) )";
367 
368   // quantized::conv2d_relu
369   std::string quantized_conv2d_relu = R"(
370 graph(%a_quant, %packed_params, %r_scale, %r_zero_point, %r_dtype, %stride, %padding, %dilation, %groups):
371         %r_quant = quantized::conv2d_relu(%a_quant, %packed_params, %r_scale, %r_zero_point)
372         return (%r_quant) )";
373 
374   // aten::conv3d
375   std::string conv3d = R"(
376 graph(%a_quant, %packed_params, %r_scale, %r_zero_point, %r_dtype, %stride, %padding, %dilation, %groups):
377         %a_dequant = aten::dequantize(%a_quant)
378         %w_quant : Tensor, %b : Tensor? = quantized::conv3d_unpack(%packed_params)
379         %w_dequant = aten::dequantize(%w_quant)
380         %r = aten::conv3d(%a_dequant, %w_dequant, %b, %stride, %padding, %dilation, %groups)
381         %r_quant = aten::quantize_per_tensor(%r, %r_scale, %r_zero_point, %r_dtype)
382         return (%r_quant) )";
383 
384   // aten::conv3d - aten::relu
385   std::string conv3d_relu = R"(
386 graph(%a_quant, %packed_params, %r_scale, %r_zero_point, %r_dtype, %stride, %padding, %dilation, %groups):
387         %a_dequant = aten::dequantize(%a_quant)
388         %w_quant : Tensor, %b : Tensor? = quantized::conv3d_unpack(%packed_params)
389         %w_dequant = aten::dequantize(%w_quant)
390         %conv_out = aten::conv3d(%a_dequant, %w_dequant, %b, %stride, %padding, %dilation, %groups)
391         %r = aten::relu(%conv_out)
392         %r_quant = aten::quantize_per_tensor(%r, %r_scale, %r_zero_point, %r_dtype)
393         return (%r_quant) )";
394 
395   // aten::conv3d - aten::relu_
396   std::string conv3d_inplace_relu = R"(
397 graph(%a_quant, %packed_params, %r_scale, %r_zero_point, %r_dtype, %stride, %padding, %dilation, %groups):
398         %a_dequant = aten::dequantize(%a_quant)
399         %w_quant : Tensor, %b : Tensor? = quantized::conv3d_unpack(%packed_params)
400         %w_dequant = aten::dequantize(%w_quant)
401         %conv_out = aten::conv3d(%a_dequant, %w_dequant, %b, %stride, %padding, %dilation, %groups)
402         %r = aten::relu_(%conv_out)
403         %r_quant = aten::quantize_per_tensor(%r, %r_scale, %r_zero_point, %r_dtype)
404         return (%r_quant) )";
405 
406   // quantized::conv3d
407   std::string quantized_conv3d = R"(
408 graph(%a_quant, %packed_params, %r_scale, %r_zero_point, %r_dtype, %stride, %padding, %dilation, %groups):
409         %r_quant = quantized::conv3d(%a_quant, %packed_params, %r_scale, %r_zero_point)
410         return (%r_quant) )";
411 
412   // quantized::conv3d_relu
413   std::string quantized_conv3d_relu = R"(
414 graph(%a_quant, %packed_params, %r_scale, %r_zero_point, %r_dtype, %stride, %padding, %dilation, %groups):
415         %r_quant = quantized::conv3d_relu(%a_quant, %packed_params, %r_scale, %r_zero_point)
416         return (%r_quant) )";
417 
418   // aten::conv_transpose1d
419   std::string conv_transpose1d = R"(
420 graph(%a_quant, %packed_params, %r_scale, %r_zero_point, %r_dtype, %stride, %padding, %output_padding, %groups, %dilation):
421         %a_dequant = aten::dequantize(%a_quant)
422         %w_quant : Tensor, %b : Tensor? = quantized::conv_transpose1d_unpack(%packed_params)
423         %w_dequant = aten::dequantize(%w_quant)
424         %r = aten::conv_transpose1d(%a_dequant, %w_dequant, %b, %stride, %padding, %output_padding, %groups, %dilation)
425         %r_quant = aten::quantize_per_tensor(%r, %r_scale, %r_zero_point, %r_dtype)
426         return (%r_quant) )";
427 
428   // quantized::conv_transpose1d
429   std::string quantized_conv_transpose1d = R"(
430 graph(%a_quant, %packed_params, %r_scale, %r_zero_point, %r_dtype, %stride, %padding, %output_padding, %groups, %dilation):
431         %r_quant = quantized::conv_transpose1d(%a_quant, %packed_params, %r_scale, %r_zero_point)
432         return (%r_quant) )";
433 
434   // aten::conv_transpose2d
435   std::string conv_transpose2d = R"(
436 graph(%a_quant, %packed_params, %r_scale, %r_zero_point, %r_dtype, %stride, %padding, %output_padding, %groups, %dilation):
437         %a_dequant = aten::dequantize(%a_quant)
438         %w_quant : Tensor, %b : Tensor? = quantized::conv_transpose2d_unpack(%packed_params)
439         %w_dequant = aten::dequantize(%w_quant)
440         %r = aten::conv_transpose2d(%a_dequant, %w_dequant, %b, %stride, %padding, %output_padding, %groups, %dilation)
441         %r_quant = aten::quantize_per_tensor(%r, %r_scale, %r_zero_point, %r_dtype)
442         return (%r_quant) )";
443 
444   // quantized::conv_transpose1d
445   std::string quantized_conv_transpose2d = R"(
446 graph(%a_quant, %packed_params, %r_scale, %r_zero_point, %r_dtype, %stride, %padding, %output_padding, %groups, %dilation):
447         %r_quant = quantized::conv_transpose2d(%a_quant, %packed_params, %r_scale, %r_zero_point)
448         return (%r_quant) )";
449 
450   std::string add_relu = R"(
451 graph(%a_quant, %b_quant, %alpha, %scale, %zero_point, %dtype):
452          %a_dequant = aten::dequantize(%a_quant)
453          %b_dequant = aten::dequantize(%b_quant)
454          %r_add = aten::add(%a_dequant, %b_dequant, %alpha)
455          %r_relu = aten::relu(%r_add)
456          %r = aten::quantize_per_tensor(%r_relu, %scale, %zero_point, %dtype)
457          return (%r) )";
458 
459   std::string add_inplace_relu = R"(
460 graph(%a_quant, %b_quant, %alpha, %scale, %zero_point, %dtype):
461          %a_dequant = aten::dequantize(%a_quant)
462          %b_dequant = aten::dequantize(%b_quant)
463          %r_add = aten::add(%a_dequant, %b_dequant, %alpha)
464          %r_relu = aten::relu_(%r_add)
465          %r = aten::quantize_per_tensor(%r_relu, %scale, %zero_point, %dtype)
466          return (%r) )";
467 
468   std::string inplace_add_relu = R"(
469 graph(%a_quant, %b_quant, %alpha, %scale, %zero_point, %dtype):
470          %a_dequant = aten::dequantize(%a_quant)
471          %b_dequant = aten::dequantize(%b_quant)
472          %r_add = aten::add_(%a_dequant, %b_dequant, %alpha)
473          %r_relu = aten::relu(%r_add)
474          %r = aten::quantize_per_tensor(%r_relu, %scale, %zero_point, %dtype)
475          return (%r) )";
476 
477   std::string inplace_add_inplace_relu = R"(
478 graph(%a_quant, %b_quant, %alpha, %scale, %zero_point, %dtype):
479          %a_dequant = aten::dequantize(%a_quant)
480          %b_dequant = aten::dequantize(%b_quant)
481          %r_add = aten::add_(%a_dequant, %b_dequant, %alpha)
482          %r_relu = aten::relu_(%r_add)
483          %r = aten::quantize_per_tensor(%r_relu, %scale, %zero_point, %dtype)
484          return (%r) )";
485 
486   std::string quantized_add_relu = R"(
487 graph(%a_quant, %b_quant, %alpha, %scale, %zero_point, %dtype):
488          %r = quantized::add_relu(%a_quant, %b_quant, %scale, %zero_point)
489          return (%r) )";
490 
491   // aten::linear
492   std::string linear = R"(
493 graph(%a_quant, %packed_params, %r_scale, %r_zero_point, %r_dtype):
494         %a_dequant = aten::dequantize(%a_quant)
495         %w_quant : Tensor, %b : Tensor? = quantized::linear_unpack(%packed_params)
496         %w_dequant = aten::dequantize(%w_quant)
497         %r = aten::linear(%a_dequant, %w_dequant, %b)
498         %r_quant = aten::quantize_per_tensor(%r, %r_scale, %r_zero_point, %r_dtype)
499         return (%r_quant) )";
500 
501   std::string linear_relu = R"(
502 graph(%a_quant, %packed_params, %r_scale, %r_zero_point, %r_dtype):
503         %a_dequant = aten::dequantize(%a_quant)
504         %w_quant : Tensor, %b : Tensor? = quantized::linear_unpack(%packed_params)
505         %w_dequant = aten::dequantize(%w_quant)
506         %linear_out = aten::linear(%a_dequant, %w_dequant, %b)
507         %r = aten::relu(%linear_out)
508         %r_quant = aten::quantize_per_tensor(%r, %r_scale, %r_zero_point, %r_dtype)
509         return (%r_quant) )";
510 
511   std::string linear_inplace_relu = R"(
512 graph(%a_quant, %packed_params, %r_scale, %r_zero_point, %r_dtype):
513         %a_dequant = aten::dequantize(%a_quant)
514         %w_quant : Tensor, %b : Tensor? = quantized::linear_unpack(%packed_params)
515         %w_dequant = aten::dequantize(%w_quant)
516         %linear_out = aten::linear(%a_dequant, %w_dequant, %b)
517         %r = aten::relu_(%linear_out)
518         %r_quant = aten::quantize_per_tensor(%r, %r_scale, %r_zero_point, %r_dtype)
519         return (%r_quant) )";
520 
521   // quantized::linear
522   std::string quantized_linear = R"(
523 graph(%a_quant, %packed_params, %r_scale, %r_zero_point, %r_dtype):
524         %r = quantized::linear(%a_quant, %packed_params, %r_scale, %r_zero_point)
525         return (%r) )";
526 
527   std::string quantized_linear_relu = R"(
528 graph(%a_quant, %packed_params, %r_scale, %r_zero_point, %r_dtype):
529         %r = quantized::linear_relu(%a_quant, %packed_params, %r_scale, %r_zero_point)
530         return (%r) )";
531 
532   std::string cat = R"(
533 graph(%input_quant, %dim, %r_scale, %r_zero_point, %r_dtype):
534         %input_dequant = aten::dequantize(%input_quant)
535         %r = aten::cat(%input_dequant, %dim)
536         %r_quant = aten::quantize_per_tensor(%r, %r_scale, %r_zero_point, %r_dtype)
537         return (%r_quant) )";
538 
539   std::string quantized_cat = R"(
540 graph(%input_quant, %dim, %r_scale, %r_zero_point, %r_dtype):
541          %r_quant = quantized::cat(%input_quant, %dim, %r_scale, %r_zero_point)
542          return (%r_quant) )";
543 
544   // aten::add
545   std::string add = R"(
546 graph(%a_quant, %b_quant, %alpha, %scale, %zero_point, %dtype):
547          %a_dequant = aten::dequantize(%a_quant)
548          %b_dequant = aten::dequantize(%b_quant)
549          %r_add = aten::add(%a_dequant, %b_dequant, %alpha)
550          %r = aten::quantize_per_tensor(%r_add, %scale, %zero_point, %dtype)
551          return (%r) )";
552 
553   // TODO: add %dtype after when https://github.com/pytorch/pytorch/issues/34351
554   // is fixed
555   // quantized::add
556   std::string quantized_add = R"(
557 graph(%a_quant, %b_quant, %alpha, %scale, %zero_point, %dtype):
558          %r = quantized::add(%a_quant, %b_quant, %scale, %zero_point)
559          return (%r) )";
560 
561   // aten::add_
562   std::string inplace_add = R"(
563 graph(%a_quant, %b_quant, %alpha, %scale, %zero_point, %dtype):
564          %a_dequant = aten::dequantize(%a_quant)
565          %b_dequant = aten::dequantize(%b_quant)
566          %r_add = aten::add_(%a_dequant, %b_dequant, %alpha)
567          %r = aten::quantize_per_tensor(%r_add, %scale, %zero_point, %dtype)
568          return (%r) )";
569 
570   auto add_scalar = getBinaryOpScalarFusionInfo(
571       "aten::add",
572       {"%b_scalar", "%alpha"},
573       "quantized::add_scalar",
574       {"%b_scalar"},
575       {aten_add_alpha_is_one, input_b_is_scalar});
576 
577   auto add_scalar_out = getBinaryOpScalarFusionInfo(
578       "aten::add_",
579       {"%b_scalar", "%alpha"},
580       "quantized::add_scalar_out",
581       {"%b_scalar", "%a_quant"},
582       {aten_add_alpha_is_one, input_b_is_scalar});
583 
584   // quantized::add_scalar_relu -- fusing quantized::add_scalar
585   // and aten::relu
586   auto quantized_add_scalar_relu_pattern = R"(
587 graph(%a_quant, %b_scalar):
588          %r_add = quantized::add_scalar(%a_quant, %b_scalar)
589          %r = aten::relu(%r_add)
590          return (%r) )";
591 
592   auto quantized_add_scalar_inplace_relu_pattern = R"(
593 graph(%a_quant, %b_scalar):
594          %r_add = quantized::add_scalar(%a_quant, %b_scalar)
595          %r = aten::relu_(%r_add)
596          return (%r) )";
597 
598   auto quantized_add_scalar_relu_replacement = R"(
599 graph(%a_quant, %b_scalar):
600          %r = quantized::add_scalar_relu(%a_quant, %b_scalar)
601          return (%r) )";
602 
603   // quantized::add_scalar_relu_out -- fusing quantized::add_scalarOut
604   // and aten::relu
605   auto quantized_add_scalar_relu_out_pattern = R"(
606 graph(%a_quant, %b_scalar):
607          %r_add = quantized::add_scalar_out(%a_quant, %b_scalar, %a_quant)
608          %r = aten::relu(%r_add)
609          return (%r) )";
610 
611   auto quantized_add_scalar_inplace_relu_out_pattern = R"(
612 graph(%a_quant, %b_scalar):
613          %r_add = quantized::add_scalar_out(%a_quant, %b_scalar, %a_quant)
614          %r = aten::relu_(%r_add)
615          return (%r) )";
616 
617   auto quantized_add_scalar_relu_out_replacement = R"(
618 graph(%a_quant, %b_scalar):
619          %r = quantized::add_scalar_relu_out(%a_quant, %b_scalar, %a_quant)
620          return (%r) )";
621 
622   // quantized::batch_norm
623   std::string batch_norm = R"(
624 graph(%a_quant, %weight, %bias, %mean, %var, %training, %eaf, %eps, %7, %scale, %zero_point, %scalar_type):
625          %a_dequant = aten::dequantize(%a_quant)
626          %r_bn = aten::batch_norm(%a_dequant, %weight, %bias, %mean, %var, %training, %eaf, %eps, %7)
627          %r = aten::quantize_per_tensor(%r_bn, %scale, %zero_point, %scalar_type)
628          return (%r) )";
629   std::string quantized_batch_norm = R"(
630 graph(%a_quant, %weight, %bias, %mean, %var, %training, %eaf, %eps, %7, %scale, %zero_point, %scalar_type):
631          %r = quantized::batch_norm(%a_quant, %weight, %bias, %mean, %var, %eps, %scale, %zero_point)
632          return (%r) )";
633 
634   std::string batch_norm_relu = R"(
635 graph(%a_quant, %weight, %bias, %mean, %var, %training, %eaf, %eps, %7, %scale, %zero_point, %scalar_type):
636          %a_dequant = aten::dequantize(%a_quant)
637          %bn_out = aten::batch_norm(%a_dequant, %weight, %bias, %mean, %var, %training, %eaf, %eps, %7)
638          %relu = aten::relu(%bn_out)
639          %r = aten::quantize_per_tensor(%relu, %scale, %zero_point, %scalar_type)
640          return (%r) )";
641   std::string batch_norm_inplace_relu = R"(
642 graph(%a_quant, %weight, %bias, %mean, %var, %training, %eaf, %eps, %7, %scale, %zero_point, %scalar_type):
643          %a_dequant = aten::dequantize(%a_quant)
644          %bn_out = aten::batch_norm(%a_dequant, %weight, %bias, %mean, %var, %training, %eaf, %eps, %7)
645          %relu = aten::relu_(%bn_out)
646          %r = aten::quantize_per_tensor(%relu, %scale, %zero_point, %scalar_type)
647          return (%r) )";
648 
649   std::string quantized_batch_norm_relu = R"(
650 graph(%a_quant, %weight, %bias, %mean, %var, %training, %eaf, %eps, %7, %scale, %zero_point, %scalar_type):
651          %r = quantized::batch_norm_relu(%a_quant, %weight, %bias, %mean, %var, %eps, %scale, %zero_point)
652          return (%r) )";
653 
654   // aten::mul
655   std::string mul = R"(
656 graph(%a_quant, %b_quant, %scale, %zero_point, %dtype):
657          %a_dequant = aten::dequantize(%a_quant)
658          %b_dequant = aten::dequantize(%b_quant)
659          %r_mul = aten::mul(%a_dequant, %b_dequant)
660          %r = aten::quantize_per_tensor(%r_mul, %scale, %zero_point, %dtype)
661          return (%r) )";
662 
663   // aten::mul_
664   std::string inplace_mul = R"(
665 graph(%a_quant, %b_quant, %scale, %zero_point, %dtype):
666          %a_dequant = aten::dequantize(%a_quant)
667          %b_dequant = aten::dequantize(%b_quant)
668          %r_mul = aten::mul_(%a_dequant, %b_dequant)
669          %r = aten::quantize_per_tensor(%r_mul, %scale, %zero_point, %dtype)
670          return (%r) )";
671 
672   // quantized::mul
673   std::string quantized_mul = R"(
674 graph(%a_quant, %b_quant, %scale, %zero_point, %dtype):
675          %r = quantized::mul(%a_quant, %b_quant, %scale, %zero_point)
676          return (%r) )";
677 
678   auto mul_scalar = getBinaryOpScalarFusionInfo(
679       "aten::mul",
680       {"%b_scalar"},
681       "quantized::mul_scalar",
682       {"%b_scalar"},
683       {input_b_is_scalar});
684 
685   auto mul_scalar_out = getBinaryOpScalarFusionInfo(
686       "aten::mul_",
687       {"%b_scalar"},
688       "quantized::mul_scalar_out",
689       {"%b_scalar", "%a_quant"},
690       {input_b_is_scalar});
691 
692   // quantized::mul_relu
693   std::string mul_relu = R"(
694 graph(%a_quant, %b_quant, %scale, %zero_point, %dtype):
695          %a_dequant = aten::dequantize(%a_quant)
696          %b_dequant = aten::dequantize(%b_quant)
697          %r_mul = aten::mul(%a_dequant, %b_dequant)
698          %r_relu = aten::relu(%r_mul)
699          %r = aten::quantize_per_tensor(%r_relu, %scale, %zero_point, %dtype)
700          return (%r) )";
701 
702   std::string mul_inplace_relu = R"(
703 graph(%a_quant, %b_quant, %scale, %zero_point, %dtype):
704          %a_dequant = aten::dequantize(%a_quant)
705          %b_dequant = aten::dequantize(%b_quant)
706          %r_mul = aten::mul(%a_dequant, %b_dequant)
707          %r_relu = aten::relu_(%r_mul)
708          %r = aten::quantize_per_tensor(%r_relu, %scale, %zero_point, %dtype)
709          return (%r) )";
710 
711   std::string inplace_mul_relu = R"(
712 graph(%a_quant, %b_quant, %scale, %zero_point, %dtype):
713          %a_dequant = aten::dequantize(%a_quant)
714          %b_dequant = aten::dequantize(%b_quant)
715          %r_mul = aten::mul_(%a_dequant, %b_dequant)
716          %r_relu = aten::relu(%r_mul)
717          %r = aten::quantize_per_tensor(%r_relu, %scale, %zero_point, %dtype)
718          return (%r) )";
719 
720   std::string inplace_mul_inplace_relu = R"(
721 graph(%a_quant, %b_quant, %scale, %zero_point, %dtype):
722          %a_dequant = aten::dequantize(%a_quant)
723          %b_dequant = aten::dequantize(%b_quant)
724          %r_mul = aten::mul_(%a_dequant, %b_dequant)
725          %r_relu = aten::relu_(%r_mul)
726          %r = aten::quantize_per_tensor(%r_relu, %scale, %zero_point, %dtype)
727          return (%r) )";
728 
729   std::string quantized_mul_relu = R"(
730 graph(%a_quant, %b_quant, %scale, %zero_point, %dtype):
731          %r = quantized::mul_relu(%a_quant, %b_quant, %scale, %zero_point)
732          return (%r) )";
733 
734   // quantized::mul_scalar_relu -- fusing quantized::mul_scalar
735   // and aten::relu
736   auto quantized_mul_scalar_relu_pattern = R"(
737 graph(%a_quant, %b_scalar):
738          %r_mul = quantized::mul_scalar(%a_quant, %b_scalar)
739          %r = aten::relu(%r_mul)
740          return (%r) )";
741 
742   auto quantized_mul_scalar_inplace_relu_pattern = R"(
743 graph(%a_quant, %b_scalar):
744          %r_mul = quantized::mul_scalar(%a_quant, %b_scalar)
745          %r = aten::relu_(%r_mul)
746          return (%r) )";
747 
748   auto quantized_mul_scalar_relu_replacement = R"(
749 graph(%a_quant, %b_scalar):
750          %r = quantized::mul_scalar_relu(%a_quant, %b_scalar)
751          return (%r) )";
752 
753   // quantized::mul_scalar_relu_out -- fusing quantized::mul_scalarOut
754   // and aten::relu
755   auto quantized_mul_scalar_relu_out_pattern = R"(
756 graph(%a_quant, %b_scalar):
757          %r_mul = quantized::mul_scalar_out(%a_quant, %b_scalar, %a_quant)
758          %r = aten::relu(%r_mul)
759          return (%r) )";
760 
761   auto quantized_mul_scalar_inplace_relu_out_pattern = R"(
762 graph(%a_quant, %b_scalar):
763          %r_mul = quantized::mul_scalar_out(%a_quant, %b_scalar, %a_quant)
764          %r = aten::relu_(%r_mul)
765          return (%r) )";
766 
767   auto quantized_mul_scalar_relu_out_replacement = R"(
768 graph(%a_quant, %b_scalar):
769          %r = quantized::mul_scalar_relu_out(%a_quant, %b_scalar, %a_quant)
770          return (%r) )";
771 
772   // quantized::elu
773   std::string elu = R"(
774 graph(%a_quant, %alpha, %scale, %input_scale, %r_scale, %r_zero_point, %r_dtype):
775          %a_dequant = aten::dequantize(%a_quant)
776          %r = aten::elu(%a_dequant, %alpha, %scale, %input_scale)
777          %r_quant = aten::quantize_per_tensor(%r, %r_scale, %r_zero_point, %r_dtype)
778          return (%r_quant) )";
779 
780   std::string quantized_elu = R"(
781 graph(%a_quant, %alpha, %scale, %input_scale, %r_scale, %r_zero_point, %r_dtype):
782          %r_quant = quantized::elu(%a_quant, %r_scale, %r_zero_point, %alpha, %scale, %input_scale)
783          return (%r_quant) )";
784 
785   std::string elu_ = R"(
786 graph(%a_quant, %alpha, %scale, %input_scale, %r_scale, %r_zero_point, %r_dtype):
787          %a_dequant = aten::dequantize(%a_quant)
788          %r = aten::elu_(%a_dequant, %alpha, %scale, %input_scale)
789          %r_quant = aten::quantize_per_tensor(%r, %r_scale, %r_zero_point, %r_dtype)
790          return (%r_quant) )";
791 
792   // ============= General Ops that inherit quantization parameters from input
793   // tensor =============
794   auto avg_pool1d = getInputTensorQParamOpFusionInfo(
795       "aten::avg_pool1d",
796       {"%kernel_size",
797        "%stride",
798        "%padding",
799        "%ceil_mode",
800        "%count_include_pad"});
801 
802   auto avg_pool2d = getInputTensorQParamOpFusionInfo(
803       "aten::avg_pool2d",
804       {"%kernel_size",
805        "%stride",
806        "%padding",
807        "%ceil_mode",
808        "%count_include_pad",
809        "%divisor_override"});
810 
811   std::string common_general_value_op = R"(
812           %r_scale : float = aten::q_scale(%a_quant)
813           %r_zero_point : int = aten::q_zero_point(%a_quant)
814           %r_dtype : int = prim::dtype(%a_quant)
815           %r_quant = aten::quantize_per_tensor(%r, %r_scale, %r_zero_point, %r_dtype)
816           return (%r_quant) )";
817 
818   auto avg_pool3d = getInputTensorQParamOpFusionInfo(
819       "aten::avg_pool3d",
820       {"%kernel_size",
821        "%stride",
822        "%padding",
823        "%ceil_mode",
824        "%count_include_pad",
825        "%divisor_override"});
826 
827   auto adaptive_avg_pool1d = getInputTensorQParamOpFusionInfo(
828       "aten::adaptive_avg_pool1d", {"%output_size"});
829 
830   auto adaptive_avg_pool2d = getInputTensorQParamOpFusionInfo(
831       "aten::adaptive_avg_pool2d", {"%output_size"});
832 
833   auto adaptive_avg_pool3d = getInputTensorQParamOpFusionInfo(
834       "aten::adaptive_avg_pool3d", {"%output_size"});
835 
836   auto mean1 = getInputTensorQParamOpFusionInfo("aten::mean", {"%dim"});
837 
838   auto mean2 = getInputTensorQParamOpFusionInfo(
839       "aten::mean", {"%dim", "%keepdim", "%out"});
840 
841   auto upsample_nearest1d_vec = getInputTensorQParamOpFusionInfo(
842       "aten::upsample_nearest1d", {"%output_size", "%scale_factors"});
843 
844   auto upsample_nearest2d_vec = getInputTensorQParamOpFusionInfo(
845       "aten::upsample_nearest2d", {"%output_size", "%scale_factors"});
846 
847   auto upsample_nearest3d_vec = getInputTensorQParamOpFusionInfo(
848       "aten::upsample_nearest3d", {"%output_size", "%scale_factors"});
849 
850   auto upsample_linear1d_vec = getInputTensorQParamOpFusionInfo(
851       "aten::upsample_linear1d",
852       {"%output_size", "%align_corners", "%scale_factors"});
853 
854   auto upsample_bilinear2d_vec = getInputTensorQParamOpFusionInfo(
855       "aten::upsample_bilinear2d",
856       {"%output_size", "%align_corners", "%scale_factors"});
857 
858   auto upsample_trilinear3d_vec = getInputTensorQParamOpFusionInfo(
859       "aten::upsample_trilinear3d",
860       {"%output_size", "%align_corners", "%scale_factors"});
861 
862   auto upsample_nearest1d = getInputTensorQParamOpFusionInfo(
863       "aten::upsample_nearest1d", {"%output_size", "%scales"});
864 
865   auto upsample_nearest2d = getInputTensorQParamOpFusionInfo(
866       "aten::upsample_nearest2d", {"%output_size", "%scale_h", "%scale_w"});
867 
868   auto upsample_nearest3d = getInputTensorQParamOpFusionInfo(
869       "aten::upsample_nearest3d",
870       {"%output_size", "%scale_d", "%scale_h", "%scale_w"});
871 
872   auto upsample_linear1d = getInputTensorQParamOpFusionInfo(
873       "aten::upsample_linear1d", {"%output_size", "%align_corners", "%scales"});
874 
875   auto upsample_bilinear2d = getInputTensorQParamOpFusionInfo(
876       "aten::upsample_bilinear2d",
877       {"%output_size", "%align_corners", "%scale_h", "%scale_w"});
878 
879   auto upsample_trilinear3d = getInputTensorQParamOpFusionInfo(
880       "aten::upsample_trilinear3d",
881       {"%output_size", "%align_corners", "%scale_d", "%scale_h", "%scale_w"});
882 
883   auto clamp = getClampOpFusionInfo("aten::clamp", {"%min", "%max"});
884 
885   auto hardtanh = getClampOpFusionInfo("aten::hardtanh", {"%min", "%max"});
886 
887   auto hardtanh_ = getClampOpFusionInfo("aten::hardtanh_", {"%min", "%max"});
888 
889   auto leaky_relu =
890       getInputTensorQParamOpFusionInfo("aten::leaky_relu", {"%negative_slope"});
891 
892   auto leaky_relu_ = getInputTensorQParamOpFusionInfo(
893       "aten::leaky_relu_", {"%negative_slope"});
894 
895   // Ops with fixed quantization parameters
896   auto hardsigmoid = getFixedQParamOpFusionInfo("aten::hardsigmoid", {}, false);
897 
898   auto hardsigmoid_ =
899       getFixedQParamOpFusionInfo("aten::hardsigmoid_", {}, false);
900 
901   auto sigmoid = getFixedQParamOpFusionInfo("aten::sigmoid", {}, false);
902 
903   auto sigmoid_ = getFixedQParamOpFusionInfo("aten::sigmoid_", {}, false);
904 
905   auto tanh = getFixedQParamOpFusionInfo("aten::tanh", {}, true);
906 
907   auto tanh_ = getFixedQParamOpFusionInfo("aten::tanh_", {}, true);
908 
909   auto hardswish = getObservedQParamOpFusionInfo(
910       "aten::hardswish", "quantized::hardswish", {}, {});
911 
912   auto hardswish_ = getObservedQParamOpFusionInfo(
913       "aten::hardswish_", "quantized::hardswish", {}, {});
914 
915   auto layer_norm = getObservedQParamOpFusionInfo(
916       "aten::layer_norm",
917       "quantized::layer_norm",
918       {"%normalized_shape", "%weight", "%bias", "%eps", "%cudnn_enabled"},
919       {"%normalized_shape", "%weight", "%bias", "%eps"});
920 
921   auto group_norm = getObservedQParamOpFusionInfo(
922       "aten::group_norm",
923       "quantized::group_norm",
924       {"%num_groups", "%weight", "%bias", "%eps", "%cudnn_enabled"},
925       {"%num_groups", "%weight", "%bias", "%eps"});
926 
927   auto instance_norm = getObservedQParamOpFusionInfo(
928       "aten::instance_norm",
929       "quantized::instance_norm",
930       {"%weight",
931        "%bias",
932        "%running_mean",
933        "%running_var",
934        "%use_input_stats",
935        "%momentum",
936        "%eps",
937        "%cudnn_enabled"},
938       {"%weight", "%bias", "%eps"});
939 
940   return {
941       {"quantized::conv1d", std::move(conv1d), std::move(quantized_conv1d)},
942       {"quantized::conv1d_relu", std::move(conv1d_relu), quantized_conv1d_relu},
943       {"quantized::conv1d_relu",
944        std::move(conv1d_inplace_relu),
945        std::move(quantized_conv1d_relu)},
946       {"quantized::conv2d", std::move(conv2d), std::move(quantized_conv2d)},
947       {"quantized::conv2d_relu", std::move(conv2d_relu), quantized_conv2d_relu},
948       {"quantized::conv2d_relu",
949        std::move(conv2d_inplace_relu),
950        std::move(quantized_conv2d_relu)},
951       {"quantized::conv3d", std::move(conv3d), std::move(quantized_conv3d)},
952       {"quantized::conv3d_relu", std::move(conv3d_relu), quantized_conv3d_relu},
953       {"quantized::conv3d_relu",
954        std::move(conv3d_inplace_relu),
955        std::move(quantized_conv3d_relu)},
956       {"quantized::conv_transpose1d",
957        std::move(conv_transpose1d),
958        std::move(quantized_conv_transpose1d)},
959       {"quantized::conv_transpose2d",
960        std::move(conv_transpose2d),
961        std::move(quantized_conv_transpose2d)},
962       {"quantized::linear", std::move(linear), std::move(quantized_linear)},
963       {"quantized::linear_relu", std::move(linear_relu), quantized_linear_relu},
964       {"quantized::linear_relu",
965        std::move(linear_inplace_relu),
966        std::move(quantized_linear_relu)},
967       {"quantized::add_relu",
968        std::move(add_relu),
969        quantized_add_relu,
970        {aten_add_alpha_is_one}},
971       {"quantized::add_relu",
972        std::move(add_inplace_relu),
973        quantized_add_relu,
974        {aten_add_alpha_is_one}},
975       {"quantized::add_relu",
976        std::move(inplace_add_relu),
977        quantized_add_relu,
978        {aten_add_alpha_is_one}},
979       {"quantized::add_relu",
980        std::move(inplace_add_inplace_relu),
981        std::move(quantized_add_relu),
982        {aten_add_alpha_is_one}},
983       std::move(add_scalar),
984       std::move(add_scalar_out),
985       // note that these must come after quantized::add_scalar and
986       // quantized::add_scalar_out patterns
987       {"quantized::add_scalar_relu",
988        quantized_add_scalar_relu_pattern,
989        quantized_add_scalar_relu_replacement},
990       {"quantized::add_scalar_relu",
991        quantized_add_scalar_inplace_relu_pattern,
992        quantized_add_scalar_relu_replacement},
993       {"quantized::add_scalar_relu_out",
994        quantized_add_scalar_relu_out_pattern,
995        quantized_add_scalar_relu_out_replacement},
996       {"quantized::add_scalar_relu_out",
997        quantized_add_scalar_inplace_relu_out_pattern,
998        quantized_add_scalar_relu_out_replacement},
999       {"quantized::add",
1000        std::move(add),
1001        quantized_add,
1002        {aten_add_alpha_is_one}},
1003       {"quantized::add",
1004        std::move(inplace_add),
1005        std::move(quantized_add),
1006        {aten_add_alpha_is_one}},
1007       {"quantized::cat", std::move(cat), std::move(quantized_cat)},
1008       {"quantized::batch_norm",
1009        std::move(batch_norm),
1010        std::move(quantized_batch_norm)},
1011       {"quantized::batch_norm_relu",
1012        std::move(batch_norm_relu),
1013        quantized_batch_norm_relu},
1014       {"quantized::batch_norm_relu",
1015        std::move(batch_norm_inplace_relu),
1016        std::move(quantized_batch_norm_relu)},
1017       std::move(mul_scalar),
1018       std::move(mul_scalar_out),
1019       // note that these must come after quantized::mul_scalar and
1020       // quantized::mul_scalar_out patterns
1021       {"quantized::mul_scalar_relu",
1022        quantized_mul_scalar_relu_pattern,
1023        quantized_mul_scalar_relu_replacement},
1024       {"quantized::mul_scalar_relu",
1025        quantized_mul_scalar_inplace_relu_pattern,
1026        quantized_mul_scalar_relu_replacement},
1027       {"quantized::mul_scalar_relu_out",
1028        quantized_mul_scalar_relu_out_pattern,
1029        quantized_mul_scalar_relu_out_replacement},
1030       {"quantized::mul_scalar_relu_out",
1031        quantized_mul_scalar_inplace_relu_out_pattern,
1032        quantized_mul_scalar_relu_out_replacement},
1033       {"quantized::mul_relu", std::move(mul_relu), quantized_mul_relu},
1034       {"quantized::mul_relu", std::move(mul_inplace_relu), quantized_mul_relu},
1035       {"quantized::mul_relu", std::move(inplace_mul_relu), quantized_mul_relu},
1036       {"quantized::mul_relu",
1037        std::move(inplace_mul_inplace_relu),
1038        std::move(quantized_mul_relu)},
1039       {"quantized::mul", std::move(mul), quantized_mul},
1040       {"quantized::mul", std::move(inplace_mul), std::move(quantized_mul)},
1041       std::move(hardswish),
1042       std::move(hardswish_),
1043       std::move(layer_norm),
1044       std::move(group_norm),
1045       std::move(instance_norm),
1046       {"quantized::elu", std::move(elu), quantized_elu},
1047       {"quantized::elu_", std::move(elu_), std::move(quantized_elu)},
1048       std::move(avg_pool1d),
1049       std::move(avg_pool2d),
1050       std::move(avg_pool3d),
1051       std::move(adaptive_avg_pool1d),
1052       std::move(adaptive_avg_pool2d),
1053       std::move(adaptive_avg_pool3d),
1054       std::move(mean1),
1055       std::move(mean2),
1056       std::move(upsample_nearest1d),
1057       std::move(upsample_nearest2d),
1058       std::move(upsample_nearest3d),
1059       std::move(upsample_linear1d),
1060       std::move(upsample_bilinear2d),
1061       std::move(upsample_trilinear3d),
1062       std::move(upsample_nearest1d_vec),
1063       std::move(upsample_nearest2d_vec),
1064       std::move(upsample_nearest3d_vec),
1065       std::move(upsample_linear1d_vec),
1066       std::move(upsample_bilinear2d_vec),
1067       std::move(upsample_trilinear3d_vec),
1068       std::move(clamp),
1069       std::move(hardtanh),
1070       std::move(hardtanh_),
1071       std::move(leaky_relu),
1072       std::move(leaky_relu_),
1073       // fixed qparam ops
1074       std::move(hardsigmoid),
1075       std::move(hardsigmoid_),
1076       std::move(sigmoid),
1077       std::move(sigmoid_),
1078       std::move(tanh),
1079       std::move(tanh_),
1080   };
1081 }
1082 
1083 inline std::vector<QuantFusionInfo>
dynamic_quantized_linear_pattern_and_replacements()1084 dynamic_quantized_linear_pattern_and_replacements() {
1085   std::string linear_dynamic = R"(
1086 graph(%packed_params, %a):
1087         %w_quant : Tensor, %b : Tensor? = quantized::linear_unpack(%packed_params)
1088         %w_dequant = aten::dequantize(%w_quant)
1089         %r = aten::linear(%a, %w_dequant, %b)
1090         return (%r) )";
1091 
1092   // This pattern ignores reduce range
1093   // Set the reduce range to default to true, since qnnpack backend ignores this
1094   // argument.
1095   std::string quantized_linear_dynamic = R"(
1096 graph(%packed_params, %a):
1097         %reduce_range : bool = prim::Constant[value=1]()
1098         %r = quantized::linear_dynamic(%a, %packed_params, %reduce_range)
1099         return (%r) )";
1100 
1101   return {
1102       {"quantized::linear_dynamic",
1103        std::move(linear_dynamic),
1104        std::move(quantized_linear_dynamic)},
1105   };
1106 }
1107 
1108 static std::vector<QuantFusionInfo>
dynamic_quant_fusion_pattern_and_replacements()1109 dynamic_quant_fusion_pattern_and_replacements() {
1110   std::string linear_dynamic = R"(
1111 graph(%packed_params, %a, %reduce_range, %a_dtype):
1112         %a_scale : float, %a_zero_point : int = aten::_choose_qparams_per_tensor(%a, %reduce_range)
1113         %a_quant = aten::quantize_per_tensor(%a, %a_scale, %a_zero_point, %a_dtype)
1114         %a_dequant = aten::dequantize(%a_quant)
1115         %w_quant : Tensor, %b : Tensor? = quantized::linear_unpack(%packed_params)
1116         %w_dequant = aten::dequantize(%w_quant)
1117         %r = aten::linear(%a_dequant, %w_dequant, %b)
1118         return (%r) )";
1119 
1120   std::string quantized_linear_dynamic = R"(
1121 graph(%packed_params, %a, %reduce_range, %a_dtype):
1122         %r = quantized::linear_dynamic(%a, %packed_params, %reduce_range)
1123         return (%r) )";
1124 
1125   std::string linear_dynamic_fp16 = R"(
1126 graph(%packed_params, %a):
1127         %w_unpacked : Tensor, %b : Tensor? = quantized::linear_unpack_fp16(%packed_params)
1128         %r = aten::linear(%a, %w_unpacked, %b)
1129         return (%r) )";
1130 
1131   std::string quantized_linear_dynamic_fp16 = R"(
1132 graph(%packed_params, %a):
1133         %r = quantized::linear_dynamic_fp16(%a, %packed_params)
1134         return (%r) )";
1135 
1136   return {
1137       {"quantized::linear_dynamic",
1138        std::move(linear_dynamic),
1139        std::move(quantized_linear_dynamic)},
1140       {"quantized::linear_dynamic_fp16",
1141        std::move(linear_dynamic_fp16),
1142        std::move(quantized_linear_dynamic_fp16)},
1143   };
1144 }
1145 
linear_prepack_unpack_patterns()1146 static std::vector<QuantFusionInfo> linear_prepack_unpack_patterns() {
1147   std::string linear_with_quant = R"(
1148 graph(%a_dequant, %w_quant, %b):
1149         %w_dequant = aten::dequantize(%w_quant)
1150         %r = aten::linear(%a_dequant, %w_dequant, %b)
1151         return (%r) )";
1152 
1153   std::string linear_with_quant_prepack = R"(
1154 graph(%a_dequant, %w_quant, %b):
1155         %packed_params = quantized::linear_prepack(%w_quant, %b)
1156         %w_quant_unpacked : Tensor, %b_unpacked : Tensor? = quantized::linear_unpack(%packed_params)
1157         %w_dequant = aten::dequantize(%w_quant_unpacked)
1158         %r = aten::linear(%a_dequant, %w_dequant, %b_unpacked)
1159         return (%r) )";
1160   std::string linear_fp16_with_cast = R"(
1161 graph(%w, %a_dq, %b):
1162         %fp16_tensor = aten::_saturate_weight_to_fp16(%w)
1163         %r = aten::linear(%a_dq, %fp16_tensor, %b)
1164         return (%r) )";
1165   std::string linear_fp16_with_prepack = R"(
1166 graph(%w, %a_dq, %b):
1167         %packed_params = quantized::linear_prepack_fp16(%w, %b)
1168         %w_unpacked : Tensor, %b_unpacked : Tensor? = quantized::linear_unpack_fp16(%packed_params)
1169         %r = aten::linear(%a_dq, %w_unpacked, %b_unpacked)
1170         return (%r) )";
1171 
1172   return {
1173       {"linear_prepack_unpack",
1174        std::move(linear_with_quant),
1175        std::move(linear_with_quant_prepack)},
1176       {"linear_fp16_prepack_unpack",
1177        std::move(linear_fp16_with_cast),
1178        std::move(linear_fp16_with_prepack)},
1179   };
1180 }
1181 
conv_prepack_unpack_patterns()1182 static std::vector<QuantFusionInfo> conv_prepack_unpack_patterns() {
1183   std::string conv1d_with_quant = R"(
1184 graph(%a_dequant, %w_quant, %b, %stride, %padding, %dilation, %groups):
1185         %w_dequant = aten::dequantize(%w_quant)
1186         %r = aten::conv1d(%a_dequant, %w_dequant, %b, %stride, %padding, %dilation, %groups)
1187         return (%r) )";
1188 
1189   std::string conv1d_with_quant_prepack = R"(
1190 graph(%a_dequant, %w_quant, %b, %stride, %padding, %dilation, %groups):
1191         %packed_params : __torch__.torch.classes.quantized.Conv2dPackedParamsBase = quantized::conv1d_prepack(%w_quant, %b, %stride, %padding, %dilation, %groups)
1192         %w_quant_unpacked : Tensor, %b_unpacked : Tensor? = quantized::conv1d_unpack(%packed_params)
1193         %w_dequant = aten::dequantize(%w_quant_unpacked)
1194         %r = aten::conv1d(%a_dequant, %w_dequant, %b_unpacked, %stride, %padding, %dilation, %groups)
1195         return (%r) )";
1196 
1197   std::string conv2d_with_quant = R"(
1198 graph(%a_dequant, %w_quant, %b, %stride, %padding, %dilation, %groups):
1199         %w_dequant = aten::dequantize(%w_quant)
1200         %r = aten::conv2d(%a_dequant, %w_dequant, %b, %stride, %padding, %dilation, %groups)
1201         return (%r) )";
1202 
1203   std::string conv2d_with_quant_prepack = R"(
1204 graph(%a_dequant, %w_quant, %b, %stride, %padding, %dilation, %groups):
1205         %packed_params : __torch__.torch.classes.quantized.Conv2dPackedParamsBase = quantized::conv2d_prepack(%w_quant, %b, %stride, %padding, %dilation, %groups)
1206         %w_quant_unpacked : Tensor, %b_unpacked : Tensor? = quantized::conv2d_unpack(%packed_params)
1207         %w_dequant = aten::dequantize(%w_quant_unpacked)
1208         %r = aten::conv2d(%a_dequant, %w_dequant, %b_unpacked, %stride, %padding, %dilation, %groups)
1209         return (%r) )";
1210 
1211   std::string conv3d_with_quant = R"(
1212 graph(%a_dequant, %w_quant, %b, %stride, %padding, %dilation, %groups):
1213         %w_dequant = aten::dequantize(%w_quant)
1214         %r = aten::conv3d(%a_dequant, %w_dequant, %b, %stride, %padding, %dilation, %groups)
1215         return (%r) )";
1216 
1217   std::string conv3d_with_quant_prepack = R"(
1218 graph(%a_dequant, %w_quant, %b, %stride, %padding, %dilation, %groups):
1219         %packed_params : __torch__.torch.classes.quantized.Conv3dPackedParamsBase = quantized::conv3d_prepack(%w_quant, %b, %stride, %padding, %dilation, %groups)
1220         %w_quant_unpacked : Tensor, %b_unpacked : Tensor? = quantized::conv3d_unpack(%packed_params)
1221         %w_dequant = aten::dequantize(%w_quant_unpacked)
1222         %r = aten::conv3d(%a_dequant, %w_dequant, %b_unpacked, %stride, %padding, %dilation, %groups)
1223         return (%r) )";
1224 
1225   std::string conv_transpose1d_with_quant = R"(
1226 graph(%a_dequant, %w_quant, %b, %stride, %padding, %output_padding, %groups, %dilation):
1227         %w_dequant = aten::dequantize(%w_quant)
1228         %r = aten::conv_transpose1d(%a_dequant, %w_dequant, %b, %stride, %padding, %output_padding, %groups, %dilation)
1229         return (%r) )";
1230 
1231   std::string conv_transpose1d_with_quant_prepack = R"(
1232 graph(%a_dequant, %w_quant, %b, %stride, %padding, %output_padding, %groups, %dilation):
1233         %packed_params : __torch__.torch.classes.quantized.Conv2dPackedParamsBase = quantized::conv_transpose1d_prepack(%w_quant, %b, %stride, %padding, %output_padding, %dilation, %groups)
1234         %w_quant_unpacked : Tensor, %b_unpacked : Tensor? = quantized::conv_transpose1d_unpack(%packed_params)
1235         %w_dequant = aten::dequantize(%w_quant_unpacked)
1236         %r = aten::conv_transpose1d(%a_dequant, %w_dequant, %b_unpacked, %stride, %padding, %output_padding, %groups, %dilation)
1237         return (%r) )";
1238 
1239   std::string conv_transpose2d_with_quant = R"(
1240 graph(%a_dequant, %w_quant, %b, %stride, %padding, %output_padding, %groups, %dilation):
1241         %w_dequant = aten::dequantize(%w_quant)
1242         %r = aten::conv_transpose2d(%a_dequant, %w_dequant, %b, %stride, %padding, %output_padding, %groups, %dilation)
1243         return (%r) )";
1244 
1245   std::string conv_transpose2d_with_quant_prepack = R"(
1246 graph(%a_dequant, %w_quant, %b, %stride, %padding, %output_padding, %groups, %dilation):
1247         %packed_params : __torch__.torch.classes.quantized.Conv2dPackedParamsBase = quantized::conv_transpose2d_prepack(%w_quant, %b, %stride, %padding, %output_padding, %dilation, %groups)
1248         %w_quant_unpacked : Tensor, %b_unpacked : Tensor? = quantized::conv_transpose2d_unpack(%packed_params)
1249         %w_dequant = aten::dequantize(%w_quant_unpacked)
1250         %r = aten::conv_transpose2d(%a_dequant, %w_dequant, %b_unpacked, %stride, %padding, %output_padding, %groups, %dilation)
1251         return (%r) )";
1252 
1253   return {
1254       {"conv1d_prepack_unpack",
1255        std::move(conv1d_with_quant),
1256        std::move(conv1d_with_quant_prepack)},
1257       {"conv2d_prepack_unpack",
1258        std::move(conv2d_with_quant),
1259        std::move(conv2d_with_quant_prepack)},
1260       {"conv3d_prepack_unpack",
1261        std::move(conv3d_with_quant),
1262        std::move(conv3d_with_quant_prepack)},
1263       {"conv_transpose1d_prepack_unpack",
1264        std::move(conv_transpose1d_with_quant),
1265        std::move(conv_transpose1d_with_quant_prepack)},
1266       {"conv_transpose2d_prepack_unpack",
1267        std::move(conv_transpose2d_with_quant),
1268        std::move(conv_transpose2d_with_quant_prepack)}};
1269 }
1270 
1271 } // namespace jit
1272 } // namespace torch
1273