xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/passes/quantization/insert_quant_dequant.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/passes/quantization/insert_quant_dequant.h>
2 
3 #include <c10/core/QScheme.h>
4 #include <c10/util/irange.h>
5 #include <torch/csrc/jit/frontend/schema_matching.h>
6 #include <torch/csrc/jit/ir/subgraph_matcher.h>
7 #include <torch/csrc/jit/jit_log.h>
8 #include <torch/csrc/jit/passes/constant_propagation.h>
9 #include <torch/csrc/jit/passes/fuse_linear.h>
10 #include <torch/csrc/jit/passes/graph_rewrite_helper.h>
11 #include <torch/csrc/jit/passes/inliner.h>
12 #include <torch/csrc/jit/passes/quantization/helper.h>
13 #include <torch/csrc/jit/passes/subgraph_rewrite.h>
14 
15 #include <stack>
16 #include <utility>
17 
18 namespace torch {
19 namespace jit {
20 
21 namespace {
22 using graph_rewrite_helper::PatternInfo;
23 
24 // dynamic quantization ops for activation: choose_qparams, quant, dequant
25 using DynamicQuantOps = std::tuple<Node*, Node*, Node*>;
26 
27 std::string kScalarType = "_scalar_type";
28 
29 struct QuantOpParams {
30   c10::QScheme qscheme{c10::kPerTensorAffine};
31   std::vector<Value*> qparams;
32   // This is only so that insertQuantizationOps can be templatized
33   // and subsequently significant portion of that code can be reused.
backtorch::jit::__anon97bd7f910111::QuantOpParams34   std::string back() const {
35     return "AttributeDoesNotExist";
36   }
37 };
38 
toAffine(c10::QScheme qscheme)39 c10::QScheme toAffine(c10::QScheme qscheme) {
40   switch (qscheme) {
41     case c10::kPerTensorAffine:
42     case c10::kPerTensorSymmetric:
43       return c10::kPerTensorAffine;
44     case c10::kPerChannelAffine:
45     case c10::kPerChannelSymmetric:
46       return c10::kPerChannelAffine;
47     default:
48       return qscheme;
49   }
50 }
51 
isPerChannel(at::QScheme qscheme)52 bool isPerChannel(at::QScheme qscheme) {
53   return qscheme == c10::kPerChannelAffine ||
54       qscheme == c10::kPerChannelSymmetric;
55 }
56 
57 // Go through the CallMethod graph to check if the value is Weight.
isWeight(Module & module,Value * v)58 bool isWeight(Module& module, Value* v) {
59   if (isWeight(v)) {
60     return true;
61   }
62   std::optional<bool> result;
63   auto* self = v->owningGraph()->inputs()[0];
64   for (const Use& u : v->uses()) {
65     Node* n = u.user;
66     if (n->kind() == prim::CallMethod) {
67       auto m_opt = getInvokedModuleOpt(module, n, self);
68       if (!m_opt.has_value()) {
69         return false;
70       }
71       auto m = *m_opt;
72       auto g = m.get_method(n->s(attr::name)).graph();
73       auto call_method_result = isWeight(m, g->inputs()[u.offset]);
74       if (result.has_value()) {
75         // Check to make sure all the CallMethods in the graph produce the same
76         // output.
77         TORCH_CHECK(
78             call_method_result == result.value(),
79             "Expected all CallMethods to use either weight "
80             "or non-weight value.",
81             v->debugName());
82       } else {
83         result = call_method_result;
84       }
85     }
86   }
87   return result.has_value() ? result.value() : false;
88 }
89 
insertChooseQParams(Graph * graph,Value * original_val)90 Node* insertChooseQParams(Graph* graph, Value* original_val) {
91   std::string choose_qparams_func = "_choose_qparams_per_tensor";
92   // Set the reduce range to default to true, since qnnpack backend ignores this
93   // argument.
94   bool reduce_range_param = true;
95   auto reduce_range = graph->insertConstant(reduce_range_param);
96   // choose_qparams_per_tensor has 2 outputs, (scale, zero_point).
97   Node* choose_qparams = graph->create(
98       at::Symbol::aten(choose_qparams_func),
99       {original_val, reduce_range},
100       /* num_outputs = */ 2);
101   choose_qparams->output(0)->setDebugName(original_val->debugName() + ".scale");
102   choose_qparams->output(0)->setType(FloatType::get());
103   choose_qparams->output(1)->setDebugName(
104       original_val->debugName() + ".zero_point");
105   choose_qparams->output(1)->setType(IntType::get());
106   graph->insertNode(choose_qparams);
107   return choose_qparams;
108 }
109 
insertQuant(Graph * graph,const std::vector<Value * > & inputs,NodeKind quant_kind,const std::string & debugName)110 Node* insertQuant(
111     Graph* graph,
112     const std::vector<Value*>& inputs,
113     NodeKind quant_kind,
114     const std::string& debugName) {
115   Node* quant = graph->create(quant_kind, inputs);
116   quant->output()->setDebugName(debugName);
117   graph->insertNode(quant);
118   return quant;
119 }
120 
insertDeQuant(Graph * graph,Value * quantized_val,Value * original_val,size_t id=0)121 Node* insertDeQuant(
122     Graph* graph,
123     Value* quantized_val,
124     Value* original_val,
125     size_t id = 0) {
126   Node* dequant = graph->create(Symbol::aten("dequantize"), {quantized_val});
127   dequant->output()
128       ->setDebugName(
129           original_val->debugName() + ".dequant." + std::to_string(id))
130       ->setType(original_val->type());
131   graph->insertNode(dequant);
132   return dequant;
133 }
134 
insertDeQuantForAllUse(Graph * graph,Value * quantized_val,Value * original_val)135 std::vector<Value*> insertDeQuantForAllUse(
136     Graph* graph,
137     Value* quantized_val,
138     Value* original_val) {
139   // copy uses to vector since value->uses() is a reference
140   // and changing the graph will also change the uses() list
141   const std::vector<Use> uses = original_val->uses();
142   std::vector<Value*> outputs;
143   for (const auto i : c10::irange(uses.size())) {
144     auto* user = uses[i].user;
145     // Insert dequantize node right before use node, because
146     // we want to make sure use node and dequantize node reside
147     // in the same block so that quant fusion can happen
148     WithInsertPoint ins(user);
149     Node* dequant = insertDeQuant(graph, quantized_val, original_val, i);
150     user->replaceInput(uses[i].offset, dequant->output());
151     outputs.push_back(dequant->output());
152   }
153   return outputs;
154 }
155 
insertQParam(Graph * graph,Value * quantized_input,NodeKind node_kind,const TypePtr & output_type,const std::string & param_name)156 Node* insertQParam(
157     Graph* graph,
158     Value* quantized_input,
159     NodeKind node_kind,
160     const TypePtr& output_type,
161     const std::string& param_name) {
162   Node* qparam = graph->create(node_kind, {quantized_input});
163   qparam->output()
164       ->setDebugName(quantized_input->debugName() + "." + param_name)
165       ->setType(output_type);
166   graph->insertNode(qparam);
167   return qparam;
168 }
169 
insertScalarToTensor(Graph * graph,Value * scalar_value)170 Node* insertScalarToTensor(Graph* graph, Value* scalar_value) {
171   Node* n = scalar_value->node();
172   WithInsertPoint ins(n->next());
173   Value* float_scalar_type = graph->insertConstant(IValue(c10::kFloat));
174   Value* none = graph->insertConstant(IValue());
175   Node* tensor_node = graph->create(
176       Symbol::aten("scalar_tensor"),
177       {scalar_value, float_scalar_type, none, none, none});
178   Value* tensor_output = tensor_node->output();
179   tensor_output->setDebugName(scalar_value->debugName() + ".tensor");
180   graph->insertNode(tensor_node);
181   // replace original_output with tensor
182   scalar_value->replaceAllUsesAfterNodeWith(tensor_node, tensor_output);
183   return tensor_node;
184 }
185 
insertItem(Graph * graph,Value * tensor,const TypePtr & output_type)186 Node* insertItem(Graph* graph, Value* tensor, const TypePtr& output_type) {
187   WithInsertPoint ins(tensor->node()->next());
188   Node* n = graph->create(Symbol::aten("item"), {tensor});
189   Value* scalar = n->output();
190   scalar->setDebugName(tensor->debugName() + ".scalar")->setType(output_type);
191   graph->insertNode(n);
192   return n;
193 }
194 
insertChooseQParamQuantDequant(Graph * graph,Value * original_val,Value * dtype,NodeKind quant_kind)195 DynamicQuantOps insertChooseQParamQuantDequant(
196     Graph* graph,
197     Value* original_val,
198     Value* dtype,
199     NodeKind quant_kind) {
200   Node* choose_qparams = insertChooseQParams(graph, original_val);
201   std::vector<Value*> quant_inputs = {original_val};
202   for (auto& out : choose_qparams->outputs()) {
203     quant_inputs.push_back(out);
204   }
205   quant_inputs.push_back(dtype);
206   Node* quant = insertQuant(
207       graph, quant_inputs, quant_kind, original_val->debugName() + ".quant");
208   Node* dequant = insertDeQuant(graph, quant->output(), original_val);
209   return std::make_tuple(choose_qparams, quant, dequant);
210 }
211 
insertFP16CastOps(Graph * graph,Value * observer_out)212 Node* insertFP16CastOps(Graph* graph, Value* observer_out) {
213   // If the weight value is outside of the range for FP16 range, i.e. [5.96e-8,
214   // 65504], we saturate the values to the min/max of this range.
215   Node* saturated_weight =
216       graph->create(Symbol::aten("_saturate_weight_to_fp16"), {observer_out});
217   graph->insertNode(saturated_weight);
218   graph->lint();
219 
220   return saturated_weight;
221 }
222 
223 // find the observer for Value `v` and return the name of the observer
findObserverName(Value * v)224 std::optional<std::string> findObserverName(Value* v) {
225   // Note that here we just check for the name of observer, but the ideally
226   // we should be comparing the type of observer, this is a temporary
227   // work around until data only clone of module.clone is supported.
228   Node* n = v->node();
229   if (n->kind() == prim::CallMethod && n->s(attr::name) == "forward") {
230     auto module_instance = n->inputs().at(0);
231     if (module_instance->node()->kind() == prim::GetAttr &&
232         module_instance->node()->s(attr::name).find("_observer_") !=
233             std::string::npos) {
234       return module_instance->node()->s(attr::name);
235     }
236   }
237   return std::nullopt;
238 }
239 
isPlaceholderObserver(Value * observer)240 bool isPlaceholderObserver(Value* observer) {
241   if (getModuleName(observer).has_value()) {
242     auto name = getModuleName(observer).value();
243     // if PlaceholderObserver is (anywhere) in name
244     if (name.find("PlaceholderObserver") != std::string::npos) {
245       return true;
246     }
247   }
248   return false;
249 }
250 
getObserverDtype(Module & module,Value * v)251 at::ScalarType getObserverDtype(Module& module, Value* v) {
252   auto observer_name = findObserverName(v);
253   if (observer_name.has_value()) {
254     auto observer_module = module.attr(observer_name.value()).toModule();
255     at::ScalarType scalar_type = observer_module.attr("dtype").toScalarType();
256     return scalar_type;
257   }
258   return at::ScalarType::Undefined;
259 }
260 
getEmbeddingBagObsName(script::Module & module,Node * n)261 std::optional<std::string> getEmbeddingBagObsName(
262     script::Module& module,
263     Node* n) {
264   Value* v = n->output();
265   auto observer = n->input(0);
266   auto observer_module = module.attr(findObserverName(v).value()).toModule();
267   if (observer_module.hasattr("custom_op")) {
268     auto op_name = observer_module.attr("custom_op").toStringRef();
269     return isPlaceholderObserver(observer) ? std::move(op_name) : "";
270   }
271   return std::nullopt;
272 }
273 
isEmbeddingBagOp(Node * observer,std::optional<std::string> embedding_bag_name)274 bool isEmbeddingBagOp(
275     Node* observer,
276     std::optional<std::string> embedding_bag_name) {
277   return embedding_bag_name &&
278       embedding_bag_name.value().find("embedding_bag_") != std::string::npos;
279 }
280 
281 template <typename T>
282 Node* insertQuantDequantNodes(
283     Value* self,
284     Node* observer,
285     T& qparams,
286     const std::string& quantize_func);
287 
288 // Insert quant and dequant nodes into the graph for both static and dynamic
289 // quant.
290 template <>
insertQuantDequantNodes(Value * self,Node * observer,std::vector<std::string> & qparam_names,const std::string & quantize_func)291 Node* insertQuantDequantNodes<std::vector<std::string>>(
292     Value* self,
293     Node* observer,
294     std::vector<std::string>& qparam_names,
295     const std::string& quantize_func) {
296   Graph* g = observer->owningGraph();
297   Value* observer_out = observer->output();
298   Value* original_val = observer->input(1);
299   std::vector<Value*> inputs = {observer_out};
300   // Insert GetAttr nodes for quantization parameters
301   for (const auto& qparam_name : qparam_names) {
302     inputs.push_back(g->insertGetAttr(self, qparam_name));
303   }
304   Node* quant = insertQuant(
305       g,
306       inputs,
307       at::Symbol::aten(quantize_func),
308       original_val->debugName() + ".quant");
309   Node* dequant = insertDeQuant(g, quant->output(), original_val);
310   return dequant;
311 }
312 
insertEmbeddingBagOps(Node * observer,const std::string & op_name)313 Node* insertEmbeddingBagOps(Node* observer, const std::string& op_name) {
314   Graph* g = observer->owningGraph();
315   auto observer_out = observer->output();
316 
317   std::string prepack_fn, quant_fn;
318   std::vector<Value*> prepack_inputs = {observer_out};
319   if (op_name == "embedding_bag_4bit") {
320     bool optimized_qparams = false;
321     constexpr int NBINS = 200;
322     constexpr float RATIO = 0.16;
323     Value* optimized_qparams_false = g->insertConstant(optimized_qparams);
324     Value* nbins_200 = g->insertConstant(NBINS);
325     Value* ratio_0_16 = g->insertConstant(RATIO);
326     prepack_fn = "quantized::embedding_bag_4bit_prepack";
327     quant_fn = "quantized::embedding_bag_4bit_rowwise_offsets";
328     prepack_inputs.push_back(optimized_qparams_false);
329     prepack_inputs.push_back(nbins_200);
330     prepack_inputs.push_back(ratio_0_16);
331   } else if (op_name == "embedding_bag_byte") {
332     prepack_fn = "quantized::embedding_bag_byte_prepack";
333     quant_fn = "quantized::embedding_bag_byte_rowwise_offsets";
334   } else {
335     TORCH_INTERNAL_ASSERT(
336         false,
337         "Graph Mode Quantization currently supports 4-bit and 8-bit embedding bag quantization.");
338   }
339 
340   std::vector<Use> uses = observer_out->uses();
341   Node* embedding_bag_float_op = nullptr;
342   // We expect that the output of the weight observer will be consumed by the
343   // embedding_bag operator.
344   for (const Use& use : uses) {
345     if (matchCallFuncToUse(use, "embedding_bag", 2) ||
346         matchAtenFuncToUse(use, "embedding_bag", 0)) {
347       embedding_bag_float_op = use.user;
348     }
349   }
350 
351   // Insert prepack op
352   Node* prepack = g->create(Symbol::fromQualString(prepack_fn), prepack_inputs);
353   g->insertNode(prepack);
354 
355   std::vector<Value*> embedding_bag_inputs =
356       embedding_bag_float_op->inputs().vec();
357   std::vector<Value*> qembedding_bag_inputs = {prepack->output()};
358   const auto inputs_size = embedding_bag_float_op->inputs().size();
359   const bool is_aten_op =
360       embedding_bag_float_op->kind() == Symbol::aten("embedding_bag");
361   // Create and insert quantized embedding op.
362   Value* none = g->insertConstant(IValue());
363   Value* zero = g->insertConstant(IValue(0));
364   bool pruned_wt = false;
365   auto pruned_const = g->insertConstant(pruned_wt);
366 
367   if (is_aten_op) {
368     TORCH_CHECK(
369         inputs_size == 9,
370         "Expecting FP aten::embedding_bag operator to have 9 inputs");
371     // input 0 is the output of prepack op.
372     // Last input is added after we account for extra input in 4-bit case.
373     for (unsigned long i = 1; i < inputs_size - 2; ++i) {
374       qembedding_bag_inputs.push_back(embedding_bag_inputs[i]);
375     }
376     // The sparse field in the float operator denotes sparse gradients.
377     // For inference this stands for pruned weights. We currently don't support
378     // pruning in graph mode API so we set the field to 0 for inference.
379     qembedding_bag_inputs[5] = pruned_const;
380   } else {
381     TORCH_CHECK(
382         inputs_size == 12,
383         "Expecting F.embedding_bag operator to have 12 inputs");
384     qembedding_bag_inputs.push_back(embedding_bag_inputs[1]); // indices
385     qembedding_bag_inputs.push_back(embedding_bag_inputs[3]); // offsets
386     qembedding_bag_inputs.push_back(
387         embedding_bag_inputs[6]); // scale_grad_by_freq
388     qembedding_bag_inputs.push_back(zero); // mode
389     qembedding_bag_inputs.push_back(pruned_const); // pruned_weights
390     qembedding_bag_inputs.push_back(
391         embedding_bag_inputs[9]); // per_sample_weights
392   }
393 
394   qembedding_bag_inputs.push_back(none); // compressed_indices_mapping
395   qembedding_bag_inputs.push_back(embedding_bag_inputs[inputs_size - 2]);
396 
397   TORCH_CHECK(
398       embedding_bag_inputs[inputs_size - 1]->mustBeNone(),
399       "Expected aten::embedding_bag padding_idx input to be None");
400 
401   Node* qembedding_bag =
402       g->create(Symbol::fromQualString(quant_fn), qembedding_bag_inputs);
403   if (is_aten_op) {
404     WithInsertPoint ins(embedding_bag_float_op);
405     g->insertNode(qembedding_bag);
406     // Verify that the outputs (apart from index 0) have no uses in the graph.
407     for (const auto i :
408          c10::irange(1, embedding_bag_float_op->outputs().size())) {
409       TORCH_CHECK(
410           !embedding_bag_float_op->output(i)->hasUses(),
411           "Expected aten::embedding_bag to only have use for its first output.");
412     }
413   } else {
414     g->insertNode(qembedding_bag);
415   }
416   embedding_bag_float_op->output(0)->replaceAllUsesWith(
417       qembedding_bag->output());
418   embedding_bag_float_op->removeAllInputs();
419   embedding_bag_float_op->destroy();
420   g->lint();
421   return qembedding_bag;
422 }
423 
424 template <typename T>
insertQuantizationOps(Module & module,Value * self,Node * observer,bool is_per_channel,T & qparams,QuantType quant_type=QuantType::STATIC)425 void insertQuantizationOps(
426     Module& module,
427     Value* self,
428     Node* observer,
429     bool is_per_channel,
430     T& qparams,
431     QuantType quant_type = QuantType::STATIC) {
432   Graph* g = observer->owningGraph();
433   // Observer output
434   Value* observer_out = observer->output();
435   // Inserting before insert point
436   WithInsertPoint ins(observer_out->node()->next());
437 
438   std::string quantize_func;
439   if (is_per_channel) {
440     quantize_func = "quantize_per_channel";
441   } else {
442     quantize_func = "quantize_per_tensor";
443   }
444   Value* original_val = observer->input(1);
445   // Temporary solution to quantize embedding_bag operators. Will be re-written
446   // once we support quantization of embedding_bag weights.
447   auto embedding_bag_name = getEmbeddingBagObsName(module, observer);
448   if (isEmbeddingBagOp(observer, embedding_bag_name)) {
449     if (isWeight(module, observer_out)) {
450       auto op_name = embedding_bag_name.value();
451       Node* dequant = insertEmbeddingBagOps(observer, op_name);
452       observer_out->replaceAllUsesWith(original_val);
453       original_val->replaceAllUsesAfterNodeWith(dequant, dequant->output());
454     } else {
455       // Special case for embedding bag operators indices input - we don't
456       // quantize the input but we still need to insert observers for it because
457       // the order of input and weight can be changed in the module code.
458       observer_out->replaceAllUsesWith(original_val);
459     }
460     return;
461   }
462   Node* dequant = nullptr;
463   if (quant_type == QuantType::DYNAMIC) {
464     if (getObserverDtype(module, observer_out) == at::ScalarType::Half) {
465       dequant = insertFP16CastOps(g, observer_out);
466     } else if (!isWeight(module, observer_out)) {
467       auto observer_dtype = getObserverDtype(module, observer_out);
468       if (observer_dtype == at::ScalarType::QUInt8 ||
469           observer_dtype == at::ScalarType::QInt8) {
470         // For activation tensors we insert choose_qparams, quant, dequant ops.
471         Value* dtype = g->insertGetAttr(self, qparams.back());
472         dequant = std::get<2>(insertChooseQParamQuantDequant(
473             g, observer_out, dtype, at::Symbol::aten(quantize_func)));
474       } else {
475         // dtype does not require quantization, e.g. float32
476         // will just remove the observer call
477         observer_out->replaceAllUsesWith(original_val);
478         return;
479       }
480     } else {
481       // For weight tensors we insert quant-dequant ops.
482       dequant = insertQuantDequantNodes(self, observer, qparams, quantize_func);
483     }
484   } else { // Static quant
485     dequant = insertQuantDequantNodes(self, observer, qparams, quantize_func);
486   }
487   observer_out->replaceAllUsesWith(original_val);
488 
489   original_val->replaceAllUsesAfterNodeWith(dequant, dequant->output());
490   GRAPH_DUMP("insert nodes:", original_val->owningGraph());
491 }
492 
ReplicateChooseQParamsQuantDequant(std::shared_ptr<Graph> & graph)493 void ReplicateChooseQParamsQuantDequant(std::shared_ptr<Graph>& graph) {
494   const PatternInfo& dynamic_quant_pattern = PatternInfo::parse_from_str(R"(
495     graph(%a, %reduce_range, %a_dtype):
496         %a_scale : float, %a_zero_point : int = aten::_choose_qparams_per_tensor(%a, %reduce_range)
497         %a_quant = aten::quantize_per_tensor(%a, %a_scale, %a_zero_point, %a_dtype)
498         %a_dequant = aten::dequantize(%a_quant)
499         return (%a_dequant) )");
500   const Graph& dynamic_quant_graph = *dynamic_quant_pattern.pattern_graph;
501 
502   const auto& matches = findPatternMatches(dynamic_quant_graph, *graph);
503   if (matches.empty()) {
504     return;
505   }
506 
507   const auto& vmap = dynamic_quant_pattern.vmap;
508   Value* dequant_val = vmap.at("a_dequant");
509   Node* pattern_dequant = dequant_val->node();
510   Value* quant_val = vmap.at("a_quant");
511   Node* pattern_quant = quant_val->node();
512   Value* choose_qparam_val = vmap.at("a_scale");
513   Node* pattern_choose_qparam = choose_qparam_val->node();
514 
515   std::vector<DynamicQuantOps> nodes_to_rewrite;
516   std::vector<Node*> choose_qparam_nodes_to_rewrite;
517   for (const Match& match : matches) {
518     Node* matched_dequantize = match.nodes_map.at(pattern_dequant);
519     Node* matched_quantize = match.nodes_map.at(pattern_quant);
520     Node* matched_choose_qparam = match.nodes_map.at(pattern_choose_qparam);
521     if (matched_dequantize->output()->uses().size() > 1) {
522       nodes_to_rewrite.emplace_back(
523           matched_choose_qparam, matched_quantize, matched_dequantize);
524     }
525   }
526   for (const auto& nodes : nodes_to_rewrite) {
527     auto quant_node = std::get<1>(nodes);
528     auto dequant_node = std::get<2>(nodes);
529     // get input of quantize call.
530     Value* original_val = quant_node->inputs()[0];
531     Value* dequant_out = dequant_node->output();
532     Value* dtype = quant_node->inputs()[3];
533     std::vector<Use> uses = dequant_out->uses();
534     for (const Use& use : uses) {
535       auto* user = use.user;
536       WithInsertPoint ins(user);
537       auto quant_ops = insertChooseQParamQuantDequant(
538           graph.get(), original_val, dtype, quant_node->kind());
539       user->replaceInputWith(dequant_out, std::get<2>(quant_ops)->output());
540     }
541   }
542   for (const auto& n : nodes_to_rewrite) {
543     auto [choose_qparams, quant, dequant] = n;
544     dequant->removeAllInputs();
545     quant->removeAllInputs();
546     choose_qparams->removeAllInputs();
547   }
548   for (const auto& n : nodes_to_rewrite) {
549     auto [choose_qparams, quant, dequant] = n;
550     dequant->destroy();
551     quant->destroy();
552     choose_qparams->destroy();
553   }
554 }
555 
RemoveRedundantDequantize(std::shared_ptr<Graph> & graph)556 void RemoveRedundantDequantize(std::shared_ptr<Graph>& graph) {
557   const std::string dequantize = R"(
558     graph(%a_quant):
559         %a_dequant = aten::dequantize(%a_quant)
560         return (%a_dequant) )";
561   const std::string dequantize_replacement = R"(
562     graph(%a):
563         return (%a) )";
564   auto filter = [&](const Match& match,
565                     const std::unordered_map<std::string, Value*>& vmap) {
566     const auto& match_vmap = match.values_map;
567     auto dequant_node = match_vmap.at(vmap.at("a_dequant"))->node();
568     Value* dequant_out = dequant_node->output();
569     // Values can be used multiple times in a single node
570     if (dequant_out->uses().size() != 1) {
571       return false;
572     }
573     Node* user = dequant_out->uses()[0].user;
574     return isTensorInfoNode(user);
575   };
576   SubgraphRewriter rewriter;
577   rewriter.RegisterRewritePattern(dequantize, dequantize_replacement);
578   rewriter.runOnGraph(graph, filter);
579 }
580 
RemoveRedundantQuantizationOps(std::shared_ptr<Graph> & graph)581 void RemoveRedundantQuantizationOps(std::shared_ptr<Graph>& graph) {
582   const std::string dynamic_quant_ops = R"(
583     graph(%a, %reduce_range, %a_dtype):
584         %a_scale : float, %a_zero_point : int = aten::_choose_qparams_per_tensor(%a, %reduce_range)
585         %a_quant = aten::quantize_per_tensor(%a, %a_scale, %a_zero_point, %a_dtype)
586         %a_dequant = aten::dequantize(%a_quant)
587         return (%a_dequant) )";
588   const std::string dynamic_quant_replacement = R"(
589     graph(%a, %reduce_range, %a_dtype):
590         return (%a) )";
591   auto filter = [&](const Match& match,
592                     const std::unordered_map<std::string, Value*>& vmap) {
593     const auto& match_vmap = match.values_map;
594     auto dequant_node = match_vmap.at(vmap.at("a_dequant"))->node();
595     Value* dequant_out = dequant_node->output();
596     // Values can be used multiple times in a single node
597     if (dequant_out->uses().size() != 1) {
598       return false;
599     }
600     Node* user = dequant_out->uses()[0].user;
601     return !nodeQuantizable(user, QuantType::DYNAMIC);
602   };
603   SubgraphRewriter rewriter;
604   rewriter.RegisterRewritePattern(dynamic_quant_ops, dynamic_quant_replacement);
605   rewriter.runOnGraph(graph, filter);
606 }
607 
ReplicateClampScalarArgs(std::shared_ptr<Graph> & graph)608 void ReplicateClampScalarArgs(std::shared_ptr<Graph>& graph) {
609   std::stack<Block*> blocks_to_visit;
610   std::unordered_set<Node*> scalar_nodes_to_rewrite;
611   ;
612   blocks_to_visit.push(graph->block());
613   while (!blocks_to_visit.empty()) {
614     Block* b = blocks_to_visit.top();
615     blocks_to_visit.pop();
616     for (Node* n : b->nodes()) {
617       for (Value* output : n->outputs()) {
618         if (getClampScalarInputUse(output) && output->uses().size() > 1) {
619           scalar_nodes_to_rewrite.insert(n);
620         }
621       }
622       for (Block* subblock : n->blocks()) {
623         blocks_to_visit.push(subblock);
624       }
625     }
626   }
627 
628   for (Node* n : scalar_nodes_to_rewrite) {
629     const std::vector<Use> uses = n->output()->uses();
630     for (const auto& use : uses) {
631       Node* user = use.user;
632       WithInsertPoint ins(user);
633       Node* cloned_node = graph->createClone(n, [](Value* v) { return v; });
634       graph->insertNode(cloned_node);
635       user->replaceInput(use.offset, cloned_node->output());
636     }
637   }
638 
639   for (Node* n : scalar_nodes_to_rewrite) {
640     n->removeAllInputs();
641   }
642 
643   for (Node* n : scalar_nodes_to_rewrite) {
644     n->destroy();
645   }
646 }
647 
checkCalculateQParamsResult(const IValue & qparams)648 void checkCalculateQParamsResult(const IValue& qparams) {
649   TORCH_CHECK(
650       qparams.isTuple(),
651       "`calculate_qparams` function is expected to return a "
652       "Tuple, but got:",
653       qparams.tagKind());
654   auto tp = qparams.toTuple();
655   TORCH_CHECK(
656       tp->elements().size() == 2,
657       "`calculate_qparams` function is expected to return a "
658       "Tuple of size 2, got Tuple of size ",
659       tp->elements().size());
660   // Expect first two elements of the tuple to be Tensor
661   for (const auto i : c10::irange(2)) {
662     TORCH_CHECK(
663         tp->elements()[i].isTensor(),
664         "Element of Tuple is expected to be Tensor, but element ",
665         i,
666         " has type: ",
667         tp->elements()[i].tagKind());
668   }
669 }
670 
671 class SubGraphCloneHelper {
672  public:
673   // Given a list of nodes, build a graph corresponding to these nodes.
674   // User should make sure to run this graph with expected input.
675   std::unique_ptr<GraphFunction> buildGraphFromNodes(
676       const std::vector<Node*>& nodes,
677       const std::string& name);
678 
679   // Given a list of nodes in src, produce a Graph with these nodes.
680   void buildObserverSubgraph(
681       const std::vector<Node*>& src,
682       std::shared_ptr<Graph> dest);
683 
684  private:
685   // Clone node in the destination Graph g.
686   void cloneNodeInGraph(
687       Node* node,
688       std::shared_ptr<Graph>& g,
689       std::unordered_map<Value*, Value*>& remap_values);
690 };
691 
692 class InsertQuantDeQuantHelper {
693  public:
InsertQuantDeQuantHelper(QuantType quant_type,bool debug)694   InsertQuantDeQuantHelper(QuantType quant_type, bool debug)
695       : quant_type_(quant_type), debug_(debug) {}
696 
697   void run(Module& module, const std::string& method_name);
698 
699   void runForOnDevicePTQ(Module& module, const std::string& method_name);
700 
701   // Cleanup observer nodes from graph and observer modules
702   // from module object and ClassType
703   void cleanup(Module& module);
704 
705   // Cleanup observer nodes only but not modules
706   // This is for ondevice PTQ
707   void removeObserverNodes(Module& m);
708 
709   // In order to propagate quantization ops through the ops that doesn't
710   // require observation, we'll first inline the graph, and call the
711   // PropagateQuantizationOps pass
712   void propagateQuantizationOps(Module& module);
713 
714   // Used for dynamic quantization to selectively run the weight observers.
715   // It extracts the subgraph corresponding to the weight and runs it with
716   // the module instance.
717   void runWeightObserver(Module& module, const std::string& method_name);
718 
719  private:
720   ModuleMethodVector getInvokedMethods(
721       Module& module,
722       const std::string& method_name);
723 
724   // Get quantization parameter map of the given Value in Graph
725   // by searching for observer module of the value and extract the
726   // quantization parameters from the observer module
727   std::tuple<c10::QScheme, QParamVector> getQSchemeAndQParamVector(
728       script::Module& module,
729       Node* n);
730   QuantOpParams insertCalculateQParams(
731       script::Module& module,
732       Graph* g,
733       Node* n);
734 
checkQScheme(Graph * g,c10::QScheme qscheme)735   void checkQScheme(Graph* g, c10::QScheme qscheme) {
736     if (qscheme_for_graph_.count(g)) {
737       // FIXME[T110786721]: This check was broken before nevery failing.
738       // Once fixed, this check triggers and fails tests.
739       // Fix the tests that enabling this check produce!
740       /*
741       TORCH_CHECK(
742           qscheme_for_graph_.at(g) == qscheme,
743           "Quantizing same graph with different types of "
744           "QSchemes is not supported.\n",
745           " Expecting:",
746           c10::toString(qscheme_for_graph_.at(g)),
747           " Got:",
748           c10::toString(qscheme));
749       */
750     } else {
751       qscheme_for_graph_[g] = toAffine(qscheme);
752     }
753   }
754 
755   void collectObserverNodesAndValueToQuantize(Module& module, Value*);
756   void cleanup(Module& module, Graph* g);
757   void removeObserverNodes(Graph* g);
758   void quantizeTensors(Module& module, Graph* g, Value* self);
759   void insertCalculateQParamsAndQuantizationOps(
760       Module& module,
761       Graph* g,
762       Value* self);
763 
764   // Function that extracts and runs the weight observer in a separate
765   // subgraph.
766   void extractAndRunWeightObserver(
767       Module& module,
768       Value* self,
769       Value* weight_value);
770 
771   // Recursively find the nodes that produce the value and add to subgraph.
772   void findSubgraph(Value* self, Value* v, std::vector<Node*>& weight_subgraph);
773 
774   // Quantizes two types of general ops(ops that works both for floating point
775   // and quantized Tensors) in this pass
776   // for ops that only manipulates shape, e.g. flatten, quantization
777   // is done by swapping with previous dequantize op
778   // for ops that manipulates values of Tensor, e.g. average pool, quantization
779   // is done by inserting quant/dequant ops after the op
780   // also has a special handling of clamp/hardtanh
781   void propagateQuantizationOps(Block* block);
782 
783   // Propagate quantization parameters from other quantized tensors
784   void propagateQParams(
785       Value* original_output,
786       const std::vector<Value*>& inputs,
787       bool is_scalar = false,
788       const std::optional<std::tuple<c10::QScheme, QParamVector>>& qparams_opt =
789           std::nullopt);
790 
isQuantized(Value * v)791   bool isQuantized(Value* v) {
792     return quantized_values_.count(v) != 0;
793   }
794 
795   std::unordered_map<Graph*, std::vector<std::string>>
796       observer_modules_to_remove_;
797   // We only remove observer module attributes from type in the
798   // first encounter of the graph, after that since the attributes
799   // is already removed from the ClassType, we'll use the list of slot index to
800   // replay this removal
801   std::unordered_map<Graph*, std::vector<int>> removed_observer_slots_;
802   std::unordered_map<Graph*, std::vector<Node*>> nodes_to_destroy_;
803   // Map from Graph to observer node, we can use observer node to
804   // get the information of original value that's been observed and
805   // the quantization parameters
806   std::unordered_map<Graph*, std::vector<Node*>> observer_nodes_for_graph_;
807   // A map from qparam name (e.g. _scale) to the attribute name in
808   // the module(e.g. weight_scale_0)
809   std::unordered_map<Node*, std::unordered_map<std::string, std::string>>
810       qparam_name_map_for_node_;
811   // Record qscheme for every graph, this is for checking
812   // each graph is only quantized with one type of QScheme
813   std::unordered_map<Graph*, c10::QScheme> qscheme_for_graph_;
814 
815   // Set of quantized values, so that we quantize each value only
816   // once
817   std::unordered_set<Value*> quantized_values_;
818 
819   // Map from original weight value to GraphFunction corresponding to the
820   // subgraph that includes the weight observer and dependent nodes.
821   std::unordered_map<Value*, std::unique_ptr<GraphFunction>>
822       weight_to_graph_fn_;
823 
824   QuantType quant_type_ = QuantType::STATIC;
825   bool debug_ = false;
826 };
827 
collectObserverNodesAndValueToQuantize(Module & module,Value * v)828 void InsertQuantDeQuantHelper::collectObserverNodesAndValueToQuantize(
829     Module& module,
830     Value* v) {
831   auto* g = v->owningGraph();
832   auto observer_name = findObserverName(v);
833   if (!observer_name) {
834     return;
835   }
836   observer_modules_to_remove_[g].push_back(observer_name.value());
837 
838   Node* observer = v->node();
839   TORCH_INTERNAL_ASSERT(
840       observer->kind() == prim::CallMethod &&
841       observer->s(attr::name) == "forward" &&
842       observer->inputs()[0]->node()->kind() == prim::GetAttr &&
843       observer->inputs()[0]->node()->s(attr::name) == observer_name);
844 
845   // Observer forward call node
846   nodes_to_destroy_[g].push_back(observer);
847   // GetAttr node for observer module
848   nodes_to_destroy_[g].push_back(observer->inputs()[0]->node());
849   observer_nodes_for_graph_[g].push_back(observer);
850 }
851 
removeObserverNodes(Module & module)852 void InsertQuantDeQuantHelper::removeObserverNodes(Module& module) {
853   for (auto& method : module.get_methods()) {
854     removeObserverNodes(method.graph().get());
855   }
856   for (Module m : module.children()) {
857     removeObserverNodes(m);
858   }
859 }
860 
removeObserverNodes(Graph * g)861 void InsertQuantDeQuantHelper::removeObserverNodes(Graph* g) {
862   if (nodes_to_destroy_.count(g)) {
863     for (auto& n : nodes_to_destroy_.at(g)) {
864       n->removeAllInputs();
865     }
866     for (auto& n : nodes_to_destroy_.at(g)) {
867       n->destroy();
868     }
869     nodes_to_destroy_.at(g).clear();
870   }
871 }
872 
cleanup(Module & module)873 void InsertQuantDeQuantHelper::cleanup(Module& module) {
874   for (auto& method : module.get_methods()) {
875     cleanup(module, method.graph().get());
876   }
877   for (Module m : module.children()) {
878     cleanup(m);
879   }
880 }
881 
cleanup(Module & module,Graph * g)882 void InsertQuantDeQuantHelper::cleanup(Module& module, Graph* g) {
883   GRAPH_DUMP("Before Remove Observers:", g);
884   removeObserverNodes(g);
885 
886   // 1. If we have seen this graph before, this means the observer
887   // attributes has been removed from the type(see step 2) but the slot
888   // index of these attributes are kept in the list, we'll replay the observer
889   // slots removal using these slot indexes
890   if (removed_observer_slots_.count(g)) {
891     for (auto slot : removed_observer_slots_.at(g)) {
892       module._ivalue()->unsafeRemoveSlot(slot);
893     }
894   }
895 
896   // 2. Remove observer modules from last one to first one in order to
897   // reduce the time complexity, assuming all the observer modules
898   // are added after the existing modules, we'll have complexity of
899   // O(N) where N is number of observer modules with this optimization
900   if (observer_modules_to_remove_.count(g)) {
901     auto& observers = observer_modules_to_remove_.at(g);
902     for (int64_t i = observers.size() - 1; i >= 0; --i) {
903       auto observer_name = observers[i];
904       GRAPH_DEBUG("Trying to remove: ", observer_name);
905       if (module.type()->hasAttribute(observer_name)) {
906         // We record the slot index here in order to replay the
907         // slot removal in other objects that's sharing the ClassType
908         // since we're going to remove attribute in the ClassType here
909         removed_observer_slots_[g].push_back(
910             module.type()->getAttributeSlot(observer_name));
911         module._ivalue()->unsafeRemoveAttr(observer_name);
912         module.type()->unsafeRemoveAttribute(observer_name);
913       }
914     }
915     observers.clear();
916   }
917   GRAPH_DUMP("After remove observers :", g);
918 }
919 
cloneNodeInGraph(Node * node,std::shared_ptr<Graph> & g,std::unordered_map<Value *,Value * > & remap_old_to_new)920 void SubGraphCloneHelper::cloneNodeInGraph(
921     Node* node,
922     std::shared_ptr<Graph>& g,
923     std::unordered_map<Value*, Value*>& remap_old_to_new) {
924   auto* block = g->block();
925   auto value_fn = [&](Value* v) {
926     if (remap_old_to_new.count(v) == 0) {
927       auto new_value = g->block()->addInput();
928       remap_old_to_new[v] = new_value;
929       new_value->copyMetadata(v);
930       return new_value;
931     } else {
932       return remap_old_to_new[v];
933     }
934   };
935 
936   auto new_node = block->appendNode(g->createClone(node, value_fn));
937   for (size_t i = 0; i < node->outputs().size(); ++i) {
938     auto oo = node->outputs()[i];
939     auto no = new_node->outputs()[i];
940     remap_old_to_new[oo] = no;
941   }
942 }
943 
buildObserverSubgraph(const std::vector<Node * > & weight_subgraph,std::shared_ptr<Graph> dest_graph)944 void SubGraphCloneHelper::buildObserverSubgraph(
945     const std::vector<Node*>& weight_subgraph,
946     std::shared_ptr<Graph> dest_graph) {
947   std::unordered_map<Value*, Value*> remap_old_to_new;
948   // Build weight subgraph
949   for (auto n : weight_subgraph) {
950     cloneNodeInGraph(n, dest_graph, remap_old_to_new);
951   }
952   LintGraph(dest_graph);
953 
954   // Add last node output value as subgraph output.
955   for (auto out : weight_subgraph.back()->outputs()) {
956     dest_graph->registerOutput(remap_old_to_new[out]);
957   }
958   GRAPH_DUMP("New weight observer subgraph: ", dest_graph);
959 }
960 
buildGraphFromNodes(const std::vector<Node * > & nodes,const std::string & name)961 std::unique_ptr<GraphFunction> SubGraphCloneHelper::buildGraphFromNodes(
962     const std::vector<Node*>& nodes,
963     const std::string& name) {
964   auto observer_subgraph = std::make_shared<Graph>();
965   auto build_observer_graph = [&](GraphFunction& func) {
966     buildObserverSubgraph(nodes, func.graph());
967   };
968   return std::make_unique<GraphFunction>(
969       name, observer_subgraph, build_observer_graph);
970 }
971 
findSubgraph(Value * self,Value * input_val,std::vector<Node * > & weight_subgraph)972 void InsertQuantDeQuantHelper::findSubgraph(
973     Value* self,
974     Value* input_val,
975     std::vector<Node*>& weight_subgraph) {
976   Node* node = input_val->node();
977   weight_subgraph.push_back(node);
978   const auto& inputs = node->inputs().vec();
979   for (auto v : inputs) {
980     if (!hitGraphInput(v)) {
981       findSubgraph(self, v, weight_subgraph);
982     } else {
983       TORCH_CHECK(
984           v == self,
985           "Unexpected value found when handling weight value "
986           " in findSubgraph, traced back to:",
987           v->debugName(),
988           " which is not self:",
989           self->debugName());
990     }
991   }
992 }
993 
extractAndRunWeightObserver(Module & module,Value * self,Value * weight_value)994 void InsertQuantDeQuantHelper::extractAndRunWeightObserver(
995     Module& module,
996     Value* self,
997     Value* weight_value) {
998   std::vector<Node*> weight_subgraph;
999   // If the graph was already visited, return the GraphFunction directly.
1000   // Multiple module instances can share the same graph code, so we don't need
1001   // to re-run the extraction process.
1002   if (weight_to_graph_fn_.count(weight_value) == 0) {
1003     // Extract the subgraph nodes.
1004     findSubgraph(self, weight_value, weight_subgraph);
1005 
1006     // Reverse to traverse subgraph in correct direction
1007     std::reverse(weight_subgraph.begin(), weight_subgraph.end());
1008 
1009     // Build the graph using the nodes found from the weight observer.
1010     SubGraphCloneHelper o;
1011     std::unique_ptr<GraphFunction> func =
1012         o.buildGraphFromNodes(weight_subgraph, "observer_subgraph");
1013     weight_to_graph_fn_[weight_value] = std::move(func);
1014   }
1015   Stack module_inp = {module._ivalue()};
1016   // Run the graph with the module input.
1017   weight_to_graph_fn_[weight_value]->run(module_inp);
1018 }
1019 
quantizeTensors(Module & module,Graph * g,Value * self)1020 void InsertQuantDeQuantHelper::quantizeTensors(
1021     Module& module,
1022     Graph* g,
1023     Value* self) {
1024   if (!observer_nodes_for_graph_.count(g)) {
1025     return;
1026   }
1027   for (auto* n : observer_nodes_for_graph_.at(g)) {
1028     auto* original_value = n->input(1);
1029     auto tp = getQSchemeAndQParamVector(module, n);
1030     auto qscheme = std::get<0>(tp);
1031     auto qparam_map = std::get<1>(tp);
1032     checkQScheme(g, qscheme);
1033     std::vector<std::string> qparam_names;
1034     for (auto& pr : qparam_map) {
1035       const auto& name = pr.first;
1036       const auto& qparam = pr.second;
1037       size_t uid = 0;
1038       auto qparam_name =
1039           original_value->debugName() + name + "_" + std::to_string(uid++);
1040       while (module.hasattr(qparam_name)) {
1041         qparam_name =
1042             original_value->debugName() + name + "_" + std::to_string(uid++);
1043       }
1044       qparam_name_map_for_node_[n][name] = qparam_name;
1045       module.register_attribute(qparam_name, qparam.type(), qparam);
1046       qparam_names.push_back(qparam_name);
1047     }
1048     insertQuantizationOps(
1049         module, self, n, isPerChannel(qscheme), qparam_names, quant_type_);
1050   }
1051 }
1052 
1053 std::tuple<c10::QScheme, QParamVector> InsertQuantDeQuantHelper::
getQSchemeAndQParamVector(script::Module & module,Node * n)1054     getQSchemeAndQParamVector(script::Module& module, Node* n) {
1055   // TODO: refactor findObserverName to take Node* as input
1056   Value* v = n->output();
1057   TORCH_INTERNAL_ASSERT(
1058       v->type()->isSubtypeOf(*TensorType::get()),
1059       "Expected output of observer node to be Tensor");
1060   auto observer_name = findObserverName(v);
1061   TORCH_INTERNAL_ASSERT(
1062       observer_name,
1063       "getQSchemeAndParamMap expects the corresponding observer for ",
1064       v->debugName(),
1065       " exists.");
1066   QParamVector qparams;
1067   c10::QScheme qscheme = c10::kPerTensorAffine;
1068 
1069   auto observer_module = module.attr(observer_name.value()).toModule();
1070   auto scalar_type = observer_module.attr("dtype");
1071   if (isPlaceholderObserver(n->input(0))) {
1072     // get compute_dtype for dynamic quantization
1073     if (observer_module.hasattr("is_dynamic") &&
1074         observer_module.attr("is_dynamic").toBool()) {
1075       qparams.emplace_back(kScalarType, observer_module.attr("dtype"));
1076     }
1077     return std::make_tuple(qscheme, std::move(qparams));
1078   } else if (scalar_type == at::ScalarType::Half) {
1079     return std::make_tuple(qscheme, std::move(qparams));
1080   }
1081   auto calculate_qparams = observer_module.get_method("calculate_qparams");
1082   IValue result = calculate_qparams(std::vector<IValue>());
1083   checkCalculateQParamsResult(result);
1084   TORCH_CHECK(
1085       scalar_type.toScalarType() != at::ScalarType::Undefined,
1086       "dtype of observer can't be undefined");
1087   auto tp = result.toTuple();
1088   at::Tensor scale = tp->elements()[0].toTensor().to(at::kFloat);
1089   at::Tensor zero_point = tp->elements()[1].toTensor().to(at::kInt);
1090   // quantization parameters should appear in the same order as
1091   // the argument for quantize_per_tensor/quantize_per_channel function
1092 
1093   qscheme = observer_module.attr("qscheme").toQScheme();
1094   if (isPerChannel(qscheme)) {
1095     auto axis = observer_module.attr("ch_axis");
1096     qparams.emplace_back("_scale", scale);
1097     qparams.emplace_back("_zero_point", zero_point);
1098     qparams.emplace_back("_axis", axis.toInt());
1099   } else {
1100     qparams.emplace_back("_scale", scale.item<double>());
1101     qparams.emplace_back("_zero_point", zero_point.item<int64_t>());
1102   }
1103   qparams.emplace_back(kScalarType, scalar_type);
1104   return std::make_tuple(qscheme, std::move(qparams));
1105 }
1106 
getInvokedMethods(Module & module,const std::string & method_name)1107 ModuleMethodVector InsertQuantDeQuantHelper::getInvokedMethods(
1108     Module& module,
1109     const std::string& method_name) {
1110   auto graph = module.get_method(method_name).graph();
1111 
1112   ModuleMethodVector invoked_methods;
1113   std::stack<Block*> blocks_to_visit;
1114   blocks_to_visit.push(graph->block());
1115   while (!blocks_to_visit.empty()) {
1116     Block* b = blocks_to_visit.top();
1117     blocks_to_visit.pop();
1118     for (Node* n : b->nodes()) {
1119       if (n->kind() == prim::CallMethod) {
1120         auto module_instance = n->inputs()[0];
1121         auto module_method_name = n->s(attr::name);
1122         std::optional<Module> m;
1123         // calling method on self
1124         if (module_instance == graph->inputs()[0]) {
1125           m = module;
1126         } else if (
1127             module_instance->node()->kind() == prim::GetAttr &&
1128             module_instance->node()->s(attr::name).find("_observer_") ==
1129                 std::string::npos) {
1130           m = getInvokedModuleOpt(module, n, graph->inputs()[0]);
1131         }
1132         if (m) {
1133           invoked_methods.emplace_back(*m, module_method_name);
1134         }
1135       }
1136 
1137       for (Block* subblock : n->blocks()) {
1138         blocks_to_visit.push(subblock);
1139       }
1140     }
1141   }
1142   return invoked_methods;
1143 }
1144 
propagateQParams(Value * original_output,const std::vector<Value * > & inputs,bool is_scalar,const std::optional<std::tuple<c10::QScheme,QParamVector>> & qparams_opt)1145 void InsertQuantDeQuantHelper::propagateQParams(
1146     Value* original_output,
1147     const std::vector<Value*>& inputs,
1148     bool is_scalar,
1149     const std::optional<std::tuple<c10::QScheme, QParamVector>>& qparams_opt) {
1150   Node* n = original_output->node();
1151   Graph* graph = n->owningGraph();
1152   if (is_scalar) {
1153     // convert Scalar to Tensor
1154     n = insertScalarToTensor(graph, original_output);
1155     original_output = n->output();
1156   }
1157   // for ops like average pool, we'll insert quant dequant after the op
1158   // We'll assume the tensor is a PerTensorAffine quantized Tensor for
1159   // now, and may generalize later if this becomes an issue
1160   TORCH_INTERNAL_ASSERT(
1161       inputs.size() == 1, "Expecting single input for the aten function");
1162   // input of the dequantize node
1163   Value* quantized_input = inputs[0]->node()->input(0);
1164   // insert ops after the general op
1165   Node* quantized_input_node = quantized_input->node();
1166   // Insert after the node that is later in topological order
1167   WithInsertPoint ins(
1168       quantized_input_node->isAfter(n) ? quantized_input_node->next()
1169                                        : n->next());
1170   std::vector<Value*> quant_inputs;
1171   auto quant_kind = Symbol::aten("quantize_per_tensor");
1172   if (qparams_opt.has_value()) {
1173     quant_inputs = {original_output};
1174     auto qscheme = std::get<0>(*qparams_opt);
1175     auto qparams = std::get<1>(*qparams_opt);
1176     if (isPerChannel(qscheme)) {
1177       quant_kind = Symbol::aten("quantize_per_channel");
1178     }
1179     for (const auto& qparam : qparams) {
1180       Value* qparam_val = graph->insertConstant(qparam.second);
1181       qparam_val->setDebugName(quantized_input->debugName() + qparam.first);
1182       quant_inputs.push_back(qparam_val);
1183     }
1184   } else {
1185     // Only per tensor affine quantized tensor is supported in this case
1186     // get quantization parameters from previous quantized op
1187     Node* scale = insertQParam(
1188         graph,
1189         quantized_input,
1190         at::Symbol::aten("q_scale"),
1191         FloatType::get(),
1192         "q_scale");
1193     Node* zero_point = insertQParam(
1194         graph,
1195         quantized_input,
1196         at::Symbol::aten("q_zero_point"),
1197         IntType::get(),
1198         "q_zero_point");
1199     Node* dtype = insertQParam(
1200         graph, quantized_input, prim::dtype, IntType::get(), "dtype");
1201     quant_inputs = {
1202         original_output,
1203         scale->output(),
1204         zero_point->output(),
1205         dtype->output()};
1206   }
1207   Node* quant = insertQuant(
1208       graph, quant_inputs, quant_kind, original_output->debugName() + ".quant");
1209   Value* quantized_output = quant->output();
1210   // replace uses of original output of the general op with quantized
1211   // output
1212   original_output->replaceAllUsesAfterNodeWith(quant, quantized_output);
1213   const auto& outputs =
1214       insertDeQuantForAllUse(graph, quantized_output, quantized_output);
1215   for (auto* output : outputs) {
1216     if (is_scalar) {
1217       // Convert the dequantized Tensor back to Scalar
1218       Node* item = insertItem(graph, output, FloatType::get());
1219       Value* scalar = item->output();
1220       output->replaceAllUsesAfterNodeWith(item, scalar);
1221       output = scalar;
1222     }
1223     quantized_values_.insert(output);
1224   }
1225 }
1226 
removeDequantizeFromInputs(const std::unordered_set<Value * > & inputs)1227 void removeDequantizeFromInputs(const std::unordered_set<Value*>& inputs) {
1228   // Delete dequantize node, we have one dequantize
1229   // for each use of the value
1230   for (auto* dequantized_val : inputs) {
1231     auto* dequantize_node = dequantized_val->node();
1232     TORCH_INTERNAL_ASSERT(
1233         dequantized_val->uses().size() == 1,
1234         "Expect to have one dequantize node for each use");
1235     // Replace useses of dequantized_val with the input of
1236     // dequantize node
1237     dequantized_val->replaceAllUsesWith(dequantize_node->inputs()[0]);
1238     dequantize_node->removeAllInputs();
1239     dequantize_node->destroy();
1240   }
1241 }
1242 
1243 // Check if we need to propagate the quantization ops from input to
1244 // output
getDequantizedInputs(Value * output)1245 std::optional<std::vector<Value*>> getDequantizedInputs(Value* output) {
1246   auto inputs = getPassThroughInputs(output);
1247   if (!inputs.empty()) {
1248     // note that we don't need to recursively check for prim::If
1249     // here because if all inputs of a prim::If is dequantized
1250     // the dequantize will be factored out before we get to this
1251     // point
1252     bool is_dequantized = true;
1253     for (auto* input : inputs) {
1254       GRAPH_DEBUG(
1255           "checking if input:",
1256           input->debugName(),
1257           " in node:",
1258           *input->node(),
1259           "is quantized");
1260       is_dequantized &= input->node()->kind() == Symbol::aten("dequantize");
1261     }
1262     if (is_dequantized) {
1263       return inputs;
1264     }
1265   }
1266   return std::nullopt;
1267 }
1268 
propagateQuantizationOps(Block * block)1269 void InsertQuantDeQuantHelper::propagateQuantizationOps(Block* block) {
1270   for (Node* n : block->nodes()) {
1271     if (n->kind() == prim::If) {
1272       for (Block* subblock : n->blocks()) {
1273         propagateQuantizationOps(subblock);
1274       }
1275       if (n->outputs().empty()) {
1276         continue;
1277       }
1278       if (n->outputs().size() > 1) {
1279         // Factoring out dequantize for if blocks with multiple outputs
1280         // is not supported right now
1281         continue;
1282       }
1283     }
1284     if (isSingleInputGeneralValueAtenFunction(n)) {
1285       for (auto* output : n->outputs()) {
1286         if (isQuantized(output)) {
1287           continue;
1288         }
1289         if (auto inputs = getDequantizedInputs(output)) {
1290           propagateQParams(output, *inputs);
1291           if (isClamp(n)) {
1292             for (size_t i = 1; i <= 2; ++i) {
1293               // propagate qparams for min and max scalar arguments
1294               // for aten::clamp/aten::hardtanh
1295               propagateQParams(n->input(i), *inputs, /* is_scalar */ true);
1296             }
1297           }
1298         }
1299       }
1300     } else if (auto qparams_opt = getFixedQParams(n)) {
1301       for (auto* output : n->outputs()) {
1302         if (isQuantized(output)) {
1303           continue;
1304         }
1305         if (auto inputs = getDequantizedInputs(output)) {
1306           propagateQParams(output, *inputs, /* is_scalar */ false, qparams_opt);
1307         }
1308       }
1309     } else {
1310       // For ops that are quantized by propagating dequantize ops,
1311       // e.g. flatten we need to
1312       // 1. check if we need to propagate dequantize op
1313       // 2. remove the dequantize ops from inputs
1314       // 3. insert dequantize for all outputs
1315       // to make sure it works for ops with multiple outputs
1316       // since removing dequantize from inputs is mutating the graph
1317       // and it will affect future checks for whether all the inputs
1318       // has been quantized or not(since currently we just check if
1319       // the value is produced by dequantize op to decide if the value
1320       // is quantized or not
1321       // list of dequantized input values
1322       std::unordered_set<Value*> dequantized_inputs;
1323       std::vector<Value*> outputs_to_dequantize;
1324       // 1. collect dequantized inputs and outputs we need to dequantize
1325       for (auto* output : n->outputs()) {
1326         if (isQuantized(output)) {
1327           continue;
1328         }
1329         if (auto inputs = getDequantizedInputs(output)) {
1330           std::copy(
1331               inputs->begin(),
1332               inputs->end(),
1333               std::inserter(dequantized_inputs, dequantized_inputs.end()));
1334           outputs_to_dequantize.push_back(output);
1335         }
1336       }
1337       // 2. remove the dequantize ops from inputs
1338       removeDequantizeFromInputs(dequantized_inputs);
1339       // 3. insert dequantize op for outputs
1340       for (auto* output : outputs_to_dequantize) {
1341         insertDeQuantForAllUse(output->owningGraph(), output, output);
1342       }
1343     }
1344 
1345     if (isBinaryOpWithScalarInput(n)) {
1346       // Print warning for add_scalar/mul_scalar when debug is enabled
1347       // since the quantization parameter for these ops depends on
1348       // input and it's too complicated to encode the equations in
1349       // the IR:
1350       // https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/quantized/cpu/BinaryOps.cpp#L64-L74
1351       if (debug_) {
1352         TORCH_WARN_ONCE(
1353             "debug option for add_scalar and mul_scalar is not supported, "
1354             "please don't use debug option for models that uses these ops.");
1355       }
1356     }
1357   }
1358 }
1359 
runWeightObserver(Module & module,const std::string & method_name)1360 void InsertQuantDeQuantHelper::runWeightObserver(
1361     Module& module,
1362     const std::string& method_name) {
1363   if (quant_type_ != QuantType::DYNAMIC) {
1364     return;
1365   }
1366 
1367   for (auto& invoked_methods : getInvokedMethods(module, method_name)) {
1368     auto& invoked_module = std::get<0>(invoked_methods);
1369     const auto& invoked_method_name = std::get<1>(invoked_methods);
1370     runWeightObserver(invoked_module, invoked_method_name);
1371   }
1372   Method method = module.get_method(method_name);
1373   auto graph = method.graph();
1374   Value* self = graph->inputs()[0];
1375 
1376   std::vector<Value*> weight_values;
1377   // Visit all blocks in the current graph to find weight values.
1378   std::stack<Block*> blocks_to_visit;
1379   blocks_to_visit.push(graph->block());
1380   while (!blocks_to_visit.empty()) {
1381     Block* b = blocks_to_visit.top();
1382     blocks_to_visit.pop();
1383     for (auto n : b->nodes()) {
1384       for (Value* v : n->outputs()) {
1385         if (!v->type()->isSubtypeOf(*TensorType::get())) {
1386           continue;
1387         }
1388         auto observer_name = findObserverName(v);
1389         if (observer_name && isWeight(module, v)) {
1390           weight_values.push_back(v);
1391         }
1392       }
1393       for (Block* subblock : n->blocks()) {
1394         blocks_to_visit.push(subblock);
1395       }
1396     }
1397   }
1398   // For all the observed weight values, find the corresponding subgraph that
1399   // contributes to the weight tensor, and run that subgraph to observe the
1400   // weight.
1401   for (const auto& v : weight_values) {
1402     extractAndRunWeightObserver(module, self, v);
1403   }
1404 }
1405 
run(Module & module,const std::string & method_name)1406 void InsertQuantDeQuantHelper::run(
1407     Module& module,
1408     const std::string& method_name) {
1409   for (auto& invoked_methods : getInvokedMethods(module, method_name)) {
1410     auto& invoked_module = std::get<0>(invoked_methods);
1411     const auto& invoked_method_name = std::get<1>(invoked_methods);
1412     run(invoked_module, invoked_method_name);
1413   }
1414 
1415   Method method = module.get_method(method_name);
1416   auto graph = method.graph();
1417   // We only need to register new parameters if the graph has
1418   // been quantized before
1419   // TODO: dedup this part with code in quantizeTensors
1420   if (observer_nodes_for_graph_.count(graph.get())) {
1421     for (auto* n : observer_nodes_for_graph_.at(graph.get())) {
1422       auto tp = getQSchemeAndQParamVector(module, n);
1423       checkQScheme(graph.get(), std::get<0>(tp));
1424       auto qparam_map = std::get<1>(tp);
1425       // We check the size here because for some observers (like
1426       // PlaceholderObserver) the qparams might be empty.
1427       if (!qparam_map.empty()) {
1428         TORCH_INTERNAL_ASSERT(
1429             qparam_name_map_for_node_.count(n),
1430             "Expected to have a qparam_name_map for node:",
1431             *n);
1432         auto qparam_name_map = qparam_name_map_for_node_.at(n);
1433         for (auto& pr : qparam_map) {
1434           const auto& name = pr.first;
1435           const auto& qparam = pr.second;
1436           module._ivalue()->setAttr(qparam_name_map.at(name), qparam);
1437         }
1438       }
1439     }
1440     return;
1441   }
1442 
1443   // prim::Param nodes do not belong to the graph. Hence the Insert
1444   // point is the beginning of graph node. This also safe guards against
1445   // observing a potentially mutated value due to some in-place operation
1446   std::vector<Value*> input_values;
1447   for (const auto idx : c10::irange(1, method.num_inputs())) {
1448     auto& v = graph->inputs()[idx];
1449     if (v->type()->isSubtypeOf(*TensorType::get())) {
1450       input_values.push_back(v);
1451     }
1452   }
1453 
1454   std::stack<Block*> blocks_to_visit;
1455   blocks_to_visit.push(graph->block());
1456   while (!blocks_to_visit.empty()) {
1457     Block* b = blocks_to_visit.top();
1458     blocks_to_visit.pop();
1459     for (auto it = b->nodes().begin(), end = b->nodes().end(); it != end;) {
1460       Node* n = *it++;
1461       for (Value* v : n->outputs()) {
1462         if (!v->type()->isSubtypeOf(*TensorType::get())) {
1463           continue;
1464         }
1465         collectObserverNodesAndValueToQuantize(module, v);
1466       }
1467 
1468       for (Block* subblock : n->blocks()) {
1469         blocks_to_visit.push(subblock);
1470       }
1471     }
1472   }
1473 
1474   for (Value* v : input_values) {
1475     collectObserverNodesAndValueToQuantize(module, v);
1476   }
1477   GRAPH_DUMP("Before Quantize Tensors:", graph);
1478   Value* self = graph->inputs()[0];
1479   quantizeTensors(module, graph.get(), self);
1480   GRAPH_DUMP("After Quantize Tensors:", graph);
1481 }
1482 
propagateQuantizationOps(Module & module)1483 void InsertQuantDeQuantHelper::propagateQuantizationOps(Module& module) {
1484   SwapFunctionalLinear(module);
1485   auto graph = module.get_method("forward").graph();
1486   Inline(*graph);
1487   ConstantPropagation(graph);
1488   ReplicateChooseQParamsQuantDequant(graph);
1489   RemoveRedundantQuantizationOps(graph);
1490   ReplicateQuant(graph);
1491   ReplicateDeQuant(graph);
1492   // TODO: add filter to the clamp patterns and remove this pass
1493   ReplicateClampScalarArgs(graph);
1494   propagateQuantizationOps(graph->block());
1495   RemoveRedundantDequantize(graph);
1496 }
1497 
1498 // Insert quant and dequant nodes into the graph for both static and dynamic
1499 // quant.
1500 template <>
insertQuantDequantNodes(Value * self,Node * observer,QuantOpParams & qparams,const std::string & quantize_func)1501 Node* insertQuantDequantNodes<QuantOpParams>(
1502     Value* self,
1503     Node* observer,
1504     QuantOpParams& qparams,
1505     const std::string& quantize_func) {
1506   (void)self;
1507   Graph* g = observer->owningGraph();
1508   Value* observer_out = observer->output();
1509   Value* original_val = observer->input(1);
1510   std::vector<Value*> inputs;
1511   // + 1 for tensor to be quantized
1512   inputs.reserve(qparams.qparams.size() + 1);
1513   inputs.push_back({observer_out});
1514   for (const auto& qparam_values : qparams.qparams) {
1515     inputs.push_back(qparam_values);
1516   }
1517   Node* quant = insertQuant(
1518       g,
1519       inputs,
1520       at::Symbol::aten(quantize_func),
1521       original_val->debugName() + ".quant");
1522   // Have to make sure that quant node appears after the values it depends on.
1523   for (Value* v : inputs) {
1524     quant->moveAfter(v->node());
1525   }
1526   Node* dequant = insertDeQuant(g, quant->output(), original_val);
1527   dequant->moveAfter(quant);
1528   return dequant;
1529 }
1530 
checkCalculateQParamsResultTypes(const Node * out)1531 void checkCalculateQParamsResultTypes(const Node* out) {
1532   TORCH_CHECK(
1533       out->outputs().size() == 2,
1534       "calculate_qparams should produce output of size 2 (scale, zero_point).");
1535   Value* scale = out->output(0);
1536   Value* zp = out->output(1);
1537   TORCH_CHECK(
1538       scale->type()->expect<TensorType>(),
1539       "Scale value should be of Tensor type.");
1540   TORCH_CHECK(
1541       zp->type()->expect<TensorType>(), "Scale value should be of float type.");
1542 }
1543 
insertCalculateQParams(script::Module & module,Graph * g,Node * n)1544 QuantOpParams InsertQuantDeQuantHelper::insertCalculateQParams(
1545     script::Module& module,
1546     Graph* g,
1547     Node* n) {
1548   // TODO: refactor findObserverName to take Node* as input
1549   Value* self = g->inputs()[0];
1550   Value* v = n->output();
1551   TORCH_INTERNAL_ASSERT(
1552       v->type()->isSubtypeOf(*TensorType::get()),
1553       "Expected output of observer node to be Tensor");
1554   auto observer_name = findObserverName(v);
1555   TORCH_INTERNAL_ASSERT(
1556       observer_name,
1557       "getQSchemeAndParamMap expects the corresponding observer for ",
1558       v->debugName(),
1559       " exists.");
1560   std::vector<Value*> qparams_graph_values;
1561   QuantOpParams quant_op_params;
1562 
1563   TORCH_CHECK(
1564       !isPlaceholderObserver(n->input(0)),
1565       "Placeholder observers are not supported in ondevice PTQ.");
1566   auto observer_module = module.attr(observer_name.value()).toModule();
1567   Value* observer_module_value = g->insertGetAttr(self, observer_name.value());
1568   auto scalar_type = observer_module.attr("dtype");
1569   TORCH_CHECK(
1570       scalar_type.toScalarType() != at::ScalarType::Undefined,
1571       "dtype of observer can't be undefined");
1572   // Not sure if we need to support this for on device PTQ.
1573   if (scalar_type == at::ScalarType::Half) {
1574     return quant_op_params;
1575   }
1576   auto calculate_qparams = observer_module.get_method("calculate_qparams");
1577   auto calculate_qparams_schema = calculate_qparams.function().getSchema();
1578   MatchedSchema matched_schema = matchSchema(
1579       calculate_qparams_schema,
1580       v->node()->sourceRange(),
1581       *g,
1582       {observer_module_value},
1583       {});
1584   Node* call = g->insertMethodCall("calculate_qparams", matched_schema)->node();
1585   Node* scale_zp_node = g->insertNode(g->createTupleUnpack(call->output(0)));
1586   checkCalculateQParamsResultTypes(scale_zp_node);
1587   auto qscheme = observer_module.attr("qscheme").toQScheme();
1588   quant_op_params.qscheme = qscheme;
1589   quant_op_params.qparams.push_back(scale_zp_node->output(0)); // scale Value*
1590   quant_op_params.qparams.push_back(
1591       scale_zp_node->output(1)); // zero_point Value*
1592   if (isPerChannel(qscheme)) {
1593     Value* ch_axis_value = g->insertGetAttr(observer_module_value, "ch_axis");
1594     quant_op_params.qparams.push_back(ch_axis_value);
1595   }
1596   Value* scalar_type_value = g->insertGetAttr(observer_module_value, "dtype");
1597   quant_op_params.qparams.push_back(scalar_type_value);
1598   return quant_op_params;
1599 }
1600 
insertCalculateQParamsAndQuantizationOps(Module & module,Graph * graph,Value * self)1601 void InsertQuantDeQuantHelper::insertCalculateQParamsAndQuantizationOps(
1602     Module& module,
1603     Graph* graph,
1604     Value* self) {
1605   if (!observer_nodes_for_graph_.count(graph)) {
1606     return;
1607   }
1608   for (auto* n : observer_nodes_for_graph_.at(graph)) {
1609     Graph* g = n->owningGraph();
1610     // Observer output
1611     Value* observer_out = n->output();
1612     // Inserting before insert point
1613     WithInsertPoint insert_qparams_calc(observer_out->node()->next());
1614     auto quant_op_params = insertCalculateQParams(module, g, n);
1615     insertQuantizationOps(
1616         module,
1617         self,
1618         n,
1619         isPerChannel(quant_op_params.qscheme),
1620         quant_op_params,
1621         quant_type_);
1622   }
1623 }
1624 
runForOnDevicePTQ(Module & module,const std::string & method_name)1625 void InsertQuantDeQuantHelper::runForOnDevicePTQ(
1626     Module& module,
1627     const std::string& method_name) {
1628   // In all likelihood this really wont do anything because we expect that
1629   // the input method for quantization's prepare step will be inlined. Thus
1630   // only call methods we will see will belong to observer's forward calls.
1631   for (auto& invoked_methods : getInvokedMethods(module, method_name)) {
1632     auto& invoked_module = std::get<0>(invoked_methods);
1633     const auto& invoked_method_name = std::get<1>(invoked_methods);
1634     runForOnDevicePTQ(invoked_module, invoked_method_name);
1635   }
1636 
1637   Method method = module.get_method(method_name);
1638   auto graph = method.graph();
1639   // Unliked the run method we dont need to extract new qparam values for the
1640   // the same graph used in different call site.
1641   // Reason is that for on device PTQ we dont:
1642   // 1. Run calculate_qparams
1643   // 2. Get the scale and zero point
1644   // 3. get axis and dtype
1645   // 4. register values from 2 and 3 as attributes on the parent module.
1646   // Instead we insert call to calculate_qparams (1) via insertCalculateQParams
1647   // in the graph itself. Then instead of 2 and 3, we get the output Value*
1648   // and for 3, we insert GetAttr for axis and dtype and use those Value*
1649   // with insterQuantizationOps
1650 
1651   // prim::Param nodes do not belong to the graph. Hence the Insert
1652   // point is the beginning of graph node. This also safe guards against
1653   // observing a potentially mutated value due to some in-place operation
1654   std::vector<Value*> input_values;
1655   for (const auto idx : c10::irange(1, method.num_inputs())) {
1656     auto& v = graph->inputs()[idx];
1657     if (v->type()->isSubtypeOf(*TensorType::get())) {
1658       input_values.push_back(v);
1659     }
1660   }
1661 
1662   std::stack<Block*> blocks_to_visit;
1663   blocks_to_visit.push(graph->block());
1664   while (!blocks_to_visit.empty()) {
1665     Block* b = blocks_to_visit.top();
1666     blocks_to_visit.pop();
1667     for (auto it = b->nodes().begin(), end = b->nodes().end(); it != end;) {
1668       Node* n = *it++;
1669       for (Value* v : n->outputs()) {
1670         if (!v->type()->isSubtypeOf(*TensorType::get())) {
1671           continue;
1672         }
1673         collectObserverNodesAndValueToQuantize(module, v);
1674       }
1675 
1676       for (Block* subblock : n->blocks()) {
1677         blocks_to_visit.push(subblock);
1678       }
1679     }
1680   }
1681 
1682   for (Value* v : input_values) {
1683     collectObserverNodesAndValueToQuantize(module, v);
1684   }
1685 
1686   GRAPH_DUMP("Before insertCalculateQparamsAndQuantizationOps:", graph);
1687   Value* self = graph->inputs()[0];
1688   insertCalculateQParamsAndQuantizationOps(module, graph.get(), self);
1689   GRAPH_DUMP("After insertCalculateQparamsAndQuantizationOps:", graph);
1690 }
1691 
1692 } // namespace
1693 
ReplicateQuant(std::shared_ptr<Graph> & graph)1694 void ReplicateQuant(std::shared_ptr<Graph>& graph) {
1695   std::stack<Block*> blocks_to_visit;
1696   std::vector<Node*> quant_nodes_to_rewrite;
1697   blocks_to_visit.push(graph->block());
1698   while (!blocks_to_visit.empty()) {
1699     Block* b = blocks_to_visit.top();
1700     blocks_to_visit.pop();
1701     for (Node* n : b->nodes()) {
1702       // find quantize node that quantizes the output of if
1703       if ((n->kind() == Symbol::aten("quantize_per_tensor") ||
1704            n->kind() == Symbol::aten("quantize_per_channel")) &&
1705           n->input(0)->node()->kind() == prim::If) {
1706         quant_nodes_to_rewrite.push_back(n);
1707       }
1708       for (Block* subblock : n->blocks()) {
1709         blocks_to_visit.push(subblock);
1710       }
1711     }
1712   }
1713   for (Node* n : quant_nodes_to_rewrite) {
1714     Node* if_node = n->input(0)->node();
1715     // move the nodes that produces the quantization parameters before
1716     // prim::If
1717     for (const auto i : c10::irange(1, n->inputs().size())) {
1718       n->input(i)->node()->moveBefore(if_node);
1719     }
1720     // replace all uses of the quantized node with the output of if node
1721     n->output()->replaceAllUsesWith(if_node->output());
1722     // add quantize nodes to the end of all blocks
1723     for (Block* if_block : if_node->blocks()) {
1724       TORCH_CHECK(
1725           if_block->outputs().size() == 1,
1726           "replicate quantize only works for `if` node with one output right now");
1727       // the original return value of the block
1728       Value* ret_val = if_block->outputs()[0];
1729       std::vector<Value*> quantize_inputs = n->inputs().vec();
1730       quantize_inputs[0] = ret_val;
1731       WithInsertPoint ins(if_block->return_node());
1732       Node* quant = graph->create(n->kind(), quantize_inputs);
1733       if_block->replaceOutput(0, quant->output());
1734       quant->output()->copyMetadata(ret_val);
1735       graph->insertNode(quant);
1736     }
1737   }
1738 
1739   for (Node* n : quant_nodes_to_rewrite) {
1740     n->removeAllInputs();
1741   }
1742   for (Node* n : quant_nodes_to_rewrite) {
1743     n->destroy();
1744   }
1745 }
1746 
ReplicateDeQuant(std::shared_ptr<Graph> & graph)1747 void ReplicateDeQuant(std::shared_ptr<Graph>& graph) {
1748   std::stack<Block*> blocks_to_visit;
1749   std::vector<Node*> dequant_nodes_to_rewrite;
1750   blocks_to_visit.push(graph->block());
1751   while (!blocks_to_visit.empty()) {
1752     Block* b = blocks_to_visit.top();
1753     blocks_to_visit.pop();
1754     for (Node* n : b->nodes()) {
1755       if (n->kind() == Symbol::aten("dequantize") &&
1756           n->output()->uses().size() > 1) {
1757         dequant_nodes_to_rewrite.push_back(n);
1758       }
1759       for (Block* subblock : n->blocks()) {
1760         blocks_to_visit.push(subblock);
1761       }
1762     }
1763   }
1764   for (Node* n : dequant_nodes_to_rewrite) {
1765     auto* quantized_val = n->input(0);
1766     auto* dequantized_val = n->output();
1767     insertDeQuantForAllUse(graph.get(), quantized_val, dequantized_val);
1768   }
1769 
1770   for (Node* n : dequant_nodes_to_rewrite) {
1771     n->removeAllInputs();
1772   }
1773 
1774   for (Node* n : dequant_nodes_to_rewrite) {
1775     n->destroy();
1776   }
1777 }
1778 
InsertQuantDeQuant(Module & input_module,const std::string & method_name,bool inplace,bool debug,QuantType quant_type)1779 Module InsertQuantDeQuant(
1780     Module& input_module,
1781     const std::string& method_name,
1782     bool inplace,
1783     bool debug,
1784     QuantType quant_type) {
1785   Module module = input_module.clone(inplace);
1786   InsertQuantDeQuantHelper h(quant_type, debug);
1787   h.runWeightObserver(module, method_name);
1788   h.run(module, method_name);
1789   h.cleanup(module);
1790   h.propagateQuantizationOps(module);
1791   return module;
1792 }
1793 
1794 /*
1795  *
1796  * Assumption: method_name method has observer placed
1797  * Objective: modify that method to insert calls to:
1798  * 1. calculate_qparams
1799  * 2. GetAttr for axis and dtype values
1800  * 3. Use Values from above two to insert calls to quant + dequant
1801  * Thus after this step you have a graph of, e.g., observe_forward,
1802  * that has observer nodes, calculate_qparams run on those observer nodes,
1803  * output of which is used by quant-dequant nodes. output of dequant is used
1804  * by the actual op.
1805  * Later on we will replace dequant + op (e.g. linear) with
1806  * 1. prepacked_op context
1807  * 2. unpack
1808  * 3. dequantize
1809  * 4. linear
1810  *
1811  * Of the above pattern 2, 3, and 4 can be replaced by linear_run op
1812  */
1813 // Module InsertQuantDeQuantForOnDevicePTQ(
InsertQuantDeQuantOnDevicePTQ(Module & input_module,const std::string & method_name,bool inplace,bool debug,QuantType quant_type)1814 Module InsertQuantDeQuantOnDevicePTQ(
1815     Module& input_module,
1816     const std::string& method_name,
1817     bool inplace,
1818     bool debug,
1819     QuantType quant_type) {
1820   Module module = input_module.clone(inplace);
1821   const std::string kObserveString = "observe_";
1822   const auto matched_pos = method_name.find(kObserveString);
1823   const auto end_pos = matched_pos + kObserveString.length();
1824   const std::string orig_method_name = method_name.substr(end_pos);
1825   TORCH_CHECK(
1826       matched_pos == 0,
1827       "Quant dequant nodes can only be added to observe_",
1828       orig_method_name,
1829       ". Please make sure to run prepare step for on-device PTQ.");
1830 
1831   std::string quantize_method_name = "quantize_" + orig_method_name;
1832   cloneMethod(module, method_name, quantize_method_name);
1833   InsertQuantDeQuantHelper h(quant_type, debug);
1834   h.runForOnDevicePTQ(module, quantize_method_name);
1835   h.removeObserverNodes(module);
1836   // Dont need:
1837   // ReplicateChooseQParamsQuantDequant: This is propagating dynamic quant's
1838   // quant dequant RemoveRedundantQuantizationOps: THis is removing activation
1839   // observers for dynamic quant when the op related to it is not dynamically
1840   // quantizable. Doesnt really make sense. In our case we wont have those
1841   // anyway since for dynamic quant activations wont be observed We can still
1842   // use this function because the above two methods should really be a noop
1843   h.propagateQuantizationOps(module);
1844   return module;
1845 }
1846 } // namespace jit
1847 } // namespace torch
1848