xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/passes/quantization/insert_quant_dequant.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <torch/csrc/jit/api/module.h>
4 #include <torch/csrc/jit/ir/ir.h>
5 #include <torch/csrc/jit/passes/quantization/quantization_type.h>
6 
7 namespace torch {
8 namespace jit {
9 
10 /** Replicate quantize node for prim::If blocks, so that we can match
11  *  quantization patterns in prim::If blocks
12  */
13 TORCH_API void ReplicateQuant(std::shared_ptr<Graph>& graph);
14 
15 /** Replicate dequantize node for each use, so that we can match
16  *  quantization patterns
17  */
18 TORCH_API void ReplicateDeQuant(std::shared_ptr<Graph>& graph);
19 
20 /** \brief Insert quantize - dequantize calls to the Tensors
21  *  that are observed in insert_observers pass
22  *
23  * For each Tensor that is observed, get the observer module and call
24  * calculate_qparam on the observer module to get quantization parameters
25  * and add quantize - int_repr - dequantize function calls using these
26  * parameters we also have special handling for quantizing "bias" right now.
27  *
28  * \param module the input module
29  * \param method_name the method we want to insert quantization calls for
30  */
31 TORCH_API Module InsertQuantDeQuant(
32     Module& module,
33     const std::string& method_name,
34     bool inplace,
35     bool debug,
36     QuantType quant_type = QuantType::STATIC);
37 
38 TORCH_API Module InsertQuantDeQuantOnDevicePTQ(
39     Module& module,
40     const std::string& method_name,
41     bool inplace,
42     bool debug,
43     QuantType quant_type = QuantType::STATIC);
44 
45 } // namespace jit
46 } // namespace torch
47