xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/passes/quantization/helper.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 #include <torch/csrc/jit/api/module.h>
3 #include <torch/csrc/jit/ir/ir.h>
4 #include <torch/csrc/jit/ir/subgraph_matcher.h>
5 #include <torch/csrc/jit/passes/graph_rewrite_helper.h>
6 #include <torch/csrc/jit/passes/quantization/quantization_type.h>
7 
8 #include <functional>
9 #include <regex>
10 
11 namespace torch {
12 namespace jit {
13 
14 using graph_rewrite_helper::getFuncName;
15 
16 // Vector of a module and the name of its method
17 using ModuleMethodVector = std::vector<std::pair<Module, std::string>>;
18 // Map of quantization parameter name and value
19 // for example _scale, _zero_point,
20 // _scalar_type and _axis(for per channel quantization)
21 using QParamVector = std::vector<std::pair<std::string, IValue>>;
22 
23 // =========== helper functions for Value =========
24 // Check if a value is weight, since we need to use weight observer
25 // for weight
26 TORCH_API bool isWeight(Value* v);
27 
28 // Check if a value is bias for conv and linear, which we do not
29 // quantize
30 TORCH_API bool isBiasOfConvOrLinear(Value* v);
31 
32 TORCH_API bool isEmbeddingBagNonInput(Value* v);
33 
34 // Get the use as scalar input of clamp ops for the input value
35 std::optional<Use> getClampScalarInputUse(Value* v);
36 
37 // For a given value `v`, get the list of values that we need to check
38 // if they are observed/quantized or not, if so, we can say the
39 // `v` is also observed/quantized, since we can derive
40 // the quantization parameters for `v` given the list of values
41 TORCH_API std::vector<Value*> getPassThroughInputs(Value* v);
42 
43 // Clones the method by the name of orig_method_name into new_method_name method
44 TORCH_API void cloneMethod(
45     Module& module,
46     const std::string& orig_method_name,
47     const std::string& new_method_name);
48 
49 // Check if a value in the graph is a Scalar value
50 TORCH_API bool isScalar(Value* v);
51 
52 // Check if value is the input of the graph
53 TORCH_API bool hitGraphInput(Value* value);
54 
55 // Converts a mangled name, such as
56 //   __torch__.torch.ao.nn.quantized.modules.conv.___torch_mangle_7.Conv2d
57 // into an unmangled name, such as
58 //   __torch__.torch.ao.nn.quantized.modules.conv.Conv2d
59 TORCH_API std::string removeTorchMangle(const std::string& orig_name);
60 
61 // Return the module name that corresponds to the value.
62 TORCH_API std::optional<std::string> getModuleName(Value* value);
63 
64 // =========== helper functions for Node =========
65 TORCH_API bool isSingleInputGeneralShapeAtenFunction(Node* n);
66 
67 TORCH_API bool isSingleInputGeneralValueAtenFunction(Node* n);
68 
69 TORCH_API bool isSingleInputGeneralCallFunction(Node* n);
70 
71 TORCH_API bool isSingleInputGeneralAtenFunction(Node* n);
72 
73 TORCH_API bool isClamp(Node* n);
74 
75 // Check if the node will produce the same result regardless of whether
76 // the input tensor is quantized or not, example: aten::size
77 TORCH_API bool isTensorInfoNode(Node* n);
78 
79 // Check if this the propagate op that has single input, e.g. aten::cat
80 TORCH_API bool isPropagateQuantSingleInputOp(Node* n);
81 
82 // Check if this is the propagate op that has two inputs, e.g. aten::add
83 TORCH_API bool isPropagateQuantBinaryOp(Node* n);
84 
85 // Check if this is the node that we'll quantize or not quantize depending on
86 // whether the input of the node is quantized, example: aten::cat
87 TORCH_API bool isPropagateQuantOp(Node* n);
88 
89 // Check if the node is a binary op like aten::add and aten::mul and
90 // if the input 1 is a scalar, these ops will be quantized to
91 // quantized::{op}_scalar
92 TORCH_API bool isBinaryOpWithScalarInput(Node* n);
93 
94 TORCH_API std::optional<std::tuple<c10::QScheme, QParamVector>> getFixedQParams(
95     Node* n);
96 
97 // We don't want to analyze the graph for some `builtin` CallFunctions
98 // like `linear` because we want to preserve the op boundary
99 TORCH_API bool userDefinedCallFunction(Node* n);
100 
101 // Check if the node has scalar input
102 TORCH_API bool hasScalarInput(Node* n);
103 
104 // Check if a node is quantizable
105 TORCH_API bool nodeQuantizable(
106     Node* n,
107     QuantType quant_type = QuantType::STATIC);
108 
109 // Nodes which only require quantization of weight value, eg. embedding_bag
110 bool isWeightOnlyStaticQuantOp(Node* n);
111 
112 // Check if a use of the value is quantizable, this depends on
113 // both the use node and the offset
114 TORCH_API bool useQuantizable(const Use& use, QuantType quant_type);
115 
116 // Given a CallFunction node, extract the graph of the called function
117 TORCH_API std::shared_ptr<Graph> getCallFunctionGraph(Node* n);
118 
119 // Check if `use` is a CallFunction of name `func_name` and if value
120 // `v` is the nth argument (if provided) of the function
121 bool matchCallFuncToUse(
122     const Use& use,
123     const std::string& func_name,
124     std::optional<int> nth_arg);
125 
126 // Check if `use` is a AtenFunction of name `func_name` and if value
127 // `v` is the nth argument (if provided) of the function
128 bool matchAtenFuncToUse(
129     const Use& use,
130     const std::string& func_name,
131     std::optional<int> nth_arg);
132 
133 // =========== helper functions for Block =========
134 // checks if a block will always raise an Exception
135 TORCH_API bool alwaysRaisesException(Block* block);
136 
137 // =========== helper functions for Module  ==========
138 // TODO: remove
139 TORCH_API std::vector<std::string> getModuleAccessPath(
140     Value* instance,
141     Value* self);
142 // TODO: remove
143 TORCH_API Module
144 findChildModule(const Module& module, const std::vector<std::string>& path);
145 
146 // Given an CallMethod node, get the module instance corresponding
147 // to the instance Value
148 // TODO: refactor all current uses of this function to the Opt one
149 TORCH_API Module getInvokedModule(Module& module, Node* n, Value* self);
150 
151 // Given an CallMethod node, get the module instance corresponding
152 // to the instance Value if the instance is a module, otherwise return
153 // std::nullopt
154 std::optional<Module> getInvokedModuleOpt(
155     const Module& module,
156     Node* n,
157     Value* self);
158 
159 // ==================== filter functions for matches ==============
160 // filter to check Value `vname` is a constant of int value `value`
161 bool is_int_constant(
162     const Match& match,
163     const std::unordered_map<std::string, Value*>& vmap,
164     const std::string& vname,
165     int value);
166 
167 // filter to check if the %alpha argument of aten::add is constant 1
168 bool aten_add_alpha_is_one(
169     const Match& match,
170     const std::unordered_map<std::string, Value*>& vmap);
171 
172 // filter to check if the functional in CallFunction is relu
173 bool is_functional_relu(
174     const Match& match,
175     const std::unordered_map<std::string, Value*>& vmap);
176 
177 // filter to check if the module is torch.nn.ReLU
178 bool is_relu_module(
179     const Match& match,
180     const std::unordered_map<std::string, Value*>& vmap);
181 
182 bool is_linear_module(
183     const Match& match,
184     const std::unordered_map<std::string, Value*>& vmap);
185 
186 // TODO: add a macro to declare the filters
187 bool is_conv1d_module(
188     const Match& match,
189     const std::unordered_map<std::string, Value*>& vmap);
190 
191 bool is_conv2d_module(
192     const Match& match,
193     const std::unordered_map<std::string, Value*>& vmap);
194 
195 bool is_conv3d_module(
196     const Match& match,
197     const std::unordered_map<std::string, Value*>& vmap);
198 
199 bool is_conv_transpose1d_module(
200     const Match& match,
201     const std::unordered_map<std::string, Value*>& vmap);
202 
203 bool is_conv_transpose2d_module(
204     const Match& match,
205     const std::unordered_map<std::string, Value*>& vmap);
206 
207 bool is_batchnorm2d_module(
208     const Match& match,
209     const std::unordered_map<std::string, Value*>& vmap);
210 
211 bool is_batchnorm3d_module(
212     const Match& match,
213     const std::unordered_map<std::string, Value*>& vmap);
214 
215 } // namespace jit
216 } // namespace torch
217