xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/passes/onnx/unpack_quantized_weights.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/passes/onnx/unpack_quantized_weights.h>
2 
3 #include <ATen/native/quantized/PackedParams.h>
4 #include <c10/util/irange.h>
5 #include <torch/csrc/jit/ir/constants.h>
6 #include <torch/csrc/jit/ir/irparser.h>
7 #include <torch/csrc/jit/ir/subgraph_matcher.h>
8 #include <torch/csrc/jit/jit_log.h>
9 #include <torch/csrc/jit/passes/onnx/helper.h>
10 #include <torch/csrc/jit/passes/subgraph_rewrite.h>
11 
12 // TODO: Switch to per operator headers after
13 // https://github.com/pytorch/pytorch/pull/68693 is merged
14 #include <ATen/Functions.h>
15 
16 using ::c10::Dispatcher;
17 
18 namespace torch::jit {
19 namespace onnx {
20 using namespace ::c10::onnx;
21 
22 }
23 
24 // Get the scale of the input to quantized op. There are two cases here
25 // 1. For ops with output_scale specified in op signature, we get the output
26 // scale
27 // 2. For ops with no output scale in op signature (like quantized::relu)
28 // we traverse up the graph to get the scale from its input until we hit a node
29 // where scale is explicitly specified.
getScaleFromInput(Node * input_node)30 double getScaleFromInput(Node* input_node) {
31   std::optional<IValue> scale;
32   std::string input_name = input_node->kind().toQualString();
33   std::unordered_set<std::string> noscale_ops = {
34       "quantized::max_pool2d",
35       "aten::max_pool2d",
36       "aten::relu",
37       "prim::ListUnpack",
38       "aten::split_with_sizes",
39       "quantized::nchw2nhwc",
40       "quantized::nhwc2nchw",
41       "aten::slice",
42       "aten::avg_pool2d",
43       "quantized::cat",
44       "prim::ListConstruct",
45       "aten::upsample_nearest2d",
46       "aten::sigmoid",
47       "aten::reshape"};
48   if (input_name == "aten::quantize_per_tensor") {
49     TORCH_CHECK(
50         input_node->inputs().size() > 1,
51         "aten::quantize_per_tensor expected scale to be 2nd input");
52     scale = toIValue(input_node->inputs()[1]);
53     return scale.value().toDouble();
54   } else if (input_name == "quantized::linear") {
55     // %r = quantized::linear(%input, %packed_weight, %w_scale, %w_zero_point)
56     TORCH_CHECK(
57         input_node->inputs().size() > 2,
58         "quantized::linear expected scale to be 3rd input");
59     scale = toIValue(input_node->inputs()[2]);
60     return scale.value().toDouble();
61   } else if (input_name == "quantized::conv2d") {
62     // %r = quantized::conv2d(%input, %packed_weight, %w_scale, %w_zero_point)
63     TORCH_CHECK(
64         input_node->inputs().size() > 2,
65         "quantized::conv2d expected scale to be 3rd input");
66     auto num_inputs = input_node->inputs().size();
67     scale = toIValue(input_node->inputs()[num_inputs - 2]);
68     return scale.value().toDouble();
69   } else if (input_name == "quantized::conv2d_relu") {
70     // %r = quantized::conv2d_relu(%input, %packed_weight, %w_scale,
71     // %w_zero_point)
72     TORCH_CHECK(
73         input_node->inputs().size() > 2,
74         "quantized::conv2d_relu expected scale to be 3rd input");
75     auto num_inputs = input_node->inputs().size();
76     scale = toIValue(input_node->inputs()[num_inputs - 2]);
77     return scale.value().toDouble();
78   } else if (input_name == "quantized::add") {
79     // %r = quantized::add(%input_a, %input_b, %w_scale, %w_zero_point)
80     TORCH_CHECK(
81         input_node->inputs().size() > 2,
82         "quantized::add expected scale to be 3rd input");
83     scale = toIValue(input_node->inputs()[2]);
84     return scale.value().toDouble();
85   } else if (input_name == "aten::sigmoid") {
86     // For the _caffe2::Int8Sigmoid op output scale is 1.0/256
87     // And output zero_point is set to 0 (quint8 type).
88     return 1.0L / 256;
89   }
90   // For the ops below the scale is not part of the op signature, so we traverse
91   // up the graph to get the scale from its input when defined in the graph.
92   else if (noscale_ops.find(input_name) != noscale_ops.end()) {
93     return getScaleFromInput(input_node->inputs()[0]->node());
94   }
95   TORCH_INTERNAL_ASSERT(
96       false,
97       "Unrecognized quantized operator while trying to compute q_scale for operator ",
98       input_name);
99 }
100 
CreateQuantizedWeights(std::shared_ptr<Graph> & graph,const at::Tensor & weight,int8_t * data,const std::vector<int64_t> & shapes,const std::vector<int64_t> & strides)101 std::vector<Node*> CreateQuantizedWeights(
102     std::shared_ptr<Graph>& graph,
103     const at::Tensor& weight,
104     int8_t* data,
105     const std::vector<int64_t>& shapes,
106     const std::vector<int64_t>& strides) {
107   auto qscheme = weight.qscheme();
108   std::vector<Node*> unpacked_wt;
109 
110   // Retrieve scales and zero_points. Their formats are different depending on
111   // different weight qscheme.
112   std::vector<float> scale_data;
113   std::vector<int64_t> scale_shapes;
114   std::vector<int64_t> zero_point_data;
115   std::vector<int64_t> zero_point_shapes;
116   std::vector<int64_t> axis_data;
117   switch (qscheme) {
118     case c10::kPerTensorAffine: {
119       // Cast to float since ONNX (De)QuantizeLinear only supports float scale.
120       scale_data = {static_cast<float>(weight.q_scale())};
121       scale_shapes = {1};
122       zero_point_data = {weight.q_zero_point()};
123       zero_point_shapes = {1};
124       break;
125     }
126     case c10::kPerChannelAffine:
127     case c10::kPerChannelAffineFloatQParams: {
128       auto q_scales = weight.q_per_channel_scales();
129       auto* scale_data_raw = q_scales.const_data_ptr<double>();
130       scale_shapes = q_scales.sizes().vec();
131       TORCH_INTERNAL_ASSERT(
132           scale_shapes.size() == 1,
133           "quantized per channel scales are expected as 1-d array.");
134       scale_data.resize(scale_shapes[0]);
135       // Cast to float since ONNX (De)QuantizeLinear only supports float scale.
136       std::transform(
137           scale_data_raw,
138           scale_data_raw + scale_shapes[0],
139           scale_data.begin(),
140           [](double x) { return static_cast<float>(x); });
141 
142       auto q_zero_points = weight.q_per_channel_zero_points();
143       auto* zero_point_data_raw = q_zero_points.const_data_ptr<int64_t>();
144       zero_point_shapes = q_zero_points.sizes().vec();
145       TORCH_INTERNAL_ASSERT(
146           zero_point_shapes.size() == 1,
147           "quantized per channel zero points are expected as 1-d array.");
148       zero_point_data = std::vector<int64_t>(
149           zero_point_data_raw, zero_point_data_raw + zero_point_shapes[0]);
150       axis_data = {weight.q_per_channel_axis()};
151       break;
152     }
153     default:
154       TORCH_CHECK(
155           false, "Unsupported qscheme for weight, got ", toString(qscheme));
156   }
157 
158   Node* data_node = graph->create(prim::Constant);
159   auto data_value =
160       at::from_blob(
161           data, c10::IntArrayRef(shapes), c10::IntArrayRef(strides), at::kChar)
162           .to(at::kCPU);
163   // Need clone because at::from_blob does not take ownership of data.
164   data_node->t_(Symbol::attr("value"), data_value.clone());
165 
166   Node* scale_node = graph->create(prim::Constant);
167   auto scale_value =
168       at::from_blob(
169           scale_data.data(), c10::IntArrayRef(scale_shapes), at::kFloat)
170           .to(at::kCPU);
171   scale_node->t_(Symbol::attr("value"), scale_value.clone());
172 
173   Node* zero_point_node = graph->create(prim::Constant);
174   auto zero_point_value =
175       at::from_blob(
176           zero_point_data.data(), c10::IntArrayRef(zero_point_shapes), at::kInt)
177           .to(at::kCPU);
178   zero_point_node->t_(Symbol::attr("value"), zero_point_value.clone());
179 
180   Node* axis_node = graph->create(prim::Constant);
181   if (!axis_data.empty()) {
182     auto axis_value =
183         at::from_blob(
184             axis_data.data(), c10::IntArrayRef(axis_data.size()), at::kLong)
185             .to(at::kCPU);
186     axis_node->t_(attr::value, axis_value.clone());
187   } else {
188     axis_node->output()->setType(NoneType::get());
189   }
190 
191   return {data_node, scale_node, zero_point_node, axis_node};
192 }
193 
CreateQuantizedBias(std::vector<float> data,std::shared_ptr<Graph> & graph,const std::vector<int64_t> & shapes)194 Node* CreateQuantizedBias(
195     std::vector<float> data,
196     std::shared_ptr<Graph>& graph,
197     const std::vector<int64_t>& shapes) {
198   Node* const_node_1 = graph->create(prim::Constant);
199   auto const_bias =
200       at::from_blob(data.data(), c10::IntArrayRef(shapes), at::kFloat)
201           .to(at::kCPU);
202   auto options = c10::TensorOptions().dtype(at::kFloat).device(at::kCPU);
203   at::Tensor const_bias_copy = at::empty(c10::IntArrayRef(shapes), options);
204   const_bias_copy.copy_(const_bias);
205   const_node_1->t_(Symbol::attr("value"), const_bias_copy);
206   return const_node_1;
207 }
208 
createIntTuple(const std::vector<int64_t> & is,std::shared_ptr<Graph> & graph)209 Node* createIntTuple(
210     const std::vector<int64_t>& is,
211     std::shared_ptr<Graph>& graph) {
212   Node* const_node = graph->create(Symbol::onnx("Constant"));
213   const_node->is_(Symbol::attr("value"), is);
214   return const_node;
215 }
216 
createInt(int64_t i,std::shared_ptr<Graph> & graph)217 Node* createInt(int64_t i, std::shared_ptr<Graph>& graph) {
218   Node* const_node = graph->create(Symbol::onnx("Constant"));
219   const_node->i_(Symbol::attr("value"), i);
220   return const_node;
221 }
222 
ConvertQuantizedWeight(std::shared_ptr<Graph> & graph,Node * node,at::Tensor & weight)223 void ConvertQuantizedWeight(
224     std::shared_ptr<Graph>& graph,
225     Node* node,
226     at::Tensor& weight) {
227   std::vector<int64_t> wt_sizes = weight.sizes().vec();
228   std::vector<int64_t> wt_strides = weight.strides().vec();
229   // Remove packed_params
230   node->removeInput(1);
231 
232   auto* wt_data =
233       reinterpret_cast<int8_t*>(weight.mutable_data_ptr<c10::qint8>());
234 
235   std::vector<Node*> unpacked_wt =
236       CreateQuantizedWeights(graph, weight, wt_data, wt_sizes, wt_strides);
237   graph->setInsertPoint(node);
238   Node* quant_node = graph->create(prim::TupleConstruct);
239   for (auto* n : unpacked_wt) {
240     n->insertBefore(node);
241     quant_node->addInput(n->output());
242   }
243   quant_node->insertBefore(node);
244   node->insertInput(1, quant_node->output());
245 }
246 
247 // CONV1D needs a different unpacking from CONV, since it's
248 // packed as CONV2D intentionally at the first place.
249 // See: https://github.com/pytorch/pytorch/pull/38248
250 enum class QuantizedParamsType { CONV1D, CONV, LINEAR };
251 
252 // This is called before the onnx pass. Using pattern matching we
253 // find the relevant nodes and extract the packed_params. The packed_params are
254 // passed to the appropriate unpack function using c10::Dispatcher. We insert
255 // the unpacked weights and bias into the graph using
256 // caffe2::Int8GivenTensorFill nodes.
unpackQuantizedWeightsHelper(std::shared_ptr<Graph> & graph,std::map<std::string,IValue> & paramsDict,const std::string & pattern,const std::string & unpack_fn,QuantizedParamsType params_type,bool expect_output_padding=false)257 void unpackQuantizedWeightsHelper(
258     std::shared_ptr<Graph>& graph,
259     std::map<std::string, IValue>& paramsDict,
260     const std::string& pattern,
261     const std::string& unpack_fn,
262     QuantizedParamsType params_type,
263     bool expect_output_padding = false) {
264   Graph pattern_graph;
265   std::unordered_map<std::string, Value*> vmap;
266   parseIR(pattern, &pattern_graph, vmap);
267   const auto& matches = findPatternMatches(pattern_graph, *graph);
268 
269   for (const auto& match : matches) {
270     auto match_vmap = match.values_map;
271     auto qlinear_node = match_vmap.at(vmap.at("r"))->node();
272     std::string quantized_weight =
273         match_vmap.at(vmap.at("r"))->node()->inputs()[1]->debugName();
274 
275     auto itr = paramsDict.find(quantized_weight);
276     if (itr == paramsDict.end()) {
277       throw std::runtime_error(
278           "getValues: Quantized weight value not found amongst constant parameters.");
279     }
280     at::Tensor unpacked_weight;
281     std::optional<at::Tensor> bias;
282     constexpr int64_t stride_idx = 2;
283     constexpr int64_t padding_idx = 3;
284     int64_t output_padding_idx = 0;
285     int64_t dilation_idx = 0;
286     int64_t groups_idx = 0;
287     if (expect_output_padding) {
288       output_padding_idx = 4;
289       dilation_idx = 5;
290       groups_idx = 6;
291     } else {
292       dilation_idx = 4;
293       groups_idx = 5;
294     }
295     std::optional<torch::List<int64_t>> stride, padding, dilation,
296         output_padding;
297     std::optional<int64_t> groups;
298     std::optional<int64_t> transpose;
299 
300     torch::List<int64_t> stride_int, padding_int, dilation_int,
301         output_padding_int;
302 
303     if (itr->second.isTuple()) {
304       // Pre-unpacked weights. Comes from Conv/Linear weights which are
305       // stored as bound C++ classes.
306       auto ser_tup = itr->second.toTuple();
307 
308       if (params_type == QuantizedParamsType::CONV &&
309           ser_tup->elements()[0].isInt()) {
310         const auto& elements = ser_tup->elements();
311         auto version = elements[0].toInt();
312         TORCH_INTERNAL_ASSERT(version == 3, "Unknown serialization version");
313         TORCH_INTERNAL_ASSERT(elements.size() == 3, "Wrong tuple size.");
314 
315         auto config_vals = elements[1].to<std::vector<int64_t>>();
316         auto tensors = elements[2].to<std::vector<std::optional<at::Tensor>>>();
317 
318         std::optional<at::Tensor> weight = tensors[1];
319         TORCH_INTERNAL_ASSERT(
320             weight, "Weight should always be present in serialized qconv.");
321         unpacked_weight = *weight;
322         bias = tensors[2];
323 
324         const int64_t kSpatialDim = config_vals.at(0);
325         // skip kSpatialDim
326         unsigned idx = 1;
327         for (const auto i : c10::irange(kSpatialDim)) {
328           (void)i; // Suppress unused variable warning
329           stride_int.emplace_back(config_vals.at(idx));
330           idx++;
331         }
332         for (const auto i : c10::irange(kSpatialDim)) {
333           (void)i; // Suppress unused variable warning
334           padding_int.emplace_back(config_vals.at(idx));
335           idx++;
336         }
337         for (const auto i : c10::irange(kSpatialDim)) {
338           (void)i; // Suppress unused variable warning
339           dilation_int.emplace_back(config_vals.at(idx));
340           idx++;
341         }
342         for (const auto i : c10::irange(kSpatialDim)) {
343           (void)i; // Suppress unused variable warning
344           output_padding_int.emplace_back(config_vals.at(idx));
345           idx++;
346         }
347         int64_t groups_int = config_vals.at(idx);
348         idx++;
349         int64_t flags = config_vals.at(idx);
350         idx++;
351         TORCH_INTERNAL_ASSERT(
352             idx == config_vals.size(),
353             "Unexpected length of config_vals, expected ",
354             idx,
355             " got ",
356             config_vals.size());
357 
358         bool transpose_int = flags & (1 << 0);
359 
360         int64_t other_flags = flags & ~(1 << 0);
361         TORCH_CHECK(other_flags == 0, "Unexpected flags set in ", flags, ".");
362 
363         stride = stride_int;
364         padding = padding_int;
365         dilation = dilation_int;
366         groups = groups_int;
367         transpose = transpose_int;
368         if (expect_output_padding) {
369           output_padding = output_padding_int;
370         }
371       } else if (
372           (params_type == QuantizedParamsType::CONV ||
373            params_type == QuantizedParamsType::CONV1D) &&
374           ser_tup->elements()[0].isString()) {
375         const auto& elements = ser_tup->elements();
376         auto version = elements[0].toStringRef();
377         TORCH_INTERNAL_ASSERT(version == "2", "Unknown serialization version");
378         std::vector<at::Tensor> non_optional = elements[1].toTensorVector();
379 
380         at::Tensor conv_params_packed = non_optional[0];
381         unpacked_weight = non_optional[1];
382 
383         const int64_t kSpatialDim = conv_params_packed[0].item<int64_t>();
384         // skip kSpatialDim
385         int64_t idx = 1;
386         // kSpatialDim = 2 even it's for Conv1D from torch.op to adopt Conv2D,
387         // so we need a special unpack for Conv1D which has Conv2D dim.
388         // See: https://github.com/pytorch/pytorch/pull/38248
389         for (const auto i : c10::irange(kSpatialDim)) {
390           if (params_type != QuantizedParamsType::CONV1D || i != 0) {
391             stride_int.emplace_back(conv_params_packed[idx].item<int64_t>());
392           }
393           idx++;
394         }
395         for (const auto i : c10::irange(kSpatialDim)) {
396           if (params_type != QuantizedParamsType::CONV1D || i != 0) {
397             padding_int.emplace_back(conv_params_packed[idx].item<int64_t>());
398           }
399           idx++;
400         }
401         for (const auto i : c10::irange(kSpatialDim)) {
402           if (params_type != QuantizedParamsType::CONV1D || i != 0) {
403             dilation_int.emplace_back(conv_params_packed[idx].item<int64_t>());
404           }
405           idx++;
406         }
407         for (const auto i : c10::irange(kSpatialDim)) {
408           if (params_type != QuantizedParamsType::CONV1D || i != 0) {
409             output_padding_int.emplace_back(
410                 conv_params_packed[idx].item<int64_t>());
411           }
412           idx++;
413         }
414         auto groups_int = conv_params_packed[idx].item<int64_t>();
415         idx++;
416         auto transpose_int = conv_params_packed[idx].item<int64_t>();
417         idx++;
418         TORCH_INTERNAL_ASSERT(
419             idx == conv_params_packed.numel(),
420             "Unexpected length of conv_params_packed, expected ",
421             idx,
422             " got ",
423             conv_params_packed.numel());
424 
425         torch::List<c10::IValue> optional = elements[2].toList();
426         bias = optional.get(0).toOptional<at::Tensor>();
427 
428         if (params_type == QuantizedParamsType::CONV1D) {
429           unpacked_weight = unpacked_weight.squeeze_(2);
430         }
431         stride = stride_int;
432         padding = padding_int;
433         dilation = dilation_int;
434         groups = groups_int;
435         transpose = transpose_int;
436         if (expect_output_padding) {
437           output_padding = output_padding_int;
438         }
439       } else { // Legacy
440         unpacked_weight = ser_tup->elements()[0].toTensor();
441         bias = ser_tup->elements()[1].toOptional<at::Tensor>();
442         // conv only parameters
443         if (ser_tup->elements().size() > 2) {
444           auto stride_ivalue = ser_tup->elements()[stride_idx].toListRef();
445           auto padding_ivalue = ser_tup->elements()[padding_idx].toListRef();
446           auto dilation_ivalue = ser_tup->elements()[dilation_idx].toListRef();
447           auto groups_ivalue = ser_tup->elements()[groups_idx];
448 
449           for (const auto& s : stride_ivalue) {
450             stride_int.emplace_back(s.toTensor()[0].item<int64_t>());
451           }
452           for (const auto& p : padding_ivalue) {
453             padding_int.emplace_back(p.toTensor()[0].item<int64_t>());
454           }
455           for (const auto& d : dilation_ivalue) {
456             dilation_int.emplace_back(d.toTensor()[0].item<int64_t>());
457           }
458           groups = groups_ivalue.toTensor()[0].item<int64_t>();
459           stride = stride_int;
460           padding = padding_int;
461           dilation = dilation_int;
462 
463           if (expect_output_padding) {
464             auto output_padding_ivalue =
465                 ser_tup->elements()[output_padding_idx].toListRef();
466             for (const auto& d : output_padding_ivalue) {
467               output_padding_int.emplace_back(d.toTensor()[0].item<int64_t>());
468             }
469             output_padding = output_padding_int;
470           }
471         }
472       }
473     } else {
474       TORCH_INTERNAL_ASSERT(itr->second.isTensor());
475       at::Tensor packed_weight = itr->second.toTensor();
476       auto op = Dispatcher::singleton()
477                     .findSchemaOrThrow(unpack_fn.c_str(), "")
478                     .typed<std::tuple<at::Tensor, std::optional<at::Tensor>>(
479                         at::Tensor)>();
480       std::tie(unpacked_weight, bias) = op.call(packed_weight);
481     }
482 
483     ConvertQuantizedWeight(graph, qlinear_node, unpacked_weight);
484 
485     // Add bias
486     at::Tensor original_bias;
487     if (bias.has_value()) {
488       original_bias = bias.value();
489       original_bias.set_requires_grad(false);
490     } else {
491       int64_t bias_size = unpacked_weight.size(0);
492       original_bias =
493           at::zeros(bias_size, unpacked_weight.options().dtype(at::kFloat));
494     }
495 
496     auto input_val = match_vmap.at(vmap.at("r"))->node()->inputs()[0];
497     TORCH_INTERNAL_ASSERT(
498         input_val->type()->isSubtypeOf(*TensorType::get()),
499         "Unsupported input type. Expected TensorType, got ",
500         input_val->type()->str());
501 
502     std::vector<float> bias_values(original_bias.numel());
503     auto bias_data = original_bias.const_data_ptr<float>();
504     for (const auto i : c10::irange(original_bias.numel())) {
505       bias_values[i] = bias_data[i];
506     }
507     Node* bias_node =
508         CreateQuantizedBias(bias_values, graph, original_bias.sizes().vec());
509     bias_node->insertBefore(qlinear_node);
510     // For quantized_linear inputs, the order is input, weight, bias, ....
511     // Therefore bias is at location 2.
512     qlinear_node->insertInput(2, bias_node->output());
513 
514     // add conv arguments: stride, padding, dilation, groups, output_padding
515     if (stride.has_value() && padding.has_value() && dilation.has_value() &&
516         groups.has_value() &&
517         (!expect_output_padding || output_padding.has_value())) {
518       std::vector<std::optional<torch::List<int64_t>>> conv_ints_args;
519       conv_ints_args.push_back(stride);
520       conv_ints_args.push_back(padding);
521       if (expect_output_padding) {
522         conv_ints_args.push_back(output_padding);
523       }
524       conv_ints_args.push_back(dilation);
525       // skip (input, weight, bias)
526       const size_t arg_offset = 3;
527       for (const auto i : c10::irange(conv_ints_args.size())) {
528         Node* ints_node =
529             createIntTuple(conv_ints_args[i].value().vec(), graph);
530         ints_node->insertBefore(qlinear_node);
531         qlinear_node->insertInput(arg_offset + i, ints_node->output());
532       }
533       Node* groups_node = createInt(groups.value(), graph);
534       groups_node->insertBefore(qlinear_node);
535       qlinear_node->insertInput(groups_idx + 1, groups_node->output());
536     }
537     auto b = graph->block();
538     auto valsToParamsMap = buildValueToParamsMap(b, paramsDict);
539     eraseUnusedValuesFromMap(valsToParamsMap);
540   }
541 }
542 
543 static std::
544     unordered_map<c10::ScalarType, c10::ScalarType, ScalarTypeHashFunction>
545         qTypeToValType = {
546             {c10::ScalarType::QInt8, c10::ScalarType::Char},
547             {c10::ScalarType::QUInt8, c10::ScalarType::Byte},
548             {c10::ScalarType::QInt32, c10::ScalarType::Int},
549             {c10::ScalarType::QUInt4x2, c10::ScalarType::Byte},
550 };
551 
552 // Unpack quantized tensor inputs into {value, scale, zero_point},
553 // Then create a prim::TupleConstruct node based on these three values.
UnpackQuantizedTensorInputs(std::shared_ptr<Graph> & graph)554 void UnpackQuantizedTensorInputs(std::shared_ptr<Graph>& graph) {
555   for (size_t index = 0; index < graph->inputs().size();) {
556     auto g_input = graph->inputs()[index];
557     TensorTypePtr shape_type = g_input->type()->cast<TensorType>();
558     if (!shape_type || !shape_type->scalarType().has_value()) {
559       index++;
560       continue;
561     }
562     auto scalar_type = shape_type->scalarType().value();
563     if (qTypeToValType.find(scalar_type) == qTypeToValType.end()) {
564       index++;
565       continue;
566     }
567     std::string input_name = g_input->debugName();
568     auto input_value =
569         graph->insertInput(index, input_name + "_value")
570             ->setType(shape_type->withScalarType(qTypeToValType[scalar_type]));
571     // scale and zero_point type can be found at torch/include/ATen/Operators.h
572     auto input_scale =
573         graph->insertInput(index + 1, input_name + "_scale")
574             ->setType(TensorType::create(
575                 at::kDouble, at::kCPU, 0, /*requires_grad=*/std::nullopt));
576     auto input_zero_point =
577         graph->insertInput(index + 2, input_name + "_zero_point")
578             ->setType(TensorType::create(
579                 at::kLong, at::kCPU, 0, /*requires_grad=*/std::nullopt));
580     std::vector<Value*> converted{input_value, input_scale, input_zero_point};
581     auto input_tuple =
582         graph->prependNode(graph->createTuple(converted))->output();
583     g_input->replaceAllUsesWith(input_tuple);
584     // Erase the original quantized tensor input.
585     graph->eraseInput(index + converted.size());
586     index += 3;
587   }
588 }
589 
590 // https://github.com/pytorch/pytorch/wiki/PyTorch-ONNX-exporter#quantized-model-export
UnpackQuantizedWeights(std::shared_ptr<Graph> & graph,std::map<std::string,IValue> & paramsDict)591 void UnpackQuantizedWeights(
592     std::shared_ptr<Graph>& graph,
593     std::map<std::string, IValue>& paramsDict) {
594   std::string qlinear = R"(
595   graph(%input, %packed_weight, %w_scale, %w_zero_point):
596         %r = quantized::linear(%input, %packed_weight, %w_scale, %w_zero_point)
597         return (%r) )";
598   std::string qlinear_relu = R"(
599   graph(%input, %packed_weight, %w_scale, %w_zero_point):
600         %r = quantized::linear_relu(%input, %packed_weight, %w_scale, %w_zero_point)
601         return (%r) )";
602   std::string qconv1d = R"(
603   graph(%input, %packed_params, %scale, %zero_point):
604         %r = quantized::conv1d(%input, %packed_params, %scale, %zero_point)
605         return (%r) )";
606   std::string qconv1d_relu = R"(
607   graph(%input, %packed_params, %scale, %zero_point):
608         %r = quantized::conv1d_relu(%input, %packed_params, %scale, %zero_point)
609         return (%r) )";
610   std::string qconv2d = R"(
611   graph(%input, %packed_params, %scale, %zero_point):
612         %r = quantized::conv2d(%input, %packed_params, %scale, %zero_point)
613         return (%r) )";
614   std::string qconv2d_relu = R"(
615   graph(%input, %packed_params, %scale, %zero_point):
616         %r = quantized::conv2d_relu(%input, %packed_params, %scale, %zero_point)
617         return (%r) )";
618   std::string qconv3d = R"(
619   graph(%input, %packed_params, %scale, %zero_point):
620         %r = quantized::conv3d(%input, %packed_params, %scale, %zero_point)
621         return (%r) )";
622   std::string qconv3d_relu = R"(
623   graph(%input, %packed_params, %scale, %zero_point):
624         %r = quantized::conv3d_relu(%input, %packed_params, %scale, %zero_point)
625         return (%r) )";
626   std::string qconv_transpose1d = R"(
627   graph(%input, %packed_params, %scale, %zero_point):
628         %r = quantized::conv_transpose1d(%input, %packed_params, %scale, %zero_point)
629         return (%r) )";
630   std::string qconv_transpose2d = R"(
631   graph(%input, %packed_params, %scale, %zero_point):
632         %r = quantized::conv_transpose2d(%input, %packed_params, %scale, %zero_point)
633         return (%r) )";
634   std::string qconv_transpose3d = R"(
635   graph(%input, %packed_params, %scale, %zero_point):
636         %r = quantized::conv_transpose3d(%input, %packed_params, %scale, %zero_point)
637         return (%r) )";
638   unpackQuantizedWeightsHelper(
639       graph,
640       paramsDict,
641       qlinear,
642       "quantized::linear_unpack",
643       QuantizedParamsType::LINEAR);
644   unpackQuantizedWeightsHelper(
645       graph,
646       paramsDict,
647       qlinear_relu,
648       "quantized::linear_unpack",
649       QuantizedParamsType::LINEAR);
650   unpackQuantizedWeightsHelper(
651       graph,
652       paramsDict,
653       qconv1d,
654       "quantized::conv1d_unpack",
655       QuantizedParamsType::CONV1D);
656   unpackQuantizedWeightsHelper(
657       graph,
658       paramsDict,
659       qconv2d,
660       "quantized::conv2d_unpack",
661       QuantizedParamsType::CONV);
662   unpackQuantizedWeightsHelper(
663       graph,
664       paramsDict,
665       qconv1d_relu,
666       "quantized::conv1d_unpack",
667       QuantizedParamsType::CONV1D);
668   unpackQuantizedWeightsHelper(
669       graph,
670       paramsDict,
671       qconv2d_relu,
672       "quantized::conv2d_unpack",
673       QuantizedParamsType::CONV);
674   unpackQuantizedWeightsHelper(
675       graph,
676       paramsDict,
677       qconv3d,
678       "quantized::conv3d_unpack",
679       QuantizedParamsType::CONV);
680   unpackQuantizedWeightsHelper(
681       graph,
682       paramsDict,
683       qconv3d_relu,
684       "quantized::conv3d_unpack",
685       QuantizedParamsType::CONV);
686   unpackQuantizedWeightsHelper(
687       graph,
688       paramsDict,
689       qconv_transpose1d,
690       "quantized::conv_transpose1d_unpack",
691       QuantizedParamsType::CONV1D,
692       true);
693   unpackQuantizedWeightsHelper(
694       graph,
695       paramsDict,
696       qconv_transpose2d,
697       "quantized::conv_transpose2d_unpack",
698       QuantizedParamsType::CONV,
699       true);
700   unpackQuantizedWeightsHelper(
701       graph,
702       paramsDict,
703       qconv_transpose3d,
704       "quantized::conv_transpose3d_unpack",
705       QuantizedParamsType::CONV,
706       true);
707   UnpackQuantizedTensorInputs(graph);
708   GRAPH_DUMP("After UnpackQuantizedWeights: ", graph);
709 }
710 
711 // Caffe2 expects quantized ops to be in NHWC format while pytorch inputs are in
712 // NCHW. This pass inserts permutes to convert from NCHW to NHWC before each
713 // conv op and add another permute from NHWC to NCHW after the conv op.
insertPermutesHelper(std::shared_ptr<Graph> & graph,std::map<std::string,IValue> & paramsDict,const std::string & pattern)714 void insertPermutesHelper(
715     std::shared_ptr<Graph>& graph,
716     std::map<std::string, IValue>& paramsDict,
717     const std::string& pattern) {
718   Graph pattern_graph;
719   std::unordered_map<std::string, Value*> vmap;
720   parseIR(pattern, &pattern_graph, vmap);
721 
722   const auto& matches = findPatternMatches(pattern_graph, *graph);
723 
724   for (const auto& match : matches) {
725     auto match_vmap = match.values_map;
726     auto op_node = match_vmap.at(vmap.at("r"))->node();
727     auto input_node = match_vmap.at(vmap.at("r"))->node()->inputs()[0]->node();
728 
729     Node* permute_node_before = graph->create(
730         Symbol::fromQualString("quantized::nchw2nhwc"), {input_node->output()});
731     permute_node_before->insertBefore(op_node);
732     op_node->removeInput(0);
733     op_node->insertInput(0, permute_node_before->output());
734 
735     Node* permute_node_after = graph->create(
736         Symbol::fromQualString("quantized::nhwc2nchw"),
737         {op_node->outputs()[0]});
738     permute_node_after->insertAfter(op_node);
739     auto v = op_node->outputs().at(0);
740     v->replaceAllUsesWith(permute_node_after->outputs().at(0));
741     permute_node_after->removeInput(0);
742     permute_node_after->addInput(v);
743   }
744 }
745 
insertPermutes(std::shared_ptr<Graph> & graph,std::map<std::string,IValue> & paramsDict)746 void insertPermutes(
747     std::shared_ptr<Graph>& graph,
748     std::map<std::string, IValue>& paramsDict) {
749   std::string qconv = R"(
750   graph(%input, %weight, %bias, %stride, %padding, %dilation, %groups, %w_scale, %w_zero_point):
751         %r = quantized::conv2d(%input, %weight, %bias, %stride, %padding, %dilation, %groups, %w_scale, %w_zero_point)
752         return (%r) )";
753   std::string qconv_relu = R"(
754   graph(%input, %weight, %bias, %stride, %padding, %dilation, %groups, %w_scale, %w_zero_point):
755         %r = quantized::conv2d_relu(%input, %weight, %bias, %stride, %padding, %dilation, %groups, %w_scale, %w_zero_point)
756         return (%r) )";
757   std::string qconv_transpose = R"(
758   graph(%input, %weight, %bias, %stride, %padding, %dilation, %output_padding, %groups, %w_scale, %w_zero_point):
759         %r = quantized::conv_transpose2d(%input, %weight, %bias, %stride, %padding, %output_padding, %dilation, %groups, %w_scale, %w_zero_point)
760         return (%r) )";
761 
762   insertPermutesHelper(graph, paramsDict, qconv);
763   insertPermutesHelper(graph, paramsDict, qconv_relu);
764   insertPermutesHelper(graph, paramsDict, qconv_transpose);
765   GRAPH_DUMP("After insertPermutes: ", graph);
766 }
767 
768 } // namespace torch::jit
769