xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/passes/quantization/finalize.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/passes/quantization/finalize.h>
2 
3 #include <torch/csrc/jit/jit_log.h>
4 #include <torch/csrc/jit/passes/clear_profiling.h>
5 #include <torch/csrc/jit/passes/common_subexpression_elimination.h>
6 #include <torch/csrc/jit/passes/constant_pooling.h>
7 #include <torch/csrc/jit/passes/constant_propagation.h>
8 #include <torch/csrc/jit/passes/dead_code_elimination.h>
9 #include <torch/csrc/jit/passes/freeze_module.h>
10 #include <torch/csrc/jit/passes/loop_unrolling.h>
11 #include <torch/csrc/jit/passes/peephole.h>
12 #include <torch/csrc/jit/passes/prepack_folding.h>
13 #include <torch/csrc/jit/passes/quantization/quantization_patterns.h>
14 #include <torch/csrc/jit/passes/quantization/register_packed_params.h>
15 #include <torch/csrc/jit/runtime/graph_iterator.h>
16 
17 #include <utility>
18 
19 namespace torch {
20 namespace jit {
21 
22 namespace {
23 
insertPrepackUnpackForLinear(std::shared_ptr<Graph> & graph)24 void insertPrepackUnpackForLinear(std::shared_ptr<Graph>& graph) {
25   std::vector<QuantFusionInfo> patterns_and_replacements =
26       linear_prepack_unpack_patterns();
27 
28   for (const auto& entry : patterns_and_replacements) {
29     SubgraphRewriter rewriter;
30     rewriter.RegisterRewritePattern(entry.pattern, entry.replacement);
31     rewriter.runOnGraph(graph, entry.filters);
32   }
33 }
34 
insertPrepackUnpackForConv(std::shared_ptr<Graph> & graph)35 void insertPrepackUnpackForConv(std::shared_ptr<Graph>& graph) {
36   std::vector<QuantFusionInfo> patterns_and_replacements =
37       conv_prepack_unpack_patterns();
38 
39   for (const auto& entry : patterns_and_replacements) {
40     SubgraphRewriter rewriter;
41     rewriter.RegisterRewritePattern(entry.pattern, entry.replacement);
42     rewriter.runOnGraph(graph, entry.filters);
43   }
44 }
45 
removePackedParamInsertionAndFPWeightsSetAttr(std::shared_ptr<Graph> & g,const std::unordered_set<std::string> & packed_param_attr_names)46 void removePackedParamInsertionAndFPWeightsSetAttr(
47     std::shared_ptr<Graph>& g,
48     const std::unordered_set<std::string>& packed_param_attr_names) {
49   DepthFirstGraphNodeIterator it(g);
50   Node* n = nullptr;
51   std::vector<Node*> nodes_to_delete;
52   while ((n = it.next()) != nullptr) {
53     if (n->kind() == prim::SetAttr) {
54       const std::string& attr_name = n->s(attr::name);
55       if (packed_param_attr_names.count(attr_name)) {
56         nodes_to_delete.push_back(n);
57       } else {
58         Value* v = n->input(0);
59         Value* self = g->inputs()[0];
60         std::vector<std::string> paths = getModuleAccessPath(v, self);
61         std::string path = joinPaths(paths);
62         if (packed_param_attr_names.count(path)) {
63           nodes_to_delete.push_back(n);
64         }
65       }
66     }
67   }
68   for (auto node : nodes_to_delete) {
69     node->removeAllInputs();
70   }
71   for (auto node : nodes_to_delete) {
72     node->destroy();
73   }
74   ConstantPooling(g);
75   EliminateDeadCode(g);
76 }
77 
removeObserverCallMethods(std::shared_ptr<Graph> & g)78 void removeObserverCallMethods(std::shared_ptr<Graph>& g) {
79   DepthFirstGraphNodeIterator it(g);
80   Node* n = nullptr;
81   std::vector<Node*> nodes_to_delete;
82   while ((n = it.next()) != nullptr) {
83     if (n->kind() == prim::CallMethod) {
84       const std::string& attr_name = n->s(attr::name);
85       if (attr_name == "calculate_qparams") {
86         auto observer_node = n->input(0)->node();
87         if (observer_node->kind() == prim::GetAttr &&
88             observer_node->s(attr::name).find("_observer_") !=
89                 std::string::npos) {
90           nodes_to_delete.push_back(n);
91         }
92       }
93     }
94   }
95   for (auto node : nodes_to_delete) {
96     node->removeAllInputs();
97   }
98   for (auto node : nodes_to_delete) {
99     node->destroy();
100   }
101   EliminateDeadCode(g);
102 }
103 
keepOnlyPackedParamsGeneration(Module & m,const std::string & method_name)104 void keepOnlyPackedParamsGeneration(Module& m, const std::string& method_name) {
105   auto g = m.get_method(method_name).graph();
106   Function& function = m.get_method(method_name).function();
107   const auto& schema = function.getSchema();
108   auto new_schema = schema.cloneWithReturns({Argument("", NoneType::get())});
109   for (size_t i = 0, output_size = g->outputs().size(); i < output_size; i++) {
110     g->eraseOutput(i);
111   }
112   Node* none_node = g->createNone();
113   g->registerOutput(none_node->output());
114   none_node->insertBefore(g->return_node());
115   function.setSchema(std::move(new_schema));
116   EliminateDeadCode(g);
117 }
118 
119 } // namespace
120 
QuantFusion(std::shared_ptr<Graph> & graph,QuantType quant_type)121 void QuantFusion(std::shared_ptr<Graph>& graph, QuantType quant_type) {
122   std::vector<QuantFusionInfo> patterns;
123   if (quant_type == QuantType::DYNAMIC) {
124     patterns = dynamic_quant_fusion_pattern_and_replacements();
125     std::vector<QuantFusionInfo> patterns_wo_dynamic_activation_quant =
126         dynamic_quantized_linear_pattern_and_replacements();
127     patterns.insert(
128         patterns.end(),
129         patterns_wo_dynamic_activation_quant.begin(),
130         patterns_wo_dynamic_activation_quant.end());
131   } else {
132     patterns = quant_fusion_pattern_and_replacements();
133   }
134   for (const auto& info : patterns) {
135     SubgraphRewriter rewriter;
136     rewriter.RegisterRewritePattern(info.pattern, info.replacement);
137     rewriter.runOnGraph(graph, info.filters);
138   }
139 }
140 
InsertPrepackUnpack(std::shared_ptr<Graph> & graph)141 void InsertPrepackUnpack(std::shared_ptr<Graph>& graph) {
142   insertPrepackUnpackForLinear(graph);
143   insertPrepackUnpackForConv(graph);
144 }
145 
InsertPrepackUnpack(Module & module)146 void InsertPrepackUnpack(Module& module) {
147   for (auto& method : module.get_methods()) {
148     auto graph = method.graph();
149     InsertPrepackUnpack(graph);
150   }
151   for (Module m : module.children()) {
152     InsertPrepackUnpack(m);
153   }
154 }
155 
FoldQuantizedPrepackingOps(Module & module)156 void FoldQuantizedPrepackingOps(Module& module) {
157   auto filter_fn = [](const Node* n) -> bool {
158     return (
159         n->kind() == Symbol::fromQualString("quantized::linear_prepack") ||
160         n->kind() == Symbol::fromQualString("quantized::conv1d_prepack") ||
161         n->kind() == Symbol::fromQualString("quantized::conv2d_prepack") ||
162         n->kind() == Symbol::fromQualString("quantized::conv3d_prepack") ||
163         n->kind() ==
164             Symbol::fromQualString("quantized::conv_transpose1d_prepack") ||
165         n->kind() ==
166             Symbol::fromQualString("quantized::conv_transpose2d_prepack"));
167   };
168   PrePackingOpsFolder(module, filter_fn, "quantized");
169 }
170 
RegisterPrePackingParams(Module & module,const std::string & method_name)171 static std::unordered_set<std::string> RegisterPrePackingParams(
172     Module& module,
173     const std::string& method_name) {
174   auto filter_fn = [](const Node* n) -> bool {
175     return (
176         n->kind() == Symbol::fromQualString("quantized::linear_prepack") ||
177         n->kind() == Symbol::fromQualString("quantized::conv1d_prepack") ||
178         n->kind() == Symbol::fromQualString("quantized::conv2d_prepack") ||
179         n->kind() == Symbol::fromQualString("quantized::conv3d_prepack") ||
180         n->kind() ==
181             Symbol::fromQualString("quantized::conv_transpose1d_prepack") ||
182         n->kind() ==
183             Symbol::fromQualString("quantized::conv_transpose2d_prepack"));
184   };
185   return RegisterPrePackParams(module, method_name, filter_fn, "");
186 }
187 
Finalize(Module & module,QuantType quant_type,const std::vector<std::string> & preserved_attrs)188 Module Finalize(
189     Module& module,
190     QuantType quant_type,
191     const std::vector<std::string>& preserved_attrs) {
192   // Tracing annotates the resulting graph with shape information. In many case,
193   // user applies different input shapes to traced graph. It is on the user to
194   // know it is correct to do so. The quantized module needs to be clean up and
195   // To prevent the JIT optimizations from leveraging the annotated shape info,
196   // clear shape information in the graph.
197   for (auto func : module.type()->methods()) {
198     ClearProfilingInformation(toGraphFunction(*func).graph());
199   }
200 
201   auto graph = module.get_method("forward").graph();
202   InsertPrepackUnpack(graph);
203   GRAPH_DUMP("Before QuantFusion:", graph);
204   QuantFusion(graph, quant_type);
205   auto frozen = freeze_module(module, preserved_attrs);
206   FoldQuantizedPrepackingOps(frozen);
207   return frozen;
208 }
209 
FinalizeOnDevicePTQ(Module & module,QuantType quant_type,const std::string & method_name)210 Module FinalizeOnDevicePTQ(
211     Module& module,
212     QuantType quant_type,
213     const std::string& method_name) {
214   // Tracing annotates the resulting graph with shape information. In many case,
215   // user applies different input shapes to traced graph. It is on the user to
216   // know it is correct to do so. The quantized module needs to be clean up and
217   // To prevent the JIT optimizations from leveraging the annotated shape info,
218   // clear shape information in the graph.
219   for (auto func : module.type()->methods()) {
220     ClearProfilingInformation(toGraphFunction(*func).graph());
221   }
222 
223   const std::string kQuantizeString = "quantize_";
224   const auto matched_pos = method_name.find(kQuantizeString);
225   const auto end_pos = matched_pos + kQuantizeString.length();
226   const std::string orig_method_name = method_name.substr(end_pos);
227   TORCH_CHECK(
228       matched_pos == 0,
229       "Quantized ops can only be added to quantize_",
230       orig_method_name,
231       ". Please make sure to run quant/dequant nodes insertion step for on-device PTQ.");
232 
233   const std::string quantized_method_name = "quantized_" + orig_method_name;
234   auto graph = module.get_method(method_name).graph();
235   // Doing some AOT optimizations here
236   // Of all CSE seems to be required otherwise in some experiments
237   // serialized model is incorrect. As in it cannot be deserialized
238   // Rest are included as canonical optimizations that are not for inference
239   EliminateCommonSubexpression(graph);
240   EliminateDeadCode(graph);
241   PeepholeOptimize(graph);
242   ConstantPropagation(graph);
243   UnrollConstantLoops(graph);
244   ConstantPooling(graph);
245 
246   InsertPrepackUnpack(graph);
247   GRAPH_DUMP("Before QuantFusion:", graph);
248   QuantFusion(graph, quant_type);
249   auto packed_param_attr_names = RegisterPrePackingParams(module, method_name);
250   GRAPH_DUMP("After QuantFusion + packed param registration:", graph);
251 
252   // Now we have:
253   // 1. Inserted quantized weights packed params
254   // 2. Inserted packed params to module
255   // 3. Inserted quantized op
256   // The next thing we need is:
257   // 1. Replicate this method in quantize_forward
258   // 2. Remove SetAttr for fp weights that are reset by quantize_forward
259   // 3. Remove SetAttr node which will subsequently optimize away the nodes
260   //    producing packed_params
261   // 4. Modify quantized_forward to remove all the nodes except for SetAttrs
262   cloneMethod(module, method_name, quantized_method_name);
263   // removeWeightSetAttrs(module, quantized_method_name);
264   auto quantized_graph = module.get_method(quantized_method_name).graph();
265   removePackedParamInsertionAndFPWeightsSetAttr(
266       quantized_graph, packed_param_attr_names);
267   // Removing packed params is not sufficient since that does not do DCE
268   // for observer node's getatts and callmethods because callmethods have side
269   // effects
270   removeObserverCallMethods(quantized_graph);
271   // This step removed the return output from the graph and subsequent
272   // DCE removes all the ops. After that only remaining things should be
273   // packed_params
274   keepOnlyPackedParamsGeneration(module, method_name);
275   return module;
276 }
277 
278 } // namespace jit
279 } // namespace torch
280