1 #include <torch/csrc/jit/passes/quantization/insert_quant_dequant.h>
2
3 #include <c10/core/QScheme.h>
4 #include <c10/util/irange.h>
5 #include <torch/csrc/jit/frontend/schema_matching.h>
6 #include <torch/csrc/jit/ir/subgraph_matcher.h>
7 #include <torch/csrc/jit/jit_log.h>
8 #include <torch/csrc/jit/passes/constant_propagation.h>
9 #include <torch/csrc/jit/passes/fuse_linear.h>
10 #include <torch/csrc/jit/passes/graph_rewrite_helper.h>
11 #include <torch/csrc/jit/passes/inliner.h>
12 #include <torch/csrc/jit/passes/quantization/helper.h>
13 #include <torch/csrc/jit/passes/subgraph_rewrite.h>
14
15 #include <stack>
16 #include <utility>
17
18 namespace torch {
19 namespace jit {
20
21 namespace {
22 using graph_rewrite_helper::PatternInfo;
23
24 // dynamic quantization ops for activation: choose_qparams, quant, dequant
25 using DynamicQuantOps = std::tuple<Node*, Node*, Node*>;
26
27 std::string kScalarType = "_scalar_type";
28
29 struct QuantOpParams {
30 c10::QScheme qscheme{c10::kPerTensorAffine};
31 std::vector<Value*> qparams;
32 // This is only so that insertQuantizationOps can be templatized
33 // and subsequently significant portion of that code can be reused.
backtorch::jit::__anon97bd7f910111::QuantOpParams34 std::string back() const {
35 return "AttributeDoesNotExist";
36 }
37 };
38
toAffine(c10::QScheme qscheme)39 c10::QScheme toAffine(c10::QScheme qscheme) {
40 switch (qscheme) {
41 case c10::kPerTensorAffine:
42 case c10::kPerTensorSymmetric:
43 return c10::kPerTensorAffine;
44 case c10::kPerChannelAffine:
45 case c10::kPerChannelSymmetric:
46 return c10::kPerChannelAffine;
47 default:
48 return qscheme;
49 }
50 }
51
isPerChannel(at::QScheme qscheme)52 bool isPerChannel(at::QScheme qscheme) {
53 return qscheme == c10::kPerChannelAffine ||
54 qscheme == c10::kPerChannelSymmetric;
55 }
56
57 // Go through the CallMethod graph to check if the value is Weight.
isWeight(Module & module,Value * v)58 bool isWeight(Module& module, Value* v) {
59 if (isWeight(v)) {
60 return true;
61 }
62 std::optional<bool> result;
63 auto* self = v->owningGraph()->inputs()[0];
64 for (const Use& u : v->uses()) {
65 Node* n = u.user;
66 if (n->kind() == prim::CallMethod) {
67 auto m_opt = getInvokedModuleOpt(module, n, self);
68 if (!m_opt.has_value()) {
69 return false;
70 }
71 auto m = *m_opt;
72 auto g = m.get_method(n->s(attr::name)).graph();
73 auto call_method_result = isWeight(m, g->inputs()[u.offset]);
74 if (result.has_value()) {
75 // Check to make sure all the CallMethods in the graph produce the same
76 // output.
77 TORCH_CHECK(
78 call_method_result == result.value(),
79 "Expected all CallMethods to use either weight "
80 "or non-weight value.",
81 v->debugName());
82 } else {
83 result = call_method_result;
84 }
85 }
86 }
87 return result.has_value() ? result.value() : false;
88 }
89
insertChooseQParams(Graph * graph,Value * original_val)90 Node* insertChooseQParams(Graph* graph, Value* original_val) {
91 std::string choose_qparams_func = "_choose_qparams_per_tensor";
92 // Set the reduce range to default to true, since qnnpack backend ignores this
93 // argument.
94 bool reduce_range_param = true;
95 auto reduce_range = graph->insertConstant(reduce_range_param);
96 // choose_qparams_per_tensor has 2 outputs, (scale, zero_point).
97 Node* choose_qparams = graph->create(
98 at::Symbol::aten(choose_qparams_func),
99 {original_val, reduce_range},
100 /* num_outputs = */ 2);
101 choose_qparams->output(0)->setDebugName(original_val->debugName() + ".scale");
102 choose_qparams->output(0)->setType(FloatType::get());
103 choose_qparams->output(1)->setDebugName(
104 original_val->debugName() + ".zero_point");
105 choose_qparams->output(1)->setType(IntType::get());
106 graph->insertNode(choose_qparams);
107 return choose_qparams;
108 }
109
insertQuant(Graph * graph,const std::vector<Value * > & inputs,NodeKind quant_kind,const std::string & debugName)110 Node* insertQuant(
111 Graph* graph,
112 const std::vector<Value*>& inputs,
113 NodeKind quant_kind,
114 const std::string& debugName) {
115 Node* quant = graph->create(quant_kind, inputs);
116 quant->output()->setDebugName(debugName);
117 graph->insertNode(quant);
118 return quant;
119 }
120
insertDeQuant(Graph * graph,Value * quantized_val,Value * original_val,size_t id=0)121 Node* insertDeQuant(
122 Graph* graph,
123 Value* quantized_val,
124 Value* original_val,
125 size_t id = 0) {
126 Node* dequant = graph->create(Symbol::aten("dequantize"), {quantized_val});
127 dequant->output()
128 ->setDebugName(
129 original_val->debugName() + ".dequant." + std::to_string(id))
130 ->setType(original_val->type());
131 graph->insertNode(dequant);
132 return dequant;
133 }
134
insertDeQuantForAllUse(Graph * graph,Value * quantized_val,Value * original_val)135 std::vector<Value*> insertDeQuantForAllUse(
136 Graph* graph,
137 Value* quantized_val,
138 Value* original_val) {
139 // copy uses to vector since value->uses() is a reference
140 // and changing the graph will also change the uses() list
141 const std::vector<Use> uses = original_val->uses();
142 std::vector<Value*> outputs;
143 for (const auto i : c10::irange(uses.size())) {
144 auto* user = uses[i].user;
145 // Insert dequantize node right before use node, because
146 // we want to make sure use node and dequantize node reside
147 // in the same block so that quant fusion can happen
148 WithInsertPoint ins(user);
149 Node* dequant = insertDeQuant(graph, quantized_val, original_val, i);
150 user->replaceInput(uses[i].offset, dequant->output());
151 outputs.push_back(dequant->output());
152 }
153 return outputs;
154 }
155
insertQParam(Graph * graph,Value * quantized_input,NodeKind node_kind,const TypePtr & output_type,const std::string & param_name)156 Node* insertQParam(
157 Graph* graph,
158 Value* quantized_input,
159 NodeKind node_kind,
160 const TypePtr& output_type,
161 const std::string& param_name) {
162 Node* qparam = graph->create(node_kind, {quantized_input});
163 qparam->output()
164 ->setDebugName(quantized_input->debugName() + "." + param_name)
165 ->setType(output_type);
166 graph->insertNode(qparam);
167 return qparam;
168 }
169
insertScalarToTensor(Graph * graph,Value * scalar_value)170 Node* insertScalarToTensor(Graph* graph, Value* scalar_value) {
171 Node* n = scalar_value->node();
172 WithInsertPoint ins(n->next());
173 Value* float_scalar_type = graph->insertConstant(IValue(c10::kFloat));
174 Value* none = graph->insertConstant(IValue());
175 Node* tensor_node = graph->create(
176 Symbol::aten("scalar_tensor"),
177 {scalar_value, float_scalar_type, none, none, none});
178 Value* tensor_output = tensor_node->output();
179 tensor_output->setDebugName(scalar_value->debugName() + ".tensor");
180 graph->insertNode(tensor_node);
181 // replace original_output with tensor
182 scalar_value->replaceAllUsesAfterNodeWith(tensor_node, tensor_output);
183 return tensor_node;
184 }
185
insertItem(Graph * graph,Value * tensor,const TypePtr & output_type)186 Node* insertItem(Graph* graph, Value* tensor, const TypePtr& output_type) {
187 WithInsertPoint ins(tensor->node()->next());
188 Node* n = graph->create(Symbol::aten("item"), {tensor});
189 Value* scalar = n->output();
190 scalar->setDebugName(tensor->debugName() + ".scalar")->setType(output_type);
191 graph->insertNode(n);
192 return n;
193 }
194
insertChooseQParamQuantDequant(Graph * graph,Value * original_val,Value * dtype,NodeKind quant_kind)195 DynamicQuantOps insertChooseQParamQuantDequant(
196 Graph* graph,
197 Value* original_val,
198 Value* dtype,
199 NodeKind quant_kind) {
200 Node* choose_qparams = insertChooseQParams(graph, original_val);
201 std::vector<Value*> quant_inputs = {original_val};
202 for (auto& out : choose_qparams->outputs()) {
203 quant_inputs.push_back(out);
204 }
205 quant_inputs.push_back(dtype);
206 Node* quant = insertQuant(
207 graph, quant_inputs, quant_kind, original_val->debugName() + ".quant");
208 Node* dequant = insertDeQuant(graph, quant->output(), original_val);
209 return std::make_tuple(choose_qparams, quant, dequant);
210 }
211
insertFP16CastOps(Graph * graph,Value * observer_out)212 Node* insertFP16CastOps(Graph* graph, Value* observer_out) {
213 // If the weight value is outside of the range for FP16 range, i.e. [5.96e-8,
214 // 65504], we saturate the values to the min/max of this range.
215 Node* saturated_weight =
216 graph->create(Symbol::aten("_saturate_weight_to_fp16"), {observer_out});
217 graph->insertNode(saturated_weight);
218 graph->lint();
219
220 return saturated_weight;
221 }
222
223 // find the observer for Value `v` and return the name of the observer
findObserverName(Value * v)224 std::optional<std::string> findObserverName(Value* v) {
225 // Note that here we just check for the name of observer, but the ideally
226 // we should be comparing the type of observer, this is a temporary
227 // work around until data only clone of module.clone is supported.
228 Node* n = v->node();
229 if (n->kind() == prim::CallMethod && n->s(attr::name) == "forward") {
230 auto module_instance = n->inputs().at(0);
231 if (module_instance->node()->kind() == prim::GetAttr &&
232 module_instance->node()->s(attr::name).find("_observer_") !=
233 std::string::npos) {
234 return module_instance->node()->s(attr::name);
235 }
236 }
237 return std::nullopt;
238 }
239
isPlaceholderObserver(Value * observer)240 bool isPlaceholderObserver(Value* observer) {
241 if (getModuleName(observer).has_value()) {
242 auto name = getModuleName(observer).value();
243 // if PlaceholderObserver is (anywhere) in name
244 if (name.find("PlaceholderObserver") != std::string::npos) {
245 return true;
246 }
247 }
248 return false;
249 }
250
getObserverDtype(Module & module,Value * v)251 at::ScalarType getObserverDtype(Module& module, Value* v) {
252 auto observer_name = findObserverName(v);
253 if (observer_name.has_value()) {
254 auto observer_module = module.attr(observer_name.value()).toModule();
255 at::ScalarType scalar_type = observer_module.attr("dtype").toScalarType();
256 return scalar_type;
257 }
258 return at::ScalarType::Undefined;
259 }
260
getEmbeddingBagObsName(script::Module & module,Node * n)261 std::optional<std::string> getEmbeddingBagObsName(
262 script::Module& module,
263 Node* n) {
264 Value* v = n->output();
265 auto observer = n->input(0);
266 auto observer_module = module.attr(findObserverName(v).value()).toModule();
267 if (observer_module.hasattr("custom_op")) {
268 auto op_name = observer_module.attr("custom_op").toStringRef();
269 return isPlaceholderObserver(observer) ? std::move(op_name) : "";
270 }
271 return std::nullopt;
272 }
273
isEmbeddingBagOp(Node * observer,std::optional<std::string> embedding_bag_name)274 bool isEmbeddingBagOp(
275 Node* observer,
276 std::optional<std::string> embedding_bag_name) {
277 return embedding_bag_name &&
278 embedding_bag_name.value().find("embedding_bag_") != std::string::npos;
279 }
280
281 template <typename T>
282 Node* insertQuantDequantNodes(
283 Value* self,
284 Node* observer,
285 T& qparams,
286 const std::string& quantize_func);
287
288 // Insert quant and dequant nodes into the graph for both static and dynamic
289 // quant.
290 template <>
insertQuantDequantNodes(Value * self,Node * observer,std::vector<std::string> & qparam_names,const std::string & quantize_func)291 Node* insertQuantDequantNodes<std::vector<std::string>>(
292 Value* self,
293 Node* observer,
294 std::vector<std::string>& qparam_names,
295 const std::string& quantize_func) {
296 Graph* g = observer->owningGraph();
297 Value* observer_out = observer->output();
298 Value* original_val = observer->input(1);
299 std::vector<Value*> inputs = {observer_out};
300 // Insert GetAttr nodes for quantization parameters
301 for (const auto& qparam_name : qparam_names) {
302 inputs.push_back(g->insertGetAttr(self, qparam_name));
303 }
304 Node* quant = insertQuant(
305 g,
306 inputs,
307 at::Symbol::aten(quantize_func),
308 original_val->debugName() + ".quant");
309 Node* dequant = insertDeQuant(g, quant->output(), original_val);
310 return dequant;
311 }
312
insertEmbeddingBagOps(Node * observer,const std::string & op_name)313 Node* insertEmbeddingBagOps(Node* observer, const std::string& op_name) {
314 Graph* g = observer->owningGraph();
315 auto observer_out = observer->output();
316
317 std::string prepack_fn, quant_fn;
318 std::vector<Value*> prepack_inputs = {observer_out};
319 if (op_name == "embedding_bag_4bit") {
320 bool optimized_qparams = false;
321 constexpr int NBINS = 200;
322 constexpr float RATIO = 0.16;
323 Value* optimized_qparams_false = g->insertConstant(optimized_qparams);
324 Value* nbins_200 = g->insertConstant(NBINS);
325 Value* ratio_0_16 = g->insertConstant(RATIO);
326 prepack_fn = "quantized::embedding_bag_4bit_prepack";
327 quant_fn = "quantized::embedding_bag_4bit_rowwise_offsets";
328 prepack_inputs.push_back(optimized_qparams_false);
329 prepack_inputs.push_back(nbins_200);
330 prepack_inputs.push_back(ratio_0_16);
331 } else if (op_name == "embedding_bag_byte") {
332 prepack_fn = "quantized::embedding_bag_byte_prepack";
333 quant_fn = "quantized::embedding_bag_byte_rowwise_offsets";
334 } else {
335 TORCH_INTERNAL_ASSERT(
336 false,
337 "Graph Mode Quantization currently supports 4-bit and 8-bit embedding bag quantization.");
338 }
339
340 std::vector<Use> uses = observer_out->uses();
341 Node* embedding_bag_float_op = nullptr;
342 // We expect that the output of the weight observer will be consumed by the
343 // embedding_bag operator.
344 for (const Use& use : uses) {
345 if (matchCallFuncToUse(use, "embedding_bag", 2) ||
346 matchAtenFuncToUse(use, "embedding_bag", 0)) {
347 embedding_bag_float_op = use.user;
348 }
349 }
350
351 // Insert prepack op
352 Node* prepack = g->create(Symbol::fromQualString(prepack_fn), prepack_inputs);
353 g->insertNode(prepack);
354
355 std::vector<Value*> embedding_bag_inputs =
356 embedding_bag_float_op->inputs().vec();
357 std::vector<Value*> qembedding_bag_inputs = {prepack->output()};
358 const auto inputs_size = embedding_bag_float_op->inputs().size();
359 const bool is_aten_op =
360 embedding_bag_float_op->kind() == Symbol::aten("embedding_bag");
361 // Create and insert quantized embedding op.
362 Value* none = g->insertConstant(IValue());
363 Value* zero = g->insertConstant(IValue(0));
364 bool pruned_wt = false;
365 auto pruned_const = g->insertConstant(pruned_wt);
366
367 if (is_aten_op) {
368 TORCH_CHECK(
369 inputs_size == 9,
370 "Expecting FP aten::embedding_bag operator to have 9 inputs");
371 // input 0 is the output of prepack op.
372 // Last input is added after we account for extra input in 4-bit case.
373 for (unsigned long i = 1; i < inputs_size - 2; ++i) {
374 qembedding_bag_inputs.push_back(embedding_bag_inputs[i]);
375 }
376 // The sparse field in the float operator denotes sparse gradients.
377 // For inference this stands for pruned weights. We currently don't support
378 // pruning in graph mode API so we set the field to 0 for inference.
379 qembedding_bag_inputs[5] = pruned_const;
380 } else {
381 TORCH_CHECK(
382 inputs_size == 12,
383 "Expecting F.embedding_bag operator to have 12 inputs");
384 qembedding_bag_inputs.push_back(embedding_bag_inputs[1]); // indices
385 qembedding_bag_inputs.push_back(embedding_bag_inputs[3]); // offsets
386 qembedding_bag_inputs.push_back(
387 embedding_bag_inputs[6]); // scale_grad_by_freq
388 qembedding_bag_inputs.push_back(zero); // mode
389 qembedding_bag_inputs.push_back(pruned_const); // pruned_weights
390 qembedding_bag_inputs.push_back(
391 embedding_bag_inputs[9]); // per_sample_weights
392 }
393
394 qembedding_bag_inputs.push_back(none); // compressed_indices_mapping
395 qembedding_bag_inputs.push_back(embedding_bag_inputs[inputs_size - 2]);
396
397 TORCH_CHECK(
398 embedding_bag_inputs[inputs_size - 1]->mustBeNone(),
399 "Expected aten::embedding_bag padding_idx input to be None");
400
401 Node* qembedding_bag =
402 g->create(Symbol::fromQualString(quant_fn), qembedding_bag_inputs);
403 if (is_aten_op) {
404 WithInsertPoint ins(embedding_bag_float_op);
405 g->insertNode(qembedding_bag);
406 // Verify that the outputs (apart from index 0) have no uses in the graph.
407 for (const auto i :
408 c10::irange(1, embedding_bag_float_op->outputs().size())) {
409 TORCH_CHECK(
410 !embedding_bag_float_op->output(i)->hasUses(),
411 "Expected aten::embedding_bag to only have use for its first output.");
412 }
413 } else {
414 g->insertNode(qembedding_bag);
415 }
416 embedding_bag_float_op->output(0)->replaceAllUsesWith(
417 qembedding_bag->output());
418 embedding_bag_float_op->removeAllInputs();
419 embedding_bag_float_op->destroy();
420 g->lint();
421 return qembedding_bag;
422 }
423
424 template <typename T>
insertQuantizationOps(Module & module,Value * self,Node * observer,bool is_per_channel,T & qparams,QuantType quant_type=QuantType::STATIC)425 void insertQuantizationOps(
426 Module& module,
427 Value* self,
428 Node* observer,
429 bool is_per_channel,
430 T& qparams,
431 QuantType quant_type = QuantType::STATIC) {
432 Graph* g = observer->owningGraph();
433 // Observer output
434 Value* observer_out = observer->output();
435 // Inserting before insert point
436 WithInsertPoint ins(observer_out->node()->next());
437
438 std::string quantize_func;
439 if (is_per_channel) {
440 quantize_func = "quantize_per_channel";
441 } else {
442 quantize_func = "quantize_per_tensor";
443 }
444 Value* original_val = observer->input(1);
445 // Temporary solution to quantize embedding_bag operators. Will be re-written
446 // once we support quantization of embedding_bag weights.
447 auto embedding_bag_name = getEmbeddingBagObsName(module, observer);
448 if (isEmbeddingBagOp(observer, embedding_bag_name)) {
449 if (isWeight(module, observer_out)) {
450 auto op_name = embedding_bag_name.value();
451 Node* dequant = insertEmbeddingBagOps(observer, op_name);
452 observer_out->replaceAllUsesWith(original_val);
453 original_val->replaceAllUsesAfterNodeWith(dequant, dequant->output());
454 } else {
455 // Special case for embedding bag operators indices input - we don't
456 // quantize the input but we still need to insert observers for it because
457 // the order of input and weight can be changed in the module code.
458 observer_out->replaceAllUsesWith(original_val);
459 }
460 return;
461 }
462 Node* dequant = nullptr;
463 if (quant_type == QuantType::DYNAMIC) {
464 if (getObserverDtype(module, observer_out) == at::ScalarType::Half) {
465 dequant = insertFP16CastOps(g, observer_out);
466 } else if (!isWeight(module, observer_out)) {
467 auto observer_dtype = getObserverDtype(module, observer_out);
468 if (observer_dtype == at::ScalarType::QUInt8 ||
469 observer_dtype == at::ScalarType::QInt8) {
470 // For activation tensors we insert choose_qparams, quant, dequant ops.
471 Value* dtype = g->insertGetAttr(self, qparams.back());
472 dequant = std::get<2>(insertChooseQParamQuantDequant(
473 g, observer_out, dtype, at::Symbol::aten(quantize_func)));
474 } else {
475 // dtype does not require quantization, e.g. float32
476 // will just remove the observer call
477 observer_out->replaceAllUsesWith(original_val);
478 return;
479 }
480 } else {
481 // For weight tensors we insert quant-dequant ops.
482 dequant = insertQuantDequantNodes(self, observer, qparams, quantize_func);
483 }
484 } else { // Static quant
485 dequant = insertQuantDequantNodes(self, observer, qparams, quantize_func);
486 }
487 observer_out->replaceAllUsesWith(original_val);
488
489 original_val->replaceAllUsesAfterNodeWith(dequant, dequant->output());
490 GRAPH_DUMP("insert nodes:", original_val->owningGraph());
491 }
492
ReplicateChooseQParamsQuantDequant(std::shared_ptr<Graph> & graph)493 void ReplicateChooseQParamsQuantDequant(std::shared_ptr<Graph>& graph) {
494 const PatternInfo& dynamic_quant_pattern = PatternInfo::parse_from_str(R"(
495 graph(%a, %reduce_range, %a_dtype):
496 %a_scale : float, %a_zero_point : int = aten::_choose_qparams_per_tensor(%a, %reduce_range)
497 %a_quant = aten::quantize_per_tensor(%a, %a_scale, %a_zero_point, %a_dtype)
498 %a_dequant = aten::dequantize(%a_quant)
499 return (%a_dequant) )");
500 const Graph& dynamic_quant_graph = *dynamic_quant_pattern.pattern_graph;
501
502 const auto& matches = findPatternMatches(dynamic_quant_graph, *graph);
503 if (matches.empty()) {
504 return;
505 }
506
507 const auto& vmap = dynamic_quant_pattern.vmap;
508 Value* dequant_val = vmap.at("a_dequant");
509 Node* pattern_dequant = dequant_val->node();
510 Value* quant_val = vmap.at("a_quant");
511 Node* pattern_quant = quant_val->node();
512 Value* choose_qparam_val = vmap.at("a_scale");
513 Node* pattern_choose_qparam = choose_qparam_val->node();
514
515 std::vector<DynamicQuantOps> nodes_to_rewrite;
516 std::vector<Node*> choose_qparam_nodes_to_rewrite;
517 for (const Match& match : matches) {
518 Node* matched_dequantize = match.nodes_map.at(pattern_dequant);
519 Node* matched_quantize = match.nodes_map.at(pattern_quant);
520 Node* matched_choose_qparam = match.nodes_map.at(pattern_choose_qparam);
521 if (matched_dequantize->output()->uses().size() > 1) {
522 nodes_to_rewrite.emplace_back(
523 matched_choose_qparam, matched_quantize, matched_dequantize);
524 }
525 }
526 for (const auto& nodes : nodes_to_rewrite) {
527 auto quant_node = std::get<1>(nodes);
528 auto dequant_node = std::get<2>(nodes);
529 // get input of quantize call.
530 Value* original_val = quant_node->inputs()[0];
531 Value* dequant_out = dequant_node->output();
532 Value* dtype = quant_node->inputs()[3];
533 std::vector<Use> uses = dequant_out->uses();
534 for (const Use& use : uses) {
535 auto* user = use.user;
536 WithInsertPoint ins(user);
537 auto quant_ops = insertChooseQParamQuantDequant(
538 graph.get(), original_val, dtype, quant_node->kind());
539 user->replaceInputWith(dequant_out, std::get<2>(quant_ops)->output());
540 }
541 }
542 for (const auto& n : nodes_to_rewrite) {
543 auto [choose_qparams, quant, dequant] = n;
544 dequant->removeAllInputs();
545 quant->removeAllInputs();
546 choose_qparams->removeAllInputs();
547 }
548 for (const auto& n : nodes_to_rewrite) {
549 auto [choose_qparams, quant, dequant] = n;
550 dequant->destroy();
551 quant->destroy();
552 choose_qparams->destroy();
553 }
554 }
555
RemoveRedundantDequantize(std::shared_ptr<Graph> & graph)556 void RemoveRedundantDequantize(std::shared_ptr<Graph>& graph) {
557 const std::string dequantize = R"(
558 graph(%a_quant):
559 %a_dequant = aten::dequantize(%a_quant)
560 return (%a_dequant) )";
561 const std::string dequantize_replacement = R"(
562 graph(%a):
563 return (%a) )";
564 auto filter = [&](const Match& match,
565 const std::unordered_map<std::string, Value*>& vmap) {
566 const auto& match_vmap = match.values_map;
567 auto dequant_node = match_vmap.at(vmap.at("a_dequant"))->node();
568 Value* dequant_out = dequant_node->output();
569 // Values can be used multiple times in a single node
570 if (dequant_out->uses().size() != 1) {
571 return false;
572 }
573 Node* user = dequant_out->uses()[0].user;
574 return isTensorInfoNode(user);
575 };
576 SubgraphRewriter rewriter;
577 rewriter.RegisterRewritePattern(dequantize, dequantize_replacement);
578 rewriter.runOnGraph(graph, filter);
579 }
580
RemoveRedundantQuantizationOps(std::shared_ptr<Graph> & graph)581 void RemoveRedundantQuantizationOps(std::shared_ptr<Graph>& graph) {
582 const std::string dynamic_quant_ops = R"(
583 graph(%a, %reduce_range, %a_dtype):
584 %a_scale : float, %a_zero_point : int = aten::_choose_qparams_per_tensor(%a, %reduce_range)
585 %a_quant = aten::quantize_per_tensor(%a, %a_scale, %a_zero_point, %a_dtype)
586 %a_dequant = aten::dequantize(%a_quant)
587 return (%a_dequant) )";
588 const std::string dynamic_quant_replacement = R"(
589 graph(%a, %reduce_range, %a_dtype):
590 return (%a) )";
591 auto filter = [&](const Match& match,
592 const std::unordered_map<std::string, Value*>& vmap) {
593 const auto& match_vmap = match.values_map;
594 auto dequant_node = match_vmap.at(vmap.at("a_dequant"))->node();
595 Value* dequant_out = dequant_node->output();
596 // Values can be used multiple times in a single node
597 if (dequant_out->uses().size() != 1) {
598 return false;
599 }
600 Node* user = dequant_out->uses()[0].user;
601 return !nodeQuantizable(user, QuantType::DYNAMIC);
602 };
603 SubgraphRewriter rewriter;
604 rewriter.RegisterRewritePattern(dynamic_quant_ops, dynamic_quant_replacement);
605 rewriter.runOnGraph(graph, filter);
606 }
607
ReplicateClampScalarArgs(std::shared_ptr<Graph> & graph)608 void ReplicateClampScalarArgs(std::shared_ptr<Graph>& graph) {
609 std::stack<Block*> blocks_to_visit;
610 std::unordered_set<Node*> scalar_nodes_to_rewrite;
611 ;
612 blocks_to_visit.push(graph->block());
613 while (!blocks_to_visit.empty()) {
614 Block* b = blocks_to_visit.top();
615 blocks_to_visit.pop();
616 for (Node* n : b->nodes()) {
617 for (Value* output : n->outputs()) {
618 if (getClampScalarInputUse(output) && output->uses().size() > 1) {
619 scalar_nodes_to_rewrite.insert(n);
620 }
621 }
622 for (Block* subblock : n->blocks()) {
623 blocks_to_visit.push(subblock);
624 }
625 }
626 }
627
628 for (Node* n : scalar_nodes_to_rewrite) {
629 const std::vector<Use> uses = n->output()->uses();
630 for (const auto& use : uses) {
631 Node* user = use.user;
632 WithInsertPoint ins(user);
633 Node* cloned_node = graph->createClone(n, [](Value* v) { return v; });
634 graph->insertNode(cloned_node);
635 user->replaceInput(use.offset, cloned_node->output());
636 }
637 }
638
639 for (Node* n : scalar_nodes_to_rewrite) {
640 n->removeAllInputs();
641 }
642
643 for (Node* n : scalar_nodes_to_rewrite) {
644 n->destroy();
645 }
646 }
647
checkCalculateQParamsResult(const IValue & qparams)648 void checkCalculateQParamsResult(const IValue& qparams) {
649 TORCH_CHECK(
650 qparams.isTuple(),
651 "`calculate_qparams` function is expected to return a "
652 "Tuple, but got:",
653 qparams.tagKind());
654 auto tp = qparams.toTuple();
655 TORCH_CHECK(
656 tp->elements().size() == 2,
657 "`calculate_qparams` function is expected to return a "
658 "Tuple of size 2, got Tuple of size ",
659 tp->elements().size());
660 // Expect first two elements of the tuple to be Tensor
661 for (const auto i : c10::irange(2)) {
662 TORCH_CHECK(
663 tp->elements()[i].isTensor(),
664 "Element of Tuple is expected to be Tensor, but element ",
665 i,
666 " has type: ",
667 tp->elements()[i].tagKind());
668 }
669 }
670
671 class SubGraphCloneHelper {
672 public:
673 // Given a list of nodes, build a graph corresponding to these nodes.
674 // User should make sure to run this graph with expected input.
675 std::unique_ptr<GraphFunction> buildGraphFromNodes(
676 const std::vector<Node*>& nodes,
677 const std::string& name);
678
679 // Given a list of nodes in src, produce a Graph with these nodes.
680 void buildObserverSubgraph(
681 const std::vector<Node*>& src,
682 std::shared_ptr<Graph> dest);
683
684 private:
685 // Clone node in the destination Graph g.
686 void cloneNodeInGraph(
687 Node* node,
688 std::shared_ptr<Graph>& g,
689 std::unordered_map<Value*, Value*>& remap_values);
690 };
691
692 class InsertQuantDeQuantHelper {
693 public:
InsertQuantDeQuantHelper(QuantType quant_type,bool debug)694 InsertQuantDeQuantHelper(QuantType quant_type, bool debug)
695 : quant_type_(quant_type), debug_(debug) {}
696
697 void run(Module& module, const std::string& method_name);
698
699 void runForOnDevicePTQ(Module& module, const std::string& method_name);
700
701 // Cleanup observer nodes from graph and observer modules
702 // from module object and ClassType
703 void cleanup(Module& module);
704
705 // Cleanup observer nodes only but not modules
706 // This is for ondevice PTQ
707 void removeObserverNodes(Module& m);
708
709 // In order to propagate quantization ops through the ops that doesn't
710 // require observation, we'll first inline the graph, and call the
711 // PropagateQuantizationOps pass
712 void propagateQuantizationOps(Module& module);
713
714 // Used for dynamic quantization to selectively run the weight observers.
715 // It extracts the subgraph corresponding to the weight and runs it with
716 // the module instance.
717 void runWeightObserver(Module& module, const std::string& method_name);
718
719 private:
720 ModuleMethodVector getInvokedMethods(
721 Module& module,
722 const std::string& method_name);
723
724 // Get quantization parameter map of the given Value in Graph
725 // by searching for observer module of the value and extract the
726 // quantization parameters from the observer module
727 std::tuple<c10::QScheme, QParamVector> getQSchemeAndQParamVector(
728 script::Module& module,
729 Node* n);
730 QuantOpParams insertCalculateQParams(
731 script::Module& module,
732 Graph* g,
733 Node* n);
734
checkQScheme(Graph * g,c10::QScheme qscheme)735 void checkQScheme(Graph* g, c10::QScheme qscheme) {
736 if (qscheme_for_graph_.count(g)) {
737 // FIXME[T110786721]: This check was broken before nevery failing.
738 // Once fixed, this check triggers and fails tests.
739 // Fix the tests that enabling this check produce!
740 /*
741 TORCH_CHECK(
742 qscheme_for_graph_.at(g) == qscheme,
743 "Quantizing same graph with different types of "
744 "QSchemes is not supported.\n",
745 " Expecting:",
746 c10::toString(qscheme_for_graph_.at(g)),
747 " Got:",
748 c10::toString(qscheme));
749 */
750 } else {
751 qscheme_for_graph_[g] = toAffine(qscheme);
752 }
753 }
754
755 void collectObserverNodesAndValueToQuantize(Module& module, Value*);
756 void cleanup(Module& module, Graph* g);
757 void removeObserverNodes(Graph* g);
758 void quantizeTensors(Module& module, Graph* g, Value* self);
759 void insertCalculateQParamsAndQuantizationOps(
760 Module& module,
761 Graph* g,
762 Value* self);
763
764 // Function that extracts and runs the weight observer in a separate
765 // subgraph.
766 void extractAndRunWeightObserver(
767 Module& module,
768 Value* self,
769 Value* weight_value);
770
771 // Recursively find the nodes that produce the value and add to subgraph.
772 void findSubgraph(Value* self, Value* v, std::vector<Node*>& weight_subgraph);
773
774 // Quantizes two types of general ops(ops that works both for floating point
775 // and quantized Tensors) in this pass
776 // for ops that only manipulates shape, e.g. flatten, quantization
777 // is done by swapping with previous dequantize op
778 // for ops that manipulates values of Tensor, e.g. average pool, quantization
779 // is done by inserting quant/dequant ops after the op
780 // also has a special handling of clamp/hardtanh
781 void propagateQuantizationOps(Block* block);
782
783 // Propagate quantization parameters from other quantized tensors
784 void propagateQParams(
785 Value* original_output,
786 const std::vector<Value*>& inputs,
787 bool is_scalar = false,
788 const std::optional<std::tuple<c10::QScheme, QParamVector>>& qparams_opt =
789 std::nullopt);
790
isQuantized(Value * v)791 bool isQuantized(Value* v) {
792 return quantized_values_.count(v) != 0;
793 }
794
795 std::unordered_map<Graph*, std::vector<std::string>>
796 observer_modules_to_remove_;
797 // We only remove observer module attributes from type in the
798 // first encounter of the graph, after that since the attributes
799 // is already removed from the ClassType, we'll use the list of slot index to
800 // replay this removal
801 std::unordered_map<Graph*, std::vector<int>> removed_observer_slots_;
802 std::unordered_map<Graph*, std::vector<Node*>> nodes_to_destroy_;
803 // Map from Graph to observer node, we can use observer node to
804 // get the information of original value that's been observed and
805 // the quantization parameters
806 std::unordered_map<Graph*, std::vector<Node*>> observer_nodes_for_graph_;
807 // A map from qparam name (e.g. _scale) to the attribute name in
808 // the module(e.g. weight_scale_0)
809 std::unordered_map<Node*, std::unordered_map<std::string, std::string>>
810 qparam_name_map_for_node_;
811 // Record qscheme for every graph, this is for checking
812 // each graph is only quantized with one type of QScheme
813 std::unordered_map<Graph*, c10::QScheme> qscheme_for_graph_;
814
815 // Set of quantized values, so that we quantize each value only
816 // once
817 std::unordered_set<Value*> quantized_values_;
818
819 // Map from original weight value to GraphFunction corresponding to the
820 // subgraph that includes the weight observer and dependent nodes.
821 std::unordered_map<Value*, std::unique_ptr<GraphFunction>>
822 weight_to_graph_fn_;
823
824 QuantType quant_type_ = QuantType::STATIC;
825 bool debug_ = false;
826 };
827
collectObserverNodesAndValueToQuantize(Module & module,Value * v)828 void InsertQuantDeQuantHelper::collectObserverNodesAndValueToQuantize(
829 Module& module,
830 Value* v) {
831 auto* g = v->owningGraph();
832 auto observer_name = findObserverName(v);
833 if (!observer_name) {
834 return;
835 }
836 observer_modules_to_remove_[g].push_back(observer_name.value());
837
838 Node* observer = v->node();
839 TORCH_INTERNAL_ASSERT(
840 observer->kind() == prim::CallMethod &&
841 observer->s(attr::name) == "forward" &&
842 observer->inputs()[0]->node()->kind() == prim::GetAttr &&
843 observer->inputs()[0]->node()->s(attr::name) == observer_name);
844
845 // Observer forward call node
846 nodes_to_destroy_[g].push_back(observer);
847 // GetAttr node for observer module
848 nodes_to_destroy_[g].push_back(observer->inputs()[0]->node());
849 observer_nodes_for_graph_[g].push_back(observer);
850 }
851
removeObserverNodes(Module & module)852 void InsertQuantDeQuantHelper::removeObserverNodes(Module& module) {
853 for (auto& method : module.get_methods()) {
854 removeObserverNodes(method.graph().get());
855 }
856 for (Module m : module.children()) {
857 removeObserverNodes(m);
858 }
859 }
860
removeObserverNodes(Graph * g)861 void InsertQuantDeQuantHelper::removeObserverNodes(Graph* g) {
862 if (nodes_to_destroy_.count(g)) {
863 for (auto& n : nodes_to_destroy_.at(g)) {
864 n->removeAllInputs();
865 }
866 for (auto& n : nodes_to_destroy_.at(g)) {
867 n->destroy();
868 }
869 nodes_to_destroy_.at(g).clear();
870 }
871 }
872
cleanup(Module & module)873 void InsertQuantDeQuantHelper::cleanup(Module& module) {
874 for (auto& method : module.get_methods()) {
875 cleanup(module, method.graph().get());
876 }
877 for (Module m : module.children()) {
878 cleanup(m);
879 }
880 }
881
cleanup(Module & module,Graph * g)882 void InsertQuantDeQuantHelper::cleanup(Module& module, Graph* g) {
883 GRAPH_DUMP("Before Remove Observers:", g);
884 removeObserverNodes(g);
885
886 // 1. If we have seen this graph before, this means the observer
887 // attributes has been removed from the type(see step 2) but the slot
888 // index of these attributes are kept in the list, we'll replay the observer
889 // slots removal using these slot indexes
890 if (removed_observer_slots_.count(g)) {
891 for (auto slot : removed_observer_slots_.at(g)) {
892 module._ivalue()->unsafeRemoveSlot(slot);
893 }
894 }
895
896 // 2. Remove observer modules from last one to first one in order to
897 // reduce the time complexity, assuming all the observer modules
898 // are added after the existing modules, we'll have complexity of
899 // O(N) where N is number of observer modules with this optimization
900 if (observer_modules_to_remove_.count(g)) {
901 auto& observers = observer_modules_to_remove_.at(g);
902 for (int64_t i = observers.size() - 1; i >= 0; --i) {
903 auto observer_name = observers[i];
904 GRAPH_DEBUG("Trying to remove: ", observer_name);
905 if (module.type()->hasAttribute(observer_name)) {
906 // We record the slot index here in order to replay the
907 // slot removal in other objects that's sharing the ClassType
908 // since we're going to remove attribute in the ClassType here
909 removed_observer_slots_[g].push_back(
910 module.type()->getAttributeSlot(observer_name));
911 module._ivalue()->unsafeRemoveAttr(observer_name);
912 module.type()->unsafeRemoveAttribute(observer_name);
913 }
914 }
915 observers.clear();
916 }
917 GRAPH_DUMP("After remove observers :", g);
918 }
919
cloneNodeInGraph(Node * node,std::shared_ptr<Graph> & g,std::unordered_map<Value *,Value * > & remap_old_to_new)920 void SubGraphCloneHelper::cloneNodeInGraph(
921 Node* node,
922 std::shared_ptr<Graph>& g,
923 std::unordered_map<Value*, Value*>& remap_old_to_new) {
924 auto* block = g->block();
925 auto value_fn = [&](Value* v) {
926 if (remap_old_to_new.count(v) == 0) {
927 auto new_value = g->block()->addInput();
928 remap_old_to_new[v] = new_value;
929 new_value->copyMetadata(v);
930 return new_value;
931 } else {
932 return remap_old_to_new[v];
933 }
934 };
935
936 auto new_node = block->appendNode(g->createClone(node, value_fn));
937 for (size_t i = 0; i < node->outputs().size(); ++i) {
938 auto oo = node->outputs()[i];
939 auto no = new_node->outputs()[i];
940 remap_old_to_new[oo] = no;
941 }
942 }
943
buildObserverSubgraph(const std::vector<Node * > & weight_subgraph,std::shared_ptr<Graph> dest_graph)944 void SubGraphCloneHelper::buildObserverSubgraph(
945 const std::vector<Node*>& weight_subgraph,
946 std::shared_ptr<Graph> dest_graph) {
947 std::unordered_map<Value*, Value*> remap_old_to_new;
948 // Build weight subgraph
949 for (auto n : weight_subgraph) {
950 cloneNodeInGraph(n, dest_graph, remap_old_to_new);
951 }
952 LintGraph(dest_graph);
953
954 // Add last node output value as subgraph output.
955 for (auto out : weight_subgraph.back()->outputs()) {
956 dest_graph->registerOutput(remap_old_to_new[out]);
957 }
958 GRAPH_DUMP("New weight observer subgraph: ", dest_graph);
959 }
960
buildGraphFromNodes(const std::vector<Node * > & nodes,const std::string & name)961 std::unique_ptr<GraphFunction> SubGraphCloneHelper::buildGraphFromNodes(
962 const std::vector<Node*>& nodes,
963 const std::string& name) {
964 auto observer_subgraph = std::make_shared<Graph>();
965 auto build_observer_graph = [&](GraphFunction& func) {
966 buildObserverSubgraph(nodes, func.graph());
967 };
968 return std::make_unique<GraphFunction>(
969 name, observer_subgraph, build_observer_graph);
970 }
971
findSubgraph(Value * self,Value * input_val,std::vector<Node * > & weight_subgraph)972 void InsertQuantDeQuantHelper::findSubgraph(
973 Value* self,
974 Value* input_val,
975 std::vector<Node*>& weight_subgraph) {
976 Node* node = input_val->node();
977 weight_subgraph.push_back(node);
978 const auto& inputs = node->inputs().vec();
979 for (auto v : inputs) {
980 if (!hitGraphInput(v)) {
981 findSubgraph(self, v, weight_subgraph);
982 } else {
983 TORCH_CHECK(
984 v == self,
985 "Unexpected value found when handling weight value "
986 " in findSubgraph, traced back to:",
987 v->debugName(),
988 " which is not self:",
989 self->debugName());
990 }
991 }
992 }
993
extractAndRunWeightObserver(Module & module,Value * self,Value * weight_value)994 void InsertQuantDeQuantHelper::extractAndRunWeightObserver(
995 Module& module,
996 Value* self,
997 Value* weight_value) {
998 std::vector<Node*> weight_subgraph;
999 // If the graph was already visited, return the GraphFunction directly.
1000 // Multiple module instances can share the same graph code, so we don't need
1001 // to re-run the extraction process.
1002 if (weight_to_graph_fn_.count(weight_value) == 0) {
1003 // Extract the subgraph nodes.
1004 findSubgraph(self, weight_value, weight_subgraph);
1005
1006 // Reverse to traverse subgraph in correct direction
1007 std::reverse(weight_subgraph.begin(), weight_subgraph.end());
1008
1009 // Build the graph using the nodes found from the weight observer.
1010 SubGraphCloneHelper o;
1011 std::unique_ptr<GraphFunction> func =
1012 o.buildGraphFromNodes(weight_subgraph, "observer_subgraph");
1013 weight_to_graph_fn_[weight_value] = std::move(func);
1014 }
1015 Stack module_inp = {module._ivalue()};
1016 // Run the graph with the module input.
1017 weight_to_graph_fn_[weight_value]->run(module_inp);
1018 }
1019
quantizeTensors(Module & module,Graph * g,Value * self)1020 void InsertQuantDeQuantHelper::quantizeTensors(
1021 Module& module,
1022 Graph* g,
1023 Value* self) {
1024 if (!observer_nodes_for_graph_.count(g)) {
1025 return;
1026 }
1027 for (auto* n : observer_nodes_for_graph_.at(g)) {
1028 auto* original_value = n->input(1);
1029 auto tp = getQSchemeAndQParamVector(module, n);
1030 auto qscheme = std::get<0>(tp);
1031 auto qparam_map = std::get<1>(tp);
1032 checkQScheme(g, qscheme);
1033 std::vector<std::string> qparam_names;
1034 for (auto& pr : qparam_map) {
1035 const auto& name = pr.first;
1036 const auto& qparam = pr.second;
1037 size_t uid = 0;
1038 auto qparam_name =
1039 original_value->debugName() + name + "_" + std::to_string(uid++);
1040 while (module.hasattr(qparam_name)) {
1041 qparam_name =
1042 original_value->debugName() + name + "_" + std::to_string(uid++);
1043 }
1044 qparam_name_map_for_node_[n][name] = qparam_name;
1045 module.register_attribute(qparam_name, qparam.type(), qparam);
1046 qparam_names.push_back(qparam_name);
1047 }
1048 insertQuantizationOps(
1049 module, self, n, isPerChannel(qscheme), qparam_names, quant_type_);
1050 }
1051 }
1052
1053 std::tuple<c10::QScheme, QParamVector> InsertQuantDeQuantHelper::
getQSchemeAndQParamVector(script::Module & module,Node * n)1054 getQSchemeAndQParamVector(script::Module& module, Node* n) {
1055 // TODO: refactor findObserverName to take Node* as input
1056 Value* v = n->output();
1057 TORCH_INTERNAL_ASSERT(
1058 v->type()->isSubtypeOf(*TensorType::get()),
1059 "Expected output of observer node to be Tensor");
1060 auto observer_name = findObserverName(v);
1061 TORCH_INTERNAL_ASSERT(
1062 observer_name,
1063 "getQSchemeAndParamMap expects the corresponding observer for ",
1064 v->debugName(),
1065 " exists.");
1066 QParamVector qparams;
1067 c10::QScheme qscheme = c10::kPerTensorAffine;
1068
1069 auto observer_module = module.attr(observer_name.value()).toModule();
1070 auto scalar_type = observer_module.attr("dtype");
1071 if (isPlaceholderObserver(n->input(0))) {
1072 // get compute_dtype for dynamic quantization
1073 if (observer_module.hasattr("is_dynamic") &&
1074 observer_module.attr("is_dynamic").toBool()) {
1075 qparams.emplace_back(kScalarType, observer_module.attr("dtype"));
1076 }
1077 return std::make_tuple(qscheme, std::move(qparams));
1078 } else if (scalar_type == at::ScalarType::Half) {
1079 return std::make_tuple(qscheme, std::move(qparams));
1080 }
1081 auto calculate_qparams = observer_module.get_method("calculate_qparams");
1082 IValue result = calculate_qparams(std::vector<IValue>());
1083 checkCalculateQParamsResult(result);
1084 TORCH_CHECK(
1085 scalar_type.toScalarType() != at::ScalarType::Undefined,
1086 "dtype of observer can't be undefined");
1087 auto tp = result.toTuple();
1088 at::Tensor scale = tp->elements()[0].toTensor().to(at::kFloat);
1089 at::Tensor zero_point = tp->elements()[1].toTensor().to(at::kInt);
1090 // quantization parameters should appear in the same order as
1091 // the argument for quantize_per_tensor/quantize_per_channel function
1092
1093 qscheme = observer_module.attr("qscheme").toQScheme();
1094 if (isPerChannel(qscheme)) {
1095 auto axis = observer_module.attr("ch_axis");
1096 qparams.emplace_back("_scale", scale);
1097 qparams.emplace_back("_zero_point", zero_point);
1098 qparams.emplace_back("_axis", axis.toInt());
1099 } else {
1100 qparams.emplace_back("_scale", scale.item<double>());
1101 qparams.emplace_back("_zero_point", zero_point.item<int64_t>());
1102 }
1103 qparams.emplace_back(kScalarType, scalar_type);
1104 return std::make_tuple(qscheme, std::move(qparams));
1105 }
1106
getInvokedMethods(Module & module,const std::string & method_name)1107 ModuleMethodVector InsertQuantDeQuantHelper::getInvokedMethods(
1108 Module& module,
1109 const std::string& method_name) {
1110 auto graph = module.get_method(method_name).graph();
1111
1112 ModuleMethodVector invoked_methods;
1113 std::stack<Block*> blocks_to_visit;
1114 blocks_to_visit.push(graph->block());
1115 while (!blocks_to_visit.empty()) {
1116 Block* b = blocks_to_visit.top();
1117 blocks_to_visit.pop();
1118 for (Node* n : b->nodes()) {
1119 if (n->kind() == prim::CallMethod) {
1120 auto module_instance = n->inputs()[0];
1121 auto module_method_name = n->s(attr::name);
1122 std::optional<Module> m;
1123 // calling method on self
1124 if (module_instance == graph->inputs()[0]) {
1125 m = module;
1126 } else if (
1127 module_instance->node()->kind() == prim::GetAttr &&
1128 module_instance->node()->s(attr::name).find("_observer_") ==
1129 std::string::npos) {
1130 m = getInvokedModuleOpt(module, n, graph->inputs()[0]);
1131 }
1132 if (m) {
1133 invoked_methods.emplace_back(*m, module_method_name);
1134 }
1135 }
1136
1137 for (Block* subblock : n->blocks()) {
1138 blocks_to_visit.push(subblock);
1139 }
1140 }
1141 }
1142 return invoked_methods;
1143 }
1144
propagateQParams(Value * original_output,const std::vector<Value * > & inputs,bool is_scalar,const std::optional<std::tuple<c10::QScheme,QParamVector>> & qparams_opt)1145 void InsertQuantDeQuantHelper::propagateQParams(
1146 Value* original_output,
1147 const std::vector<Value*>& inputs,
1148 bool is_scalar,
1149 const std::optional<std::tuple<c10::QScheme, QParamVector>>& qparams_opt) {
1150 Node* n = original_output->node();
1151 Graph* graph = n->owningGraph();
1152 if (is_scalar) {
1153 // convert Scalar to Tensor
1154 n = insertScalarToTensor(graph, original_output);
1155 original_output = n->output();
1156 }
1157 // for ops like average pool, we'll insert quant dequant after the op
1158 // We'll assume the tensor is a PerTensorAffine quantized Tensor for
1159 // now, and may generalize later if this becomes an issue
1160 TORCH_INTERNAL_ASSERT(
1161 inputs.size() == 1, "Expecting single input for the aten function");
1162 // input of the dequantize node
1163 Value* quantized_input = inputs[0]->node()->input(0);
1164 // insert ops after the general op
1165 Node* quantized_input_node = quantized_input->node();
1166 // Insert after the node that is later in topological order
1167 WithInsertPoint ins(
1168 quantized_input_node->isAfter(n) ? quantized_input_node->next()
1169 : n->next());
1170 std::vector<Value*> quant_inputs;
1171 auto quant_kind = Symbol::aten("quantize_per_tensor");
1172 if (qparams_opt.has_value()) {
1173 quant_inputs = {original_output};
1174 auto qscheme = std::get<0>(*qparams_opt);
1175 auto qparams = std::get<1>(*qparams_opt);
1176 if (isPerChannel(qscheme)) {
1177 quant_kind = Symbol::aten("quantize_per_channel");
1178 }
1179 for (const auto& qparam : qparams) {
1180 Value* qparam_val = graph->insertConstant(qparam.second);
1181 qparam_val->setDebugName(quantized_input->debugName() + qparam.first);
1182 quant_inputs.push_back(qparam_val);
1183 }
1184 } else {
1185 // Only per tensor affine quantized tensor is supported in this case
1186 // get quantization parameters from previous quantized op
1187 Node* scale = insertQParam(
1188 graph,
1189 quantized_input,
1190 at::Symbol::aten("q_scale"),
1191 FloatType::get(),
1192 "q_scale");
1193 Node* zero_point = insertQParam(
1194 graph,
1195 quantized_input,
1196 at::Symbol::aten("q_zero_point"),
1197 IntType::get(),
1198 "q_zero_point");
1199 Node* dtype = insertQParam(
1200 graph, quantized_input, prim::dtype, IntType::get(), "dtype");
1201 quant_inputs = {
1202 original_output,
1203 scale->output(),
1204 zero_point->output(),
1205 dtype->output()};
1206 }
1207 Node* quant = insertQuant(
1208 graph, quant_inputs, quant_kind, original_output->debugName() + ".quant");
1209 Value* quantized_output = quant->output();
1210 // replace uses of original output of the general op with quantized
1211 // output
1212 original_output->replaceAllUsesAfterNodeWith(quant, quantized_output);
1213 const auto& outputs =
1214 insertDeQuantForAllUse(graph, quantized_output, quantized_output);
1215 for (auto* output : outputs) {
1216 if (is_scalar) {
1217 // Convert the dequantized Tensor back to Scalar
1218 Node* item = insertItem(graph, output, FloatType::get());
1219 Value* scalar = item->output();
1220 output->replaceAllUsesAfterNodeWith(item, scalar);
1221 output = scalar;
1222 }
1223 quantized_values_.insert(output);
1224 }
1225 }
1226
removeDequantizeFromInputs(const std::unordered_set<Value * > & inputs)1227 void removeDequantizeFromInputs(const std::unordered_set<Value*>& inputs) {
1228 // Delete dequantize node, we have one dequantize
1229 // for each use of the value
1230 for (auto* dequantized_val : inputs) {
1231 auto* dequantize_node = dequantized_val->node();
1232 TORCH_INTERNAL_ASSERT(
1233 dequantized_val->uses().size() == 1,
1234 "Expect to have one dequantize node for each use");
1235 // Replace useses of dequantized_val with the input of
1236 // dequantize node
1237 dequantized_val->replaceAllUsesWith(dequantize_node->inputs()[0]);
1238 dequantize_node->removeAllInputs();
1239 dequantize_node->destroy();
1240 }
1241 }
1242
1243 // Check if we need to propagate the quantization ops from input to
1244 // output
getDequantizedInputs(Value * output)1245 std::optional<std::vector<Value*>> getDequantizedInputs(Value* output) {
1246 auto inputs = getPassThroughInputs(output);
1247 if (!inputs.empty()) {
1248 // note that we don't need to recursively check for prim::If
1249 // here because if all inputs of a prim::If is dequantized
1250 // the dequantize will be factored out before we get to this
1251 // point
1252 bool is_dequantized = true;
1253 for (auto* input : inputs) {
1254 GRAPH_DEBUG(
1255 "checking if input:",
1256 input->debugName(),
1257 " in node:",
1258 *input->node(),
1259 "is quantized");
1260 is_dequantized &= input->node()->kind() == Symbol::aten("dequantize");
1261 }
1262 if (is_dequantized) {
1263 return inputs;
1264 }
1265 }
1266 return std::nullopt;
1267 }
1268
propagateQuantizationOps(Block * block)1269 void InsertQuantDeQuantHelper::propagateQuantizationOps(Block* block) {
1270 for (Node* n : block->nodes()) {
1271 if (n->kind() == prim::If) {
1272 for (Block* subblock : n->blocks()) {
1273 propagateQuantizationOps(subblock);
1274 }
1275 if (n->outputs().empty()) {
1276 continue;
1277 }
1278 if (n->outputs().size() > 1) {
1279 // Factoring out dequantize for if blocks with multiple outputs
1280 // is not supported right now
1281 continue;
1282 }
1283 }
1284 if (isSingleInputGeneralValueAtenFunction(n)) {
1285 for (auto* output : n->outputs()) {
1286 if (isQuantized(output)) {
1287 continue;
1288 }
1289 if (auto inputs = getDequantizedInputs(output)) {
1290 propagateQParams(output, *inputs);
1291 if (isClamp(n)) {
1292 for (size_t i = 1; i <= 2; ++i) {
1293 // propagate qparams for min and max scalar arguments
1294 // for aten::clamp/aten::hardtanh
1295 propagateQParams(n->input(i), *inputs, /* is_scalar */ true);
1296 }
1297 }
1298 }
1299 }
1300 } else if (auto qparams_opt = getFixedQParams(n)) {
1301 for (auto* output : n->outputs()) {
1302 if (isQuantized(output)) {
1303 continue;
1304 }
1305 if (auto inputs = getDequantizedInputs(output)) {
1306 propagateQParams(output, *inputs, /* is_scalar */ false, qparams_opt);
1307 }
1308 }
1309 } else {
1310 // For ops that are quantized by propagating dequantize ops,
1311 // e.g. flatten we need to
1312 // 1. check if we need to propagate dequantize op
1313 // 2. remove the dequantize ops from inputs
1314 // 3. insert dequantize for all outputs
1315 // to make sure it works for ops with multiple outputs
1316 // since removing dequantize from inputs is mutating the graph
1317 // and it will affect future checks for whether all the inputs
1318 // has been quantized or not(since currently we just check if
1319 // the value is produced by dequantize op to decide if the value
1320 // is quantized or not
1321 // list of dequantized input values
1322 std::unordered_set<Value*> dequantized_inputs;
1323 std::vector<Value*> outputs_to_dequantize;
1324 // 1. collect dequantized inputs and outputs we need to dequantize
1325 for (auto* output : n->outputs()) {
1326 if (isQuantized(output)) {
1327 continue;
1328 }
1329 if (auto inputs = getDequantizedInputs(output)) {
1330 std::copy(
1331 inputs->begin(),
1332 inputs->end(),
1333 std::inserter(dequantized_inputs, dequantized_inputs.end()));
1334 outputs_to_dequantize.push_back(output);
1335 }
1336 }
1337 // 2. remove the dequantize ops from inputs
1338 removeDequantizeFromInputs(dequantized_inputs);
1339 // 3. insert dequantize op for outputs
1340 for (auto* output : outputs_to_dequantize) {
1341 insertDeQuantForAllUse(output->owningGraph(), output, output);
1342 }
1343 }
1344
1345 if (isBinaryOpWithScalarInput(n)) {
1346 // Print warning for add_scalar/mul_scalar when debug is enabled
1347 // since the quantization parameter for these ops depends on
1348 // input and it's too complicated to encode the equations in
1349 // the IR:
1350 // https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/quantized/cpu/BinaryOps.cpp#L64-L74
1351 if (debug_) {
1352 TORCH_WARN_ONCE(
1353 "debug option for add_scalar and mul_scalar is not supported, "
1354 "please don't use debug option for models that uses these ops.");
1355 }
1356 }
1357 }
1358 }
1359
runWeightObserver(Module & module,const std::string & method_name)1360 void InsertQuantDeQuantHelper::runWeightObserver(
1361 Module& module,
1362 const std::string& method_name) {
1363 if (quant_type_ != QuantType::DYNAMIC) {
1364 return;
1365 }
1366
1367 for (auto& invoked_methods : getInvokedMethods(module, method_name)) {
1368 auto& invoked_module = std::get<0>(invoked_methods);
1369 const auto& invoked_method_name = std::get<1>(invoked_methods);
1370 runWeightObserver(invoked_module, invoked_method_name);
1371 }
1372 Method method = module.get_method(method_name);
1373 auto graph = method.graph();
1374 Value* self = graph->inputs()[0];
1375
1376 std::vector<Value*> weight_values;
1377 // Visit all blocks in the current graph to find weight values.
1378 std::stack<Block*> blocks_to_visit;
1379 blocks_to_visit.push(graph->block());
1380 while (!blocks_to_visit.empty()) {
1381 Block* b = blocks_to_visit.top();
1382 blocks_to_visit.pop();
1383 for (auto n : b->nodes()) {
1384 for (Value* v : n->outputs()) {
1385 if (!v->type()->isSubtypeOf(*TensorType::get())) {
1386 continue;
1387 }
1388 auto observer_name = findObserverName(v);
1389 if (observer_name && isWeight(module, v)) {
1390 weight_values.push_back(v);
1391 }
1392 }
1393 for (Block* subblock : n->blocks()) {
1394 blocks_to_visit.push(subblock);
1395 }
1396 }
1397 }
1398 // For all the observed weight values, find the corresponding subgraph that
1399 // contributes to the weight tensor, and run that subgraph to observe the
1400 // weight.
1401 for (const auto& v : weight_values) {
1402 extractAndRunWeightObserver(module, self, v);
1403 }
1404 }
1405
run(Module & module,const std::string & method_name)1406 void InsertQuantDeQuantHelper::run(
1407 Module& module,
1408 const std::string& method_name) {
1409 for (auto& invoked_methods : getInvokedMethods(module, method_name)) {
1410 auto& invoked_module = std::get<0>(invoked_methods);
1411 const auto& invoked_method_name = std::get<1>(invoked_methods);
1412 run(invoked_module, invoked_method_name);
1413 }
1414
1415 Method method = module.get_method(method_name);
1416 auto graph = method.graph();
1417 // We only need to register new parameters if the graph has
1418 // been quantized before
1419 // TODO: dedup this part with code in quantizeTensors
1420 if (observer_nodes_for_graph_.count(graph.get())) {
1421 for (auto* n : observer_nodes_for_graph_.at(graph.get())) {
1422 auto tp = getQSchemeAndQParamVector(module, n);
1423 checkQScheme(graph.get(), std::get<0>(tp));
1424 auto qparam_map = std::get<1>(tp);
1425 // We check the size here because for some observers (like
1426 // PlaceholderObserver) the qparams might be empty.
1427 if (!qparam_map.empty()) {
1428 TORCH_INTERNAL_ASSERT(
1429 qparam_name_map_for_node_.count(n),
1430 "Expected to have a qparam_name_map for node:",
1431 *n);
1432 auto qparam_name_map = qparam_name_map_for_node_.at(n);
1433 for (auto& pr : qparam_map) {
1434 const auto& name = pr.first;
1435 const auto& qparam = pr.second;
1436 module._ivalue()->setAttr(qparam_name_map.at(name), qparam);
1437 }
1438 }
1439 }
1440 return;
1441 }
1442
1443 // prim::Param nodes do not belong to the graph. Hence the Insert
1444 // point is the beginning of graph node. This also safe guards against
1445 // observing a potentially mutated value due to some in-place operation
1446 std::vector<Value*> input_values;
1447 for (const auto idx : c10::irange(1, method.num_inputs())) {
1448 auto& v = graph->inputs()[idx];
1449 if (v->type()->isSubtypeOf(*TensorType::get())) {
1450 input_values.push_back(v);
1451 }
1452 }
1453
1454 std::stack<Block*> blocks_to_visit;
1455 blocks_to_visit.push(graph->block());
1456 while (!blocks_to_visit.empty()) {
1457 Block* b = blocks_to_visit.top();
1458 blocks_to_visit.pop();
1459 for (auto it = b->nodes().begin(), end = b->nodes().end(); it != end;) {
1460 Node* n = *it++;
1461 for (Value* v : n->outputs()) {
1462 if (!v->type()->isSubtypeOf(*TensorType::get())) {
1463 continue;
1464 }
1465 collectObserverNodesAndValueToQuantize(module, v);
1466 }
1467
1468 for (Block* subblock : n->blocks()) {
1469 blocks_to_visit.push(subblock);
1470 }
1471 }
1472 }
1473
1474 for (Value* v : input_values) {
1475 collectObserverNodesAndValueToQuantize(module, v);
1476 }
1477 GRAPH_DUMP("Before Quantize Tensors:", graph);
1478 Value* self = graph->inputs()[0];
1479 quantizeTensors(module, graph.get(), self);
1480 GRAPH_DUMP("After Quantize Tensors:", graph);
1481 }
1482
propagateQuantizationOps(Module & module)1483 void InsertQuantDeQuantHelper::propagateQuantizationOps(Module& module) {
1484 SwapFunctionalLinear(module);
1485 auto graph = module.get_method("forward").graph();
1486 Inline(*graph);
1487 ConstantPropagation(graph);
1488 ReplicateChooseQParamsQuantDequant(graph);
1489 RemoveRedundantQuantizationOps(graph);
1490 ReplicateQuant(graph);
1491 ReplicateDeQuant(graph);
1492 // TODO: add filter to the clamp patterns and remove this pass
1493 ReplicateClampScalarArgs(graph);
1494 propagateQuantizationOps(graph->block());
1495 RemoveRedundantDequantize(graph);
1496 }
1497
1498 // Insert quant and dequant nodes into the graph for both static and dynamic
1499 // quant.
1500 template <>
insertQuantDequantNodes(Value * self,Node * observer,QuantOpParams & qparams,const std::string & quantize_func)1501 Node* insertQuantDequantNodes<QuantOpParams>(
1502 Value* self,
1503 Node* observer,
1504 QuantOpParams& qparams,
1505 const std::string& quantize_func) {
1506 (void)self;
1507 Graph* g = observer->owningGraph();
1508 Value* observer_out = observer->output();
1509 Value* original_val = observer->input(1);
1510 std::vector<Value*> inputs;
1511 // + 1 for tensor to be quantized
1512 inputs.reserve(qparams.qparams.size() + 1);
1513 inputs.push_back({observer_out});
1514 for (const auto& qparam_values : qparams.qparams) {
1515 inputs.push_back(qparam_values);
1516 }
1517 Node* quant = insertQuant(
1518 g,
1519 inputs,
1520 at::Symbol::aten(quantize_func),
1521 original_val->debugName() + ".quant");
1522 // Have to make sure that quant node appears after the values it depends on.
1523 for (Value* v : inputs) {
1524 quant->moveAfter(v->node());
1525 }
1526 Node* dequant = insertDeQuant(g, quant->output(), original_val);
1527 dequant->moveAfter(quant);
1528 return dequant;
1529 }
1530
checkCalculateQParamsResultTypes(const Node * out)1531 void checkCalculateQParamsResultTypes(const Node* out) {
1532 TORCH_CHECK(
1533 out->outputs().size() == 2,
1534 "calculate_qparams should produce output of size 2 (scale, zero_point).");
1535 Value* scale = out->output(0);
1536 Value* zp = out->output(1);
1537 TORCH_CHECK(
1538 scale->type()->expect<TensorType>(),
1539 "Scale value should be of Tensor type.");
1540 TORCH_CHECK(
1541 zp->type()->expect<TensorType>(), "Scale value should be of float type.");
1542 }
1543
insertCalculateQParams(script::Module & module,Graph * g,Node * n)1544 QuantOpParams InsertQuantDeQuantHelper::insertCalculateQParams(
1545 script::Module& module,
1546 Graph* g,
1547 Node* n) {
1548 // TODO: refactor findObserverName to take Node* as input
1549 Value* self = g->inputs()[0];
1550 Value* v = n->output();
1551 TORCH_INTERNAL_ASSERT(
1552 v->type()->isSubtypeOf(*TensorType::get()),
1553 "Expected output of observer node to be Tensor");
1554 auto observer_name = findObserverName(v);
1555 TORCH_INTERNAL_ASSERT(
1556 observer_name,
1557 "getQSchemeAndParamMap expects the corresponding observer for ",
1558 v->debugName(),
1559 " exists.");
1560 std::vector<Value*> qparams_graph_values;
1561 QuantOpParams quant_op_params;
1562
1563 TORCH_CHECK(
1564 !isPlaceholderObserver(n->input(0)),
1565 "Placeholder observers are not supported in ondevice PTQ.");
1566 auto observer_module = module.attr(observer_name.value()).toModule();
1567 Value* observer_module_value = g->insertGetAttr(self, observer_name.value());
1568 auto scalar_type = observer_module.attr("dtype");
1569 TORCH_CHECK(
1570 scalar_type.toScalarType() != at::ScalarType::Undefined,
1571 "dtype of observer can't be undefined");
1572 // Not sure if we need to support this for on device PTQ.
1573 if (scalar_type == at::ScalarType::Half) {
1574 return quant_op_params;
1575 }
1576 auto calculate_qparams = observer_module.get_method("calculate_qparams");
1577 auto calculate_qparams_schema = calculate_qparams.function().getSchema();
1578 MatchedSchema matched_schema = matchSchema(
1579 calculate_qparams_schema,
1580 v->node()->sourceRange(),
1581 *g,
1582 {observer_module_value},
1583 {});
1584 Node* call = g->insertMethodCall("calculate_qparams", matched_schema)->node();
1585 Node* scale_zp_node = g->insertNode(g->createTupleUnpack(call->output(0)));
1586 checkCalculateQParamsResultTypes(scale_zp_node);
1587 auto qscheme = observer_module.attr("qscheme").toQScheme();
1588 quant_op_params.qscheme = qscheme;
1589 quant_op_params.qparams.push_back(scale_zp_node->output(0)); // scale Value*
1590 quant_op_params.qparams.push_back(
1591 scale_zp_node->output(1)); // zero_point Value*
1592 if (isPerChannel(qscheme)) {
1593 Value* ch_axis_value = g->insertGetAttr(observer_module_value, "ch_axis");
1594 quant_op_params.qparams.push_back(ch_axis_value);
1595 }
1596 Value* scalar_type_value = g->insertGetAttr(observer_module_value, "dtype");
1597 quant_op_params.qparams.push_back(scalar_type_value);
1598 return quant_op_params;
1599 }
1600
insertCalculateQParamsAndQuantizationOps(Module & module,Graph * graph,Value * self)1601 void InsertQuantDeQuantHelper::insertCalculateQParamsAndQuantizationOps(
1602 Module& module,
1603 Graph* graph,
1604 Value* self) {
1605 if (!observer_nodes_for_graph_.count(graph)) {
1606 return;
1607 }
1608 for (auto* n : observer_nodes_for_graph_.at(graph)) {
1609 Graph* g = n->owningGraph();
1610 // Observer output
1611 Value* observer_out = n->output();
1612 // Inserting before insert point
1613 WithInsertPoint insert_qparams_calc(observer_out->node()->next());
1614 auto quant_op_params = insertCalculateQParams(module, g, n);
1615 insertQuantizationOps(
1616 module,
1617 self,
1618 n,
1619 isPerChannel(quant_op_params.qscheme),
1620 quant_op_params,
1621 quant_type_);
1622 }
1623 }
1624
runForOnDevicePTQ(Module & module,const std::string & method_name)1625 void InsertQuantDeQuantHelper::runForOnDevicePTQ(
1626 Module& module,
1627 const std::string& method_name) {
1628 // In all likelihood this really wont do anything because we expect that
1629 // the input method for quantization's prepare step will be inlined. Thus
1630 // only call methods we will see will belong to observer's forward calls.
1631 for (auto& invoked_methods : getInvokedMethods(module, method_name)) {
1632 auto& invoked_module = std::get<0>(invoked_methods);
1633 const auto& invoked_method_name = std::get<1>(invoked_methods);
1634 runForOnDevicePTQ(invoked_module, invoked_method_name);
1635 }
1636
1637 Method method = module.get_method(method_name);
1638 auto graph = method.graph();
1639 // Unliked the run method we dont need to extract new qparam values for the
1640 // the same graph used in different call site.
1641 // Reason is that for on device PTQ we dont:
1642 // 1. Run calculate_qparams
1643 // 2. Get the scale and zero point
1644 // 3. get axis and dtype
1645 // 4. register values from 2 and 3 as attributes on the parent module.
1646 // Instead we insert call to calculate_qparams (1) via insertCalculateQParams
1647 // in the graph itself. Then instead of 2 and 3, we get the output Value*
1648 // and for 3, we insert GetAttr for axis and dtype and use those Value*
1649 // with insterQuantizationOps
1650
1651 // prim::Param nodes do not belong to the graph. Hence the Insert
1652 // point is the beginning of graph node. This also safe guards against
1653 // observing a potentially mutated value due to some in-place operation
1654 std::vector<Value*> input_values;
1655 for (const auto idx : c10::irange(1, method.num_inputs())) {
1656 auto& v = graph->inputs()[idx];
1657 if (v->type()->isSubtypeOf(*TensorType::get())) {
1658 input_values.push_back(v);
1659 }
1660 }
1661
1662 std::stack<Block*> blocks_to_visit;
1663 blocks_to_visit.push(graph->block());
1664 while (!blocks_to_visit.empty()) {
1665 Block* b = blocks_to_visit.top();
1666 blocks_to_visit.pop();
1667 for (auto it = b->nodes().begin(), end = b->nodes().end(); it != end;) {
1668 Node* n = *it++;
1669 for (Value* v : n->outputs()) {
1670 if (!v->type()->isSubtypeOf(*TensorType::get())) {
1671 continue;
1672 }
1673 collectObserverNodesAndValueToQuantize(module, v);
1674 }
1675
1676 for (Block* subblock : n->blocks()) {
1677 blocks_to_visit.push(subblock);
1678 }
1679 }
1680 }
1681
1682 for (Value* v : input_values) {
1683 collectObserverNodesAndValueToQuantize(module, v);
1684 }
1685
1686 GRAPH_DUMP("Before insertCalculateQparamsAndQuantizationOps:", graph);
1687 Value* self = graph->inputs()[0];
1688 insertCalculateQParamsAndQuantizationOps(module, graph.get(), self);
1689 GRAPH_DUMP("After insertCalculateQparamsAndQuantizationOps:", graph);
1690 }
1691
1692 } // namespace
1693
ReplicateQuant(std::shared_ptr<Graph> & graph)1694 void ReplicateQuant(std::shared_ptr<Graph>& graph) {
1695 std::stack<Block*> blocks_to_visit;
1696 std::vector<Node*> quant_nodes_to_rewrite;
1697 blocks_to_visit.push(graph->block());
1698 while (!blocks_to_visit.empty()) {
1699 Block* b = blocks_to_visit.top();
1700 blocks_to_visit.pop();
1701 for (Node* n : b->nodes()) {
1702 // find quantize node that quantizes the output of if
1703 if ((n->kind() == Symbol::aten("quantize_per_tensor") ||
1704 n->kind() == Symbol::aten("quantize_per_channel")) &&
1705 n->input(0)->node()->kind() == prim::If) {
1706 quant_nodes_to_rewrite.push_back(n);
1707 }
1708 for (Block* subblock : n->blocks()) {
1709 blocks_to_visit.push(subblock);
1710 }
1711 }
1712 }
1713 for (Node* n : quant_nodes_to_rewrite) {
1714 Node* if_node = n->input(0)->node();
1715 // move the nodes that produces the quantization parameters before
1716 // prim::If
1717 for (const auto i : c10::irange(1, n->inputs().size())) {
1718 n->input(i)->node()->moveBefore(if_node);
1719 }
1720 // replace all uses of the quantized node with the output of if node
1721 n->output()->replaceAllUsesWith(if_node->output());
1722 // add quantize nodes to the end of all blocks
1723 for (Block* if_block : if_node->blocks()) {
1724 TORCH_CHECK(
1725 if_block->outputs().size() == 1,
1726 "replicate quantize only works for `if` node with one output right now");
1727 // the original return value of the block
1728 Value* ret_val = if_block->outputs()[0];
1729 std::vector<Value*> quantize_inputs = n->inputs().vec();
1730 quantize_inputs[0] = ret_val;
1731 WithInsertPoint ins(if_block->return_node());
1732 Node* quant = graph->create(n->kind(), quantize_inputs);
1733 if_block->replaceOutput(0, quant->output());
1734 quant->output()->copyMetadata(ret_val);
1735 graph->insertNode(quant);
1736 }
1737 }
1738
1739 for (Node* n : quant_nodes_to_rewrite) {
1740 n->removeAllInputs();
1741 }
1742 for (Node* n : quant_nodes_to_rewrite) {
1743 n->destroy();
1744 }
1745 }
1746
ReplicateDeQuant(std::shared_ptr<Graph> & graph)1747 void ReplicateDeQuant(std::shared_ptr<Graph>& graph) {
1748 std::stack<Block*> blocks_to_visit;
1749 std::vector<Node*> dequant_nodes_to_rewrite;
1750 blocks_to_visit.push(graph->block());
1751 while (!blocks_to_visit.empty()) {
1752 Block* b = blocks_to_visit.top();
1753 blocks_to_visit.pop();
1754 for (Node* n : b->nodes()) {
1755 if (n->kind() == Symbol::aten("dequantize") &&
1756 n->output()->uses().size() > 1) {
1757 dequant_nodes_to_rewrite.push_back(n);
1758 }
1759 for (Block* subblock : n->blocks()) {
1760 blocks_to_visit.push(subblock);
1761 }
1762 }
1763 }
1764 for (Node* n : dequant_nodes_to_rewrite) {
1765 auto* quantized_val = n->input(0);
1766 auto* dequantized_val = n->output();
1767 insertDeQuantForAllUse(graph.get(), quantized_val, dequantized_val);
1768 }
1769
1770 for (Node* n : dequant_nodes_to_rewrite) {
1771 n->removeAllInputs();
1772 }
1773
1774 for (Node* n : dequant_nodes_to_rewrite) {
1775 n->destroy();
1776 }
1777 }
1778
InsertQuantDeQuant(Module & input_module,const std::string & method_name,bool inplace,bool debug,QuantType quant_type)1779 Module InsertQuantDeQuant(
1780 Module& input_module,
1781 const std::string& method_name,
1782 bool inplace,
1783 bool debug,
1784 QuantType quant_type) {
1785 Module module = input_module.clone(inplace);
1786 InsertQuantDeQuantHelper h(quant_type, debug);
1787 h.runWeightObserver(module, method_name);
1788 h.run(module, method_name);
1789 h.cleanup(module);
1790 h.propagateQuantizationOps(module);
1791 return module;
1792 }
1793
1794 /*
1795 *
1796 * Assumption: method_name method has observer placed
1797 * Objective: modify that method to insert calls to:
1798 * 1. calculate_qparams
1799 * 2. GetAttr for axis and dtype values
1800 * 3. Use Values from above two to insert calls to quant + dequant
1801 * Thus after this step you have a graph of, e.g., observe_forward,
1802 * that has observer nodes, calculate_qparams run on those observer nodes,
1803 * output of which is used by quant-dequant nodes. output of dequant is used
1804 * by the actual op.
1805 * Later on we will replace dequant + op (e.g. linear) with
1806 * 1. prepacked_op context
1807 * 2. unpack
1808 * 3. dequantize
1809 * 4. linear
1810 *
1811 * Of the above pattern 2, 3, and 4 can be replaced by linear_run op
1812 */
1813 // Module InsertQuantDeQuantForOnDevicePTQ(
InsertQuantDeQuantOnDevicePTQ(Module & input_module,const std::string & method_name,bool inplace,bool debug,QuantType quant_type)1814 Module InsertQuantDeQuantOnDevicePTQ(
1815 Module& input_module,
1816 const std::string& method_name,
1817 bool inplace,
1818 bool debug,
1819 QuantType quant_type) {
1820 Module module = input_module.clone(inplace);
1821 const std::string kObserveString = "observe_";
1822 const auto matched_pos = method_name.find(kObserveString);
1823 const auto end_pos = matched_pos + kObserveString.length();
1824 const std::string orig_method_name = method_name.substr(end_pos);
1825 TORCH_CHECK(
1826 matched_pos == 0,
1827 "Quant dequant nodes can only be added to observe_",
1828 orig_method_name,
1829 ". Please make sure to run prepare step for on-device PTQ.");
1830
1831 std::string quantize_method_name = "quantize_" + orig_method_name;
1832 cloneMethod(module, method_name, quantize_method_name);
1833 InsertQuantDeQuantHelper h(quant_type, debug);
1834 h.runForOnDevicePTQ(module, quantize_method_name);
1835 h.removeObserverNodes(module);
1836 // Dont need:
1837 // ReplicateChooseQParamsQuantDequant: This is propagating dynamic quant's
1838 // quant dequant RemoveRedundantQuantizationOps: THis is removing activation
1839 // observers for dynamic quant when the op related to it is not dynamically
1840 // quantizable. Doesnt really make sense. In our case we wont have those
1841 // anyway since for dynamic quant activations wont be observed We can still
1842 // use this function because the above two methods should really be a noop
1843 h.propagateQuantizationOps(module);
1844 return module;
1845 }
1846 } // namespace jit
1847 } // namespace torch
1848