xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/mobile/quantization.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/Context.h>
2 #include <torch/csrc/jit/mobile/module.h>
3 #include <torch/csrc/jit/mobile/quantization.h>
4 
5 namespace torch::jit::mobile::quantization {
6 
quantize_dynamic(torch::jit::mobile::Module & m,const std::string & method_name)7 void PTQQuanizationHelper::quantize_dynamic(
8     torch::jit::mobile::Module& m,
9     const std::string& method_name) {
10   at::globalContext().setReleaseWeightsWhenPrepacking(false);
11   std::string reset_observers_method_name = "reset_observers_" + method_name;
12   std::string observe_method_name = "observe_" + method_name;
13   std::string quantize_method_name = "quantize_" + method_name;
14   std::string quantized_method_name = "quantized_" + method_name;
15 
16   TORCH_CHECK(
17       m.find_method(reset_observers_method_name).has_value(),
18       "PTQ ready module must have",
19       reset_observers_method_name,
20       " method.");
21   TORCH_CHECK(
22       m.find_method(observe_method_name),
23       "PTQ ready module must have",
24       reset_observers_method_name,
25       " method.");
26   TORCH_CHECK(
27       m.find_method(quantize_method_name),
28       "PTQ ready module must have",
29       quantize_method_name,
30       " method.");
31   TORCH_CHECK(
32       m.find_method(quantized_method_name),
33       "PTQ ready module must have",
34       quantized_method_name,
35       " method.");
36   TORCH_CHECK(
37       m.find_method("get_all_bundled_inputs"),
38       "PTQ ready module must have get_all_bundled_inputs method.");
39 
40   auto inputs = m.run_method("get_all_bundled_inputs")
41                     .toList()
42                     .get(0)
43                     .toTupleRef()
44                     .elements()
45                     .vec();
46   m.get_method(reset_observers_method_name)({});
47   m.get_method(observe_method_name)(inputs);
48   m.get_method(quantize_method_name)(inputs);
49 
50   m.compareMethodSchemas(method_name, quantized_method_name);
51   m.unsafeRemoveMethod(method_name);
52   const Function& to_be_copied =
53       m.find_method(quantized_method_name).value().function();
54   m.unsafeCopyMethod(method_name, to_be_copied);
55   m.unsafeRemoveMethod(quantized_method_name);
56   m.unsafeRemoveMethod(quantize_method_name);
57   m.unsafeRemoveMethod(observe_method_name);
58   m.unsafeRemoveMethod(reset_observers_method_name);
59 }
60 } // namespace torch::jit::mobile::quantization
61