1 #include <torch/csrc/jit/passes/onnx/unpack_quantized_weights.h>
2
3 #include <ATen/native/quantized/PackedParams.h>
4 #include <c10/util/irange.h>
5 #include <torch/csrc/jit/ir/constants.h>
6 #include <torch/csrc/jit/ir/irparser.h>
7 #include <torch/csrc/jit/ir/subgraph_matcher.h>
8 #include <torch/csrc/jit/jit_log.h>
9 #include <torch/csrc/jit/passes/onnx/helper.h>
10 #include <torch/csrc/jit/passes/subgraph_rewrite.h>
11
12 // TODO: Switch to per operator headers after
13 // https://github.com/pytorch/pytorch/pull/68693 is merged
14 #include <ATen/Functions.h>
15
16 using ::c10::Dispatcher;
17
18 namespace torch::jit {
19 namespace onnx {
20 using namespace ::c10::onnx;
21
22 }
23
24 // Get the scale of the input to quantized op. There are two cases here
25 // 1. For ops with output_scale specified in op signature, we get the output
26 // scale
27 // 2. For ops with no output scale in op signature (like quantized::relu)
28 // we traverse up the graph to get the scale from its input until we hit a node
29 // where scale is explicitly specified.
getScaleFromInput(Node * input_node)30 double getScaleFromInput(Node* input_node) {
31 std::optional<IValue> scale;
32 std::string input_name = input_node->kind().toQualString();
33 std::unordered_set<std::string> noscale_ops = {
34 "quantized::max_pool2d",
35 "aten::max_pool2d",
36 "aten::relu",
37 "prim::ListUnpack",
38 "aten::split_with_sizes",
39 "quantized::nchw2nhwc",
40 "quantized::nhwc2nchw",
41 "aten::slice",
42 "aten::avg_pool2d",
43 "quantized::cat",
44 "prim::ListConstruct",
45 "aten::upsample_nearest2d",
46 "aten::sigmoid",
47 "aten::reshape"};
48 if (input_name == "aten::quantize_per_tensor") {
49 TORCH_CHECK(
50 input_node->inputs().size() > 1,
51 "aten::quantize_per_tensor expected scale to be 2nd input");
52 scale = toIValue(input_node->inputs()[1]);
53 return scale.value().toDouble();
54 } else if (input_name == "quantized::linear") {
55 // %r = quantized::linear(%input, %packed_weight, %w_scale, %w_zero_point)
56 TORCH_CHECK(
57 input_node->inputs().size() > 2,
58 "quantized::linear expected scale to be 3rd input");
59 scale = toIValue(input_node->inputs()[2]);
60 return scale.value().toDouble();
61 } else if (input_name == "quantized::conv2d") {
62 // %r = quantized::conv2d(%input, %packed_weight, %w_scale, %w_zero_point)
63 TORCH_CHECK(
64 input_node->inputs().size() > 2,
65 "quantized::conv2d expected scale to be 3rd input");
66 auto num_inputs = input_node->inputs().size();
67 scale = toIValue(input_node->inputs()[num_inputs - 2]);
68 return scale.value().toDouble();
69 } else if (input_name == "quantized::conv2d_relu") {
70 // %r = quantized::conv2d_relu(%input, %packed_weight, %w_scale,
71 // %w_zero_point)
72 TORCH_CHECK(
73 input_node->inputs().size() > 2,
74 "quantized::conv2d_relu expected scale to be 3rd input");
75 auto num_inputs = input_node->inputs().size();
76 scale = toIValue(input_node->inputs()[num_inputs - 2]);
77 return scale.value().toDouble();
78 } else if (input_name == "quantized::add") {
79 // %r = quantized::add(%input_a, %input_b, %w_scale, %w_zero_point)
80 TORCH_CHECK(
81 input_node->inputs().size() > 2,
82 "quantized::add expected scale to be 3rd input");
83 scale = toIValue(input_node->inputs()[2]);
84 return scale.value().toDouble();
85 } else if (input_name == "aten::sigmoid") {
86 // For the _caffe2::Int8Sigmoid op output scale is 1.0/256
87 // And output zero_point is set to 0 (quint8 type).
88 return 1.0L / 256;
89 }
90 // For the ops below the scale is not part of the op signature, so we traverse
91 // up the graph to get the scale from its input when defined in the graph.
92 else if (noscale_ops.find(input_name) != noscale_ops.end()) {
93 return getScaleFromInput(input_node->inputs()[0]->node());
94 }
95 TORCH_INTERNAL_ASSERT(
96 false,
97 "Unrecognized quantized operator while trying to compute q_scale for operator ",
98 input_name);
99 }
100
CreateQuantizedWeights(std::shared_ptr<Graph> & graph,const at::Tensor & weight,int8_t * data,const std::vector<int64_t> & shapes,const std::vector<int64_t> & strides)101 std::vector<Node*> CreateQuantizedWeights(
102 std::shared_ptr<Graph>& graph,
103 const at::Tensor& weight,
104 int8_t* data,
105 const std::vector<int64_t>& shapes,
106 const std::vector<int64_t>& strides) {
107 auto qscheme = weight.qscheme();
108 std::vector<Node*> unpacked_wt;
109
110 // Retrieve scales and zero_points. Their formats are different depending on
111 // different weight qscheme.
112 std::vector<float> scale_data;
113 std::vector<int64_t> scale_shapes;
114 std::vector<int64_t> zero_point_data;
115 std::vector<int64_t> zero_point_shapes;
116 std::vector<int64_t> axis_data;
117 switch (qscheme) {
118 case c10::kPerTensorAffine: {
119 // Cast to float since ONNX (De)QuantizeLinear only supports float scale.
120 scale_data = {static_cast<float>(weight.q_scale())};
121 scale_shapes = {1};
122 zero_point_data = {weight.q_zero_point()};
123 zero_point_shapes = {1};
124 break;
125 }
126 case c10::kPerChannelAffine:
127 case c10::kPerChannelAffineFloatQParams: {
128 auto q_scales = weight.q_per_channel_scales();
129 auto* scale_data_raw = q_scales.const_data_ptr<double>();
130 scale_shapes = q_scales.sizes().vec();
131 TORCH_INTERNAL_ASSERT(
132 scale_shapes.size() == 1,
133 "quantized per channel scales are expected as 1-d array.");
134 scale_data.resize(scale_shapes[0]);
135 // Cast to float since ONNX (De)QuantizeLinear only supports float scale.
136 std::transform(
137 scale_data_raw,
138 scale_data_raw + scale_shapes[0],
139 scale_data.begin(),
140 [](double x) { return static_cast<float>(x); });
141
142 auto q_zero_points = weight.q_per_channel_zero_points();
143 auto* zero_point_data_raw = q_zero_points.const_data_ptr<int64_t>();
144 zero_point_shapes = q_zero_points.sizes().vec();
145 TORCH_INTERNAL_ASSERT(
146 zero_point_shapes.size() == 1,
147 "quantized per channel zero points are expected as 1-d array.");
148 zero_point_data = std::vector<int64_t>(
149 zero_point_data_raw, zero_point_data_raw + zero_point_shapes[0]);
150 axis_data = {weight.q_per_channel_axis()};
151 break;
152 }
153 default:
154 TORCH_CHECK(
155 false, "Unsupported qscheme for weight, got ", toString(qscheme));
156 }
157
158 Node* data_node = graph->create(prim::Constant);
159 auto data_value =
160 at::from_blob(
161 data, c10::IntArrayRef(shapes), c10::IntArrayRef(strides), at::kChar)
162 .to(at::kCPU);
163 // Need clone because at::from_blob does not take ownership of data.
164 data_node->t_(Symbol::attr("value"), data_value.clone());
165
166 Node* scale_node = graph->create(prim::Constant);
167 auto scale_value =
168 at::from_blob(
169 scale_data.data(), c10::IntArrayRef(scale_shapes), at::kFloat)
170 .to(at::kCPU);
171 scale_node->t_(Symbol::attr("value"), scale_value.clone());
172
173 Node* zero_point_node = graph->create(prim::Constant);
174 auto zero_point_value =
175 at::from_blob(
176 zero_point_data.data(), c10::IntArrayRef(zero_point_shapes), at::kInt)
177 .to(at::kCPU);
178 zero_point_node->t_(Symbol::attr("value"), zero_point_value.clone());
179
180 Node* axis_node = graph->create(prim::Constant);
181 if (!axis_data.empty()) {
182 auto axis_value =
183 at::from_blob(
184 axis_data.data(), c10::IntArrayRef(axis_data.size()), at::kLong)
185 .to(at::kCPU);
186 axis_node->t_(attr::value, axis_value.clone());
187 } else {
188 axis_node->output()->setType(NoneType::get());
189 }
190
191 return {data_node, scale_node, zero_point_node, axis_node};
192 }
193
CreateQuantizedBias(std::vector<float> data,std::shared_ptr<Graph> & graph,const std::vector<int64_t> & shapes)194 Node* CreateQuantizedBias(
195 std::vector<float> data,
196 std::shared_ptr<Graph>& graph,
197 const std::vector<int64_t>& shapes) {
198 Node* const_node_1 = graph->create(prim::Constant);
199 auto const_bias =
200 at::from_blob(data.data(), c10::IntArrayRef(shapes), at::kFloat)
201 .to(at::kCPU);
202 auto options = c10::TensorOptions().dtype(at::kFloat).device(at::kCPU);
203 at::Tensor const_bias_copy = at::empty(c10::IntArrayRef(shapes), options);
204 const_bias_copy.copy_(const_bias);
205 const_node_1->t_(Symbol::attr("value"), const_bias_copy);
206 return const_node_1;
207 }
208
createIntTuple(const std::vector<int64_t> & is,std::shared_ptr<Graph> & graph)209 Node* createIntTuple(
210 const std::vector<int64_t>& is,
211 std::shared_ptr<Graph>& graph) {
212 Node* const_node = graph->create(Symbol::onnx("Constant"));
213 const_node->is_(Symbol::attr("value"), is);
214 return const_node;
215 }
216
createInt(int64_t i,std::shared_ptr<Graph> & graph)217 Node* createInt(int64_t i, std::shared_ptr<Graph>& graph) {
218 Node* const_node = graph->create(Symbol::onnx("Constant"));
219 const_node->i_(Symbol::attr("value"), i);
220 return const_node;
221 }
222
ConvertQuantizedWeight(std::shared_ptr<Graph> & graph,Node * node,at::Tensor & weight)223 void ConvertQuantizedWeight(
224 std::shared_ptr<Graph>& graph,
225 Node* node,
226 at::Tensor& weight) {
227 std::vector<int64_t> wt_sizes = weight.sizes().vec();
228 std::vector<int64_t> wt_strides = weight.strides().vec();
229 // Remove packed_params
230 node->removeInput(1);
231
232 auto* wt_data =
233 reinterpret_cast<int8_t*>(weight.mutable_data_ptr<c10::qint8>());
234
235 std::vector<Node*> unpacked_wt =
236 CreateQuantizedWeights(graph, weight, wt_data, wt_sizes, wt_strides);
237 graph->setInsertPoint(node);
238 Node* quant_node = graph->create(prim::TupleConstruct);
239 for (auto* n : unpacked_wt) {
240 n->insertBefore(node);
241 quant_node->addInput(n->output());
242 }
243 quant_node->insertBefore(node);
244 node->insertInput(1, quant_node->output());
245 }
246
247 // CONV1D needs a different unpacking from CONV, since it's
248 // packed as CONV2D intentionally at the first place.
249 // See: https://github.com/pytorch/pytorch/pull/38248
250 enum class QuantizedParamsType { CONV1D, CONV, LINEAR };
251
252 // This is called before the onnx pass. Using pattern matching we
253 // find the relevant nodes and extract the packed_params. The packed_params are
254 // passed to the appropriate unpack function using c10::Dispatcher. We insert
255 // the unpacked weights and bias into the graph using
256 // caffe2::Int8GivenTensorFill nodes.
unpackQuantizedWeightsHelper(std::shared_ptr<Graph> & graph,std::map<std::string,IValue> & paramsDict,const std::string & pattern,const std::string & unpack_fn,QuantizedParamsType params_type,bool expect_output_padding=false)257 void unpackQuantizedWeightsHelper(
258 std::shared_ptr<Graph>& graph,
259 std::map<std::string, IValue>& paramsDict,
260 const std::string& pattern,
261 const std::string& unpack_fn,
262 QuantizedParamsType params_type,
263 bool expect_output_padding = false) {
264 Graph pattern_graph;
265 std::unordered_map<std::string, Value*> vmap;
266 parseIR(pattern, &pattern_graph, vmap);
267 const auto& matches = findPatternMatches(pattern_graph, *graph);
268
269 for (const auto& match : matches) {
270 auto match_vmap = match.values_map;
271 auto qlinear_node = match_vmap.at(vmap.at("r"))->node();
272 std::string quantized_weight =
273 match_vmap.at(vmap.at("r"))->node()->inputs()[1]->debugName();
274
275 auto itr = paramsDict.find(quantized_weight);
276 if (itr == paramsDict.end()) {
277 throw std::runtime_error(
278 "getValues: Quantized weight value not found amongst constant parameters.");
279 }
280 at::Tensor unpacked_weight;
281 std::optional<at::Tensor> bias;
282 constexpr int64_t stride_idx = 2;
283 constexpr int64_t padding_idx = 3;
284 int64_t output_padding_idx = 0;
285 int64_t dilation_idx = 0;
286 int64_t groups_idx = 0;
287 if (expect_output_padding) {
288 output_padding_idx = 4;
289 dilation_idx = 5;
290 groups_idx = 6;
291 } else {
292 dilation_idx = 4;
293 groups_idx = 5;
294 }
295 std::optional<torch::List<int64_t>> stride, padding, dilation,
296 output_padding;
297 std::optional<int64_t> groups;
298 std::optional<int64_t> transpose;
299
300 torch::List<int64_t> stride_int, padding_int, dilation_int,
301 output_padding_int;
302
303 if (itr->second.isTuple()) {
304 // Pre-unpacked weights. Comes from Conv/Linear weights which are
305 // stored as bound C++ classes.
306 auto ser_tup = itr->second.toTuple();
307
308 if (params_type == QuantizedParamsType::CONV &&
309 ser_tup->elements()[0].isInt()) {
310 const auto& elements = ser_tup->elements();
311 auto version = elements[0].toInt();
312 TORCH_INTERNAL_ASSERT(version == 3, "Unknown serialization version");
313 TORCH_INTERNAL_ASSERT(elements.size() == 3, "Wrong tuple size.");
314
315 auto config_vals = elements[1].to<std::vector<int64_t>>();
316 auto tensors = elements[2].to<std::vector<std::optional<at::Tensor>>>();
317
318 std::optional<at::Tensor> weight = tensors[1];
319 TORCH_INTERNAL_ASSERT(
320 weight, "Weight should always be present in serialized qconv.");
321 unpacked_weight = *weight;
322 bias = tensors[2];
323
324 const int64_t kSpatialDim = config_vals.at(0);
325 // skip kSpatialDim
326 unsigned idx = 1;
327 for (const auto i : c10::irange(kSpatialDim)) {
328 (void)i; // Suppress unused variable warning
329 stride_int.emplace_back(config_vals.at(idx));
330 idx++;
331 }
332 for (const auto i : c10::irange(kSpatialDim)) {
333 (void)i; // Suppress unused variable warning
334 padding_int.emplace_back(config_vals.at(idx));
335 idx++;
336 }
337 for (const auto i : c10::irange(kSpatialDim)) {
338 (void)i; // Suppress unused variable warning
339 dilation_int.emplace_back(config_vals.at(idx));
340 idx++;
341 }
342 for (const auto i : c10::irange(kSpatialDim)) {
343 (void)i; // Suppress unused variable warning
344 output_padding_int.emplace_back(config_vals.at(idx));
345 idx++;
346 }
347 int64_t groups_int = config_vals.at(idx);
348 idx++;
349 int64_t flags = config_vals.at(idx);
350 idx++;
351 TORCH_INTERNAL_ASSERT(
352 idx == config_vals.size(),
353 "Unexpected length of config_vals, expected ",
354 idx,
355 " got ",
356 config_vals.size());
357
358 bool transpose_int = flags & (1 << 0);
359
360 int64_t other_flags = flags & ~(1 << 0);
361 TORCH_CHECK(other_flags == 0, "Unexpected flags set in ", flags, ".");
362
363 stride = stride_int;
364 padding = padding_int;
365 dilation = dilation_int;
366 groups = groups_int;
367 transpose = transpose_int;
368 if (expect_output_padding) {
369 output_padding = output_padding_int;
370 }
371 } else if (
372 (params_type == QuantizedParamsType::CONV ||
373 params_type == QuantizedParamsType::CONV1D) &&
374 ser_tup->elements()[0].isString()) {
375 const auto& elements = ser_tup->elements();
376 auto version = elements[0].toStringRef();
377 TORCH_INTERNAL_ASSERT(version == "2", "Unknown serialization version");
378 std::vector<at::Tensor> non_optional = elements[1].toTensorVector();
379
380 at::Tensor conv_params_packed = non_optional[0];
381 unpacked_weight = non_optional[1];
382
383 const int64_t kSpatialDim = conv_params_packed[0].item<int64_t>();
384 // skip kSpatialDim
385 int64_t idx = 1;
386 // kSpatialDim = 2 even it's for Conv1D from torch.op to adopt Conv2D,
387 // so we need a special unpack for Conv1D which has Conv2D dim.
388 // See: https://github.com/pytorch/pytorch/pull/38248
389 for (const auto i : c10::irange(kSpatialDim)) {
390 if (params_type != QuantizedParamsType::CONV1D || i != 0) {
391 stride_int.emplace_back(conv_params_packed[idx].item<int64_t>());
392 }
393 idx++;
394 }
395 for (const auto i : c10::irange(kSpatialDim)) {
396 if (params_type != QuantizedParamsType::CONV1D || i != 0) {
397 padding_int.emplace_back(conv_params_packed[idx].item<int64_t>());
398 }
399 idx++;
400 }
401 for (const auto i : c10::irange(kSpatialDim)) {
402 if (params_type != QuantizedParamsType::CONV1D || i != 0) {
403 dilation_int.emplace_back(conv_params_packed[idx].item<int64_t>());
404 }
405 idx++;
406 }
407 for (const auto i : c10::irange(kSpatialDim)) {
408 if (params_type != QuantizedParamsType::CONV1D || i != 0) {
409 output_padding_int.emplace_back(
410 conv_params_packed[idx].item<int64_t>());
411 }
412 idx++;
413 }
414 auto groups_int = conv_params_packed[idx].item<int64_t>();
415 idx++;
416 auto transpose_int = conv_params_packed[idx].item<int64_t>();
417 idx++;
418 TORCH_INTERNAL_ASSERT(
419 idx == conv_params_packed.numel(),
420 "Unexpected length of conv_params_packed, expected ",
421 idx,
422 " got ",
423 conv_params_packed.numel());
424
425 torch::List<c10::IValue> optional = elements[2].toList();
426 bias = optional.get(0).toOptional<at::Tensor>();
427
428 if (params_type == QuantizedParamsType::CONV1D) {
429 unpacked_weight = unpacked_weight.squeeze_(2);
430 }
431 stride = stride_int;
432 padding = padding_int;
433 dilation = dilation_int;
434 groups = groups_int;
435 transpose = transpose_int;
436 if (expect_output_padding) {
437 output_padding = output_padding_int;
438 }
439 } else { // Legacy
440 unpacked_weight = ser_tup->elements()[0].toTensor();
441 bias = ser_tup->elements()[1].toOptional<at::Tensor>();
442 // conv only parameters
443 if (ser_tup->elements().size() > 2) {
444 auto stride_ivalue = ser_tup->elements()[stride_idx].toListRef();
445 auto padding_ivalue = ser_tup->elements()[padding_idx].toListRef();
446 auto dilation_ivalue = ser_tup->elements()[dilation_idx].toListRef();
447 auto groups_ivalue = ser_tup->elements()[groups_idx];
448
449 for (const auto& s : stride_ivalue) {
450 stride_int.emplace_back(s.toTensor()[0].item<int64_t>());
451 }
452 for (const auto& p : padding_ivalue) {
453 padding_int.emplace_back(p.toTensor()[0].item<int64_t>());
454 }
455 for (const auto& d : dilation_ivalue) {
456 dilation_int.emplace_back(d.toTensor()[0].item<int64_t>());
457 }
458 groups = groups_ivalue.toTensor()[0].item<int64_t>();
459 stride = stride_int;
460 padding = padding_int;
461 dilation = dilation_int;
462
463 if (expect_output_padding) {
464 auto output_padding_ivalue =
465 ser_tup->elements()[output_padding_idx].toListRef();
466 for (const auto& d : output_padding_ivalue) {
467 output_padding_int.emplace_back(d.toTensor()[0].item<int64_t>());
468 }
469 output_padding = output_padding_int;
470 }
471 }
472 }
473 } else {
474 TORCH_INTERNAL_ASSERT(itr->second.isTensor());
475 at::Tensor packed_weight = itr->second.toTensor();
476 auto op = Dispatcher::singleton()
477 .findSchemaOrThrow(unpack_fn.c_str(), "")
478 .typed<std::tuple<at::Tensor, std::optional<at::Tensor>>(
479 at::Tensor)>();
480 std::tie(unpacked_weight, bias) = op.call(packed_weight);
481 }
482
483 ConvertQuantizedWeight(graph, qlinear_node, unpacked_weight);
484
485 // Add bias
486 at::Tensor original_bias;
487 if (bias.has_value()) {
488 original_bias = bias.value();
489 original_bias.set_requires_grad(false);
490 } else {
491 int64_t bias_size = unpacked_weight.size(0);
492 original_bias =
493 at::zeros(bias_size, unpacked_weight.options().dtype(at::kFloat));
494 }
495
496 auto input_val = match_vmap.at(vmap.at("r"))->node()->inputs()[0];
497 TORCH_INTERNAL_ASSERT(
498 input_val->type()->isSubtypeOf(*TensorType::get()),
499 "Unsupported input type. Expected TensorType, got ",
500 input_val->type()->str());
501
502 std::vector<float> bias_values(original_bias.numel());
503 auto bias_data = original_bias.const_data_ptr<float>();
504 for (const auto i : c10::irange(original_bias.numel())) {
505 bias_values[i] = bias_data[i];
506 }
507 Node* bias_node =
508 CreateQuantizedBias(bias_values, graph, original_bias.sizes().vec());
509 bias_node->insertBefore(qlinear_node);
510 // For quantized_linear inputs, the order is input, weight, bias, ....
511 // Therefore bias is at location 2.
512 qlinear_node->insertInput(2, bias_node->output());
513
514 // add conv arguments: stride, padding, dilation, groups, output_padding
515 if (stride.has_value() && padding.has_value() && dilation.has_value() &&
516 groups.has_value() &&
517 (!expect_output_padding || output_padding.has_value())) {
518 std::vector<std::optional<torch::List<int64_t>>> conv_ints_args;
519 conv_ints_args.push_back(stride);
520 conv_ints_args.push_back(padding);
521 if (expect_output_padding) {
522 conv_ints_args.push_back(output_padding);
523 }
524 conv_ints_args.push_back(dilation);
525 // skip (input, weight, bias)
526 const size_t arg_offset = 3;
527 for (const auto i : c10::irange(conv_ints_args.size())) {
528 Node* ints_node =
529 createIntTuple(conv_ints_args[i].value().vec(), graph);
530 ints_node->insertBefore(qlinear_node);
531 qlinear_node->insertInput(arg_offset + i, ints_node->output());
532 }
533 Node* groups_node = createInt(groups.value(), graph);
534 groups_node->insertBefore(qlinear_node);
535 qlinear_node->insertInput(groups_idx + 1, groups_node->output());
536 }
537 auto b = graph->block();
538 auto valsToParamsMap = buildValueToParamsMap(b, paramsDict);
539 eraseUnusedValuesFromMap(valsToParamsMap);
540 }
541 }
542
543 static std::
544 unordered_map<c10::ScalarType, c10::ScalarType, ScalarTypeHashFunction>
545 qTypeToValType = {
546 {c10::ScalarType::QInt8, c10::ScalarType::Char},
547 {c10::ScalarType::QUInt8, c10::ScalarType::Byte},
548 {c10::ScalarType::QInt32, c10::ScalarType::Int},
549 {c10::ScalarType::QUInt4x2, c10::ScalarType::Byte},
550 };
551
552 // Unpack quantized tensor inputs into {value, scale, zero_point},
553 // Then create a prim::TupleConstruct node based on these three values.
UnpackQuantizedTensorInputs(std::shared_ptr<Graph> & graph)554 void UnpackQuantizedTensorInputs(std::shared_ptr<Graph>& graph) {
555 for (size_t index = 0; index < graph->inputs().size();) {
556 auto g_input = graph->inputs()[index];
557 TensorTypePtr shape_type = g_input->type()->cast<TensorType>();
558 if (!shape_type || !shape_type->scalarType().has_value()) {
559 index++;
560 continue;
561 }
562 auto scalar_type = shape_type->scalarType().value();
563 if (qTypeToValType.find(scalar_type) == qTypeToValType.end()) {
564 index++;
565 continue;
566 }
567 std::string input_name = g_input->debugName();
568 auto input_value =
569 graph->insertInput(index, input_name + "_value")
570 ->setType(shape_type->withScalarType(qTypeToValType[scalar_type]));
571 // scale and zero_point type can be found at torch/include/ATen/Operators.h
572 auto input_scale =
573 graph->insertInput(index + 1, input_name + "_scale")
574 ->setType(TensorType::create(
575 at::kDouble, at::kCPU, 0, /*requires_grad=*/std::nullopt));
576 auto input_zero_point =
577 graph->insertInput(index + 2, input_name + "_zero_point")
578 ->setType(TensorType::create(
579 at::kLong, at::kCPU, 0, /*requires_grad=*/std::nullopt));
580 std::vector<Value*> converted{input_value, input_scale, input_zero_point};
581 auto input_tuple =
582 graph->prependNode(graph->createTuple(converted))->output();
583 g_input->replaceAllUsesWith(input_tuple);
584 // Erase the original quantized tensor input.
585 graph->eraseInput(index + converted.size());
586 index += 3;
587 }
588 }
589
590 // https://github.com/pytorch/pytorch/wiki/PyTorch-ONNX-exporter#quantized-model-export
UnpackQuantizedWeights(std::shared_ptr<Graph> & graph,std::map<std::string,IValue> & paramsDict)591 void UnpackQuantizedWeights(
592 std::shared_ptr<Graph>& graph,
593 std::map<std::string, IValue>& paramsDict) {
594 std::string qlinear = R"(
595 graph(%input, %packed_weight, %w_scale, %w_zero_point):
596 %r = quantized::linear(%input, %packed_weight, %w_scale, %w_zero_point)
597 return (%r) )";
598 std::string qlinear_relu = R"(
599 graph(%input, %packed_weight, %w_scale, %w_zero_point):
600 %r = quantized::linear_relu(%input, %packed_weight, %w_scale, %w_zero_point)
601 return (%r) )";
602 std::string qconv1d = R"(
603 graph(%input, %packed_params, %scale, %zero_point):
604 %r = quantized::conv1d(%input, %packed_params, %scale, %zero_point)
605 return (%r) )";
606 std::string qconv1d_relu = R"(
607 graph(%input, %packed_params, %scale, %zero_point):
608 %r = quantized::conv1d_relu(%input, %packed_params, %scale, %zero_point)
609 return (%r) )";
610 std::string qconv2d = R"(
611 graph(%input, %packed_params, %scale, %zero_point):
612 %r = quantized::conv2d(%input, %packed_params, %scale, %zero_point)
613 return (%r) )";
614 std::string qconv2d_relu = R"(
615 graph(%input, %packed_params, %scale, %zero_point):
616 %r = quantized::conv2d_relu(%input, %packed_params, %scale, %zero_point)
617 return (%r) )";
618 std::string qconv3d = R"(
619 graph(%input, %packed_params, %scale, %zero_point):
620 %r = quantized::conv3d(%input, %packed_params, %scale, %zero_point)
621 return (%r) )";
622 std::string qconv3d_relu = R"(
623 graph(%input, %packed_params, %scale, %zero_point):
624 %r = quantized::conv3d_relu(%input, %packed_params, %scale, %zero_point)
625 return (%r) )";
626 std::string qconv_transpose1d = R"(
627 graph(%input, %packed_params, %scale, %zero_point):
628 %r = quantized::conv_transpose1d(%input, %packed_params, %scale, %zero_point)
629 return (%r) )";
630 std::string qconv_transpose2d = R"(
631 graph(%input, %packed_params, %scale, %zero_point):
632 %r = quantized::conv_transpose2d(%input, %packed_params, %scale, %zero_point)
633 return (%r) )";
634 std::string qconv_transpose3d = R"(
635 graph(%input, %packed_params, %scale, %zero_point):
636 %r = quantized::conv_transpose3d(%input, %packed_params, %scale, %zero_point)
637 return (%r) )";
638 unpackQuantizedWeightsHelper(
639 graph,
640 paramsDict,
641 qlinear,
642 "quantized::linear_unpack",
643 QuantizedParamsType::LINEAR);
644 unpackQuantizedWeightsHelper(
645 graph,
646 paramsDict,
647 qlinear_relu,
648 "quantized::linear_unpack",
649 QuantizedParamsType::LINEAR);
650 unpackQuantizedWeightsHelper(
651 graph,
652 paramsDict,
653 qconv1d,
654 "quantized::conv1d_unpack",
655 QuantizedParamsType::CONV1D);
656 unpackQuantizedWeightsHelper(
657 graph,
658 paramsDict,
659 qconv2d,
660 "quantized::conv2d_unpack",
661 QuantizedParamsType::CONV);
662 unpackQuantizedWeightsHelper(
663 graph,
664 paramsDict,
665 qconv1d_relu,
666 "quantized::conv1d_unpack",
667 QuantizedParamsType::CONV1D);
668 unpackQuantizedWeightsHelper(
669 graph,
670 paramsDict,
671 qconv2d_relu,
672 "quantized::conv2d_unpack",
673 QuantizedParamsType::CONV);
674 unpackQuantizedWeightsHelper(
675 graph,
676 paramsDict,
677 qconv3d,
678 "quantized::conv3d_unpack",
679 QuantizedParamsType::CONV);
680 unpackQuantizedWeightsHelper(
681 graph,
682 paramsDict,
683 qconv3d_relu,
684 "quantized::conv3d_unpack",
685 QuantizedParamsType::CONV);
686 unpackQuantizedWeightsHelper(
687 graph,
688 paramsDict,
689 qconv_transpose1d,
690 "quantized::conv_transpose1d_unpack",
691 QuantizedParamsType::CONV1D,
692 true);
693 unpackQuantizedWeightsHelper(
694 graph,
695 paramsDict,
696 qconv_transpose2d,
697 "quantized::conv_transpose2d_unpack",
698 QuantizedParamsType::CONV,
699 true);
700 unpackQuantizedWeightsHelper(
701 graph,
702 paramsDict,
703 qconv_transpose3d,
704 "quantized::conv_transpose3d_unpack",
705 QuantizedParamsType::CONV,
706 true);
707 UnpackQuantizedTensorInputs(graph);
708 GRAPH_DUMP("After UnpackQuantizedWeights: ", graph);
709 }
710
711 // Caffe2 expects quantized ops to be in NHWC format while pytorch inputs are in
712 // NCHW. This pass inserts permutes to convert from NCHW to NHWC before each
713 // conv op and add another permute from NHWC to NCHW after the conv op.
insertPermutesHelper(std::shared_ptr<Graph> & graph,std::map<std::string,IValue> & paramsDict,const std::string & pattern)714 void insertPermutesHelper(
715 std::shared_ptr<Graph>& graph,
716 std::map<std::string, IValue>& paramsDict,
717 const std::string& pattern) {
718 Graph pattern_graph;
719 std::unordered_map<std::string, Value*> vmap;
720 parseIR(pattern, &pattern_graph, vmap);
721
722 const auto& matches = findPatternMatches(pattern_graph, *graph);
723
724 for (const auto& match : matches) {
725 auto match_vmap = match.values_map;
726 auto op_node = match_vmap.at(vmap.at("r"))->node();
727 auto input_node = match_vmap.at(vmap.at("r"))->node()->inputs()[0]->node();
728
729 Node* permute_node_before = graph->create(
730 Symbol::fromQualString("quantized::nchw2nhwc"), {input_node->output()});
731 permute_node_before->insertBefore(op_node);
732 op_node->removeInput(0);
733 op_node->insertInput(0, permute_node_before->output());
734
735 Node* permute_node_after = graph->create(
736 Symbol::fromQualString("quantized::nhwc2nchw"),
737 {op_node->outputs()[0]});
738 permute_node_after->insertAfter(op_node);
739 auto v = op_node->outputs().at(0);
740 v->replaceAllUsesWith(permute_node_after->outputs().at(0));
741 permute_node_after->removeInput(0);
742 permute_node_after->addInput(v);
743 }
744 }
745
insertPermutes(std::shared_ptr<Graph> & graph,std::map<std::string,IValue> & paramsDict)746 void insertPermutes(
747 std::shared_ptr<Graph>& graph,
748 std::map<std::string, IValue>& paramsDict) {
749 std::string qconv = R"(
750 graph(%input, %weight, %bias, %stride, %padding, %dilation, %groups, %w_scale, %w_zero_point):
751 %r = quantized::conv2d(%input, %weight, %bias, %stride, %padding, %dilation, %groups, %w_scale, %w_zero_point)
752 return (%r) )";
753 std::string qconv_relu = R"(
754 graph(%input, %weight, %bias, %stride, %padding, %dilation, %groups, %w_scale, %w_zero_point):
755 %r = quantized::conv2d_relu(%input, %weight, %bias, %stride, %padding, %dilation, %groups, %w_scale, %w_zero_point)
756 return (%r) )";
757 std::string qconv_transpose = R"(
758 graph(%input, %weight, %bias, %stride, %padding, %dilation, %output_padding, %groups, %w_scale, %w_zero_point):
759 %r = quantized::conv_transpose2d(%input, %weight, %bias, %stride, %padding, %output_padding, %dilation, %groups, %w_scale, %w_zero_point)
760 return (%r) )";
761
762 insertPermutesHelper(graph, paramsDict, qconv);
763 insertPermutesHelper(graph, paramsDict, qconv_relu);
764 insertPermutesHelper(graph, paramsDict, qconv_transpose);
765 GRAPH_DUMP("After insertPermutes: ", graph);
766 }
767
768 } // namespace torch::jit
769