1 #pragma once 2 #include <torch/csrc/jit/api/module.h> 3 #include <torch/csrc/jit/ir/ir.h> 4 #include <torch/csrc/jit/ir/subgraph_matcher.h> 5 #include <torch/csrc/jit/passes/graph_rewrite_helper.h> 6 #include <torch/csrc/jit/passes/quantization/quantization_type.h> 7 8 #include <functional> 9 #include <regex> 10 11 namespace torch { 12 namespace jit { 13 14 using graph_rewrite_helper::getFuncName; 15 16 // Vector of a module and the name of its method 17 using ModuleMethodVector = std::vector<std::pair<Module, std::string>>; 18 // Map of quantization parameter name and value 19 // for example _scale, _zero_point, 20 // _scalar_type and _axis(for per channel quantization) 21 using QParamVector = std::vector<std::pair<std::string, IValue>>; 22 23 // =========== helper functions for Value ========= 24 // Check if a value is weight, since we need to use weight observer 25 // for weight 26 TORCH_API bool isWeight(Value* v); 27 28 // Check if a value is bias for conv and linear, which we do not 29 // quantize 30 TORCH_API bool isBiasOfConvOrLinear(Value* v); 31 32 TORCH_API bool isEmbeddingBagNonInput(Value* v); 33 34 // Get the use as scalar input of clamp ops for the input value 35 std::optional<Use> getClampScalarInputUse(Value* v); 36 37 // For a given value `v`, get the list of values that we need to check 38 // if they are observed/quantized or not, if so, we can say the 39 // `v` is also observed/quantized, since we can derive 40 // the quantization parameters for `v` given the list of values 41 TORCH_API std::vector<Value*> getPassThroughInputs(Value* v); 42 43 // Clones the method by the name of orig_method_name into new_method_name method 44 TORCH_API void cloneMethod( 45 Module& module, 46 const std::string& orig_method_name, 47 const std::string& new_method_name); 48 49 // Check if a value in the graph is a Scalar value 50 TORCH_API bool isScalar(Value* v); 51 52 // Check if value is the input of the graph 53 TORCH_API bool hitGraphInput(Value* value); 54 55 // Converts a mangled name, such as 56 // __torch__.torch.ao.nn.quantized.modules.conv.___torch_mangle_7.Conv2d 57 // into an unmangled name, such as 58 // __torch__.torch.ao.nn.quantized.modules.conv.Conv2d 59 TORCH_API std::string removeTorchMangle(const std::string& orig_name); 60 61 // Return the module name that corresponds to the value. 62 TORCH_API std::optional<std::string> getModuleName(Value* value); 63 64 // =========== helper functions for Node ========= 65 TORCH_API bool isSingleInputGeneralShapeAtenFunction(Node* n); 66 67 TORCH_API bool isSingleInputGeneralValueAtenFunction(Node* n); 68 69 TORCH_API bool isSingleInputGeneralCallFunction(Node* n); 70 71 TORCH_API bool isSingleInputGeneralAtenFunction(Node* n); 72 73 TORCH_API bool isClamp(Node* n); 74 75 // Check if the node will produce the same result regardless of whether 76 // the input tensor is quantized or not, example: aten::size 77 TORCH_API bool isTensorInfoNode(Node* n); 78 79 // Check if this the propagate op that has single input, e.g. aten::cat 80 TORCH_API bool isPropagateQuantSingleInputOp(Node* n); 81 82 // Check if this is the propagate op that has two inputs, e.g. aten::add 83 TORCH_API bool isPropagateQuantBinaryOp(Node* n); 84 85 // Check if this is the node that we'll quantize or not quantize depending on 86 // whether the input of the node is quantized, example: aten::cat 87 TORCH_API bool isPropagateQuantOp(Node* n); 88 89 // Check if the node is a binary op like aten::add and aten::mul and 90 // if the input 1 is a scalar, these ops will be quantized to 91 // quantized::{op}_scalar 92 TORCH_API bool isBinaryOpWithScalarInput(Node* n); 93 94 TORCH_API std::optional<std::tuple<c10::QScheme, QParamVector>> getFixedQParams( 95 Node* n); 96 97 // We don't want to analyze the graph for some `builtin` CallFunctions 98 // like `linear` because we want to preserve the op boundary 99 TORCH_API bool userDefinedCallFunction(Node* n); 100 101 // Check if the node has scalar input 102 TORCH_API bool hasScalarInput(Node* n); 103 104 // Check if a node is quantizable 105 TORCH_API bool nodeQuantizable( 106 Node* n, 107 QuantType quant_type = QuantType::STATIC); 108 109 // Nodes which only require quantization of weight value, eg. embedding_bag 110 bool isWeightOnlyStaticQuantOp(Node* n); 111 112 // Check if a use of the value is quantizable, this depends on 113 // both the use node and the offset 114 TORCH_API bool useQuantizable(const Use& use, QuantType quant_type); 115 116 // Given a CallFunction node, extract the graph of the called function 117 TORCH_API std::shared_ptr<Graph> getCallFunctionGraph(Node* n); 118 119 // Check if `use` is a CallFunction of name `func_name` and if value 120 // `v` is the nth argument (if provided) of the function 121 bool matchCallFuncToUse( 122 const Use& use, 123 const std::string& func_name, 124 std::optional<int> nth_arg); 125 126 // Check if `use` is a AtenFunction of name `func_name` and if value 127 // `v` is the nth argument (if provided) of the function 128 bool matchAtenFuncToUse( 129 const Use& use, 130 const std::string& func_name, 131 std::optional<int> nth_arg); 132 133 // =========== helper functions for Block ========= 134 // checks if a block will always raise an Exception 135 TORCH_API bool alwaysRaisesException(Block* block); 136 137 // =========== helper functions for Module ========== 138 // TODO: remove 139 TORCH_API std::vector<std::string> getModuleAccessPath( 140 Value* instance, 141 Value* self); 142 // TODO: remove 143 TORCH_API Module 144 findChildModule(const Module& module, const std::vector<std::string>& path); 145 146 // Given an CallMethod node, get the module instance corresponding 147 // to the instance Value 148 // TODO: refactor all current uses of this function to the Opt one 149 TORCH_API Module getInvokedModule(Module& module, Node* n, Value* self); 150 151 // Given an CallMethod node, get the module instance corresponding 152 // to the instance Value if the instance is a module, otherwise return 153 // std::nullopt 154 std::optional<Module> getInvokedModuleOpt( 155 const Module& module, 156 Node* n, 157 Value* self); 158 159 // ==================== filter functions for matches ============== 160 // filter to check Value `vname` is a constant of int value `value` 161 bool is_int_constant( 162 const Match& match, 163 const std::unordered_map<std::string, Value*>& vmap, 164 const std::string& vname, 165 int value); 166 167 // filter to check if the %alpha argument of aten::add is constant 1 168 bool aten_add_alpha_is_one( 169 const Match& match, 170 const std::unordered_map<std::string, Value*>& vmap); 171 172 // filter to check if the functional in CallFunction is relu 173 bool is_functional_relu( 174 const Match& match, 175 const std::unordered_map<std::string, Value*>& vmap); 176 177 // filter to check if the module is torch.nn.ReLU 178 bool is_relu_module( 179 const Match& match, 180 const std::unordered_map<std::string, Value*>& vmap); 181 182 bool is_linear_module( 183 const Match& match, 184 const std::unordered_map<std::string, Value*>& vmap); 185 186 // TODO: add a macro to declare the filters 187 bool is_conv1d_module( 188 const Match& match, 189 const std::unordered_map<std::string, Value*>& vmap); 190 191 bool is_conv2d_module( 192 const Match& match, 193 const std::unordered_map<std::string, Value*>& vmap); 194 195 bool is_conv3d_module( 196 const Match& match, 197 const std::unordered_map<std::string, Value*>& vmap); 198 199 bool is_conv_transpose1d_module( 200 const Match& match, 201 const std::unordered_map<std::string, Value*>& vmap); 202 203 bool is_conv_transpose2d_module( 204 const Match& match, 205 const std::unordered_map<std::string, Value*>& vmap); 206 207 bool is_batchnorm2d_module( 208 const Match& match, 209 const std::unordered_map<std::string, Value*>& vmap); 210 211 bool is_batchnorm3d_module( 212 const Match& match, 213 const std::unordered_map<std::string, Value*>& vmap); 214 215 } // namespace jit 216 } // namespace torch 217