1 #include <ATen/Context.h>
2 #include <torch/csrc/jit/mobile/module.h>
3 #include <torch/csrc/jit/mobile/quantization.h>
4
5 namespace torch::jit::mobile::quantization {
6
quantize_dynamic(torch::jit::mobile::Module & m,const std::string & method_name)7 void PTQQuanizationHelper::quantize_dynamic(
8 torch::jit::mobile::Module& m,
9 const std::string& method_name) {
10 at::globalContext().setReleaseWeightsWhenPrepacking(false);
11 std::string reset_observers_method_name = "reset_observers_" + method_name;
12 std::string observe_method_name = "observe_" + method_name;
13 std::string quantize_method_name = "quantize_" + method_name;
14 std::string quantized_method_name = "quantized_" + method_name;
15
16 TORCH_CHECK(
17 m.find_method(reset_observers_method_name).has_value(),
18 "PTQ ready module must have",
19 reset_observers_method_name,
20 " method.");
21 TORCH_CHECK(
22 m.find_method(observe_method_name),
23 "PTQ ready module must have",
24 reset_observers_method_name,
25 " method.");
26 TORCH_CHECK(
27 m.find_method(quantize_method_name),
28 "PTQ ready module must have",
29 quantize_method_name,
30 " method.");
31 TORCH_CHECK(
32 m.find_method(quantized_method_name),
33 "PTQ ready module must have",
34 quantized_method_name,
35 " method.");
36 TORCH_CHECK(
37 m.find_method("get_all_bundled_inputs"),
38 "PTQ ready module must have get_all_bundled_inputs method.");
39
40 auto inputs = m.run_method("get_all_bundled_inputs")
41 .toList()
42 .get(0)
43 .toTupleRef()
44 .elements()
45 .vec();
46 m.get_method(reset_observers_method_name)({});
47 m.get_method(observe_method_name)(inputs);
48 m.get_method(quantize_method_name)(inputs);
49
50 m.compareMethodSchemas(method_name, quantized_method_name);
51 m.unsafeRemoveMethod(method_name);
52 const Function& to_be_copied =
53 m.find_method(quantized_method_name).value().function();
54 m.unsafeCopyMethod(method_name, to_be_copied);
55 m.unsafeRemoveMethod(quantized_method_name);
56 m.unsafeRemoveMethod(quantize_method_name);
57 m.unsafeRemoveMethod(observe_method_name);
58 m.unsafeRemoveMethod(reset_observers_method_name);
59 }
60 } // namespace torch::jit::mobile::quantization
61