xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/passes/quantization/helper.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/passes/quantization/helper.h>
2 
3 #include <torch/csrc/jit/api/function_impl.h>
4 #include <torch/csrc/jit/passes/graph_rewrite_helper.h>
5 
6 #include <utility>
7 
8 namespace torch {
9 namespace jit {
10 
11 using graph_rewrite_helper::getFuncName;
12 
13 struct FuncArg {
14   std::string func_name;
15   int arg_index;
16 };
17 
18 using AtenFuncArgs = std::vector<FuncArg>;
19 using CallFuncArgs = std::vector<FuncArg>;
20 
21 // Lists of allowed quantizable operators
22 std::vector<std::string> _static_quantizable_call_funcs = {
23     "conv2d",
24     "linear",
25     "batch_norm",
26     "hardswish",
27     "elu",
28     "celu",
29     "layer_norm",
30     "group_norm",
31     "instance_norm",
32     "embedding_bag",
33 };
34 
35 std::vector<std::string> _static_quantizable_aten_funcs = {
36     "conv1d",
37     "conv2d",
38     "conv3d",
39     "conv_transpose1d",
40     "conv_transpose2d",
41     "linear",
42     "hardswish",
43     "hardswish_",
44     "elu",
45     "elu_",
46     "celu",
47     "celu_",
48     "batch_norm",
49     "layer_norm",
50     "group_norm",
51     "instance_norm",
52     "embedding_bag",
53 };
54 
55 std::vector<std::string> _dynamic_quantizable_call_funcs = {
56     "linear",
57 };
58 
59 std::vector<std::string> _dynamic_quantizable_aten_funcs = {
60     "linear",
61 };
62 
63 std::vector<std::string> _static_weight_only_quant_aten_funcs = {
64     "embedding_bag",
65 };
66 std::vector<std::string> _static_weight_only_quant_call_funcs = {
67     "embedding_bag",
68 };
69 
70 // These are the prim::CallFunctions that doesn't require observation and
71 // have a single input Tensor
72 // example: `prim::CallFunction(%dropout, %input_tensor, ...)
73 // so we propagate observed property from %input_tensor to the
74 // output of the `prim::CallFunction`
75 // Also these ops doesn't do computation on the value of Tensor, the
76 // operation only depends on the shape of the Tensor
77 std::vector<std::string> _single_input_general_shape_call_funcs = {
78     "_max_pool1d",
79     "_max_pool2d",
80     "_max_pool3d",
81     "dropout",
82     "relu",
83 };
84 
85 // Similar to prim::CallFunctions, there are aten ops that doesn't
86 // require observation and have a single input Tensor
87 // Also these ops doesn't do computation on the value of Tensor, the
88 // operation only depends on the shape of the Tensor
89 // e.g. `aten::flatten(%input_tensor, ...)`
90 std::vector<std::string> _single_input_general_shape_aten_funcs = {
91     "max_pool1d",
92     "max_pool2d",
93     "max_pool3d",
94     "flatten",
95     "max",
96     "min",
97     "dropout",
98     "reshape",
99     // Non-inplace resize is deprecated
100     "resize_",
101     "chunk",
102     "view",
103     "transpose",
104     "contiguous",
105     "permute",
106     "repeat",
107     "repeat_interleave",
108     "relu",
109     "relu_",
110     "squeeze",
111     "squeeze_",
112     "unsqueeze",
113     "unsqueeze_",
114     "detach",
115     "detach_",
116     "stack",
117     "__getitem__",
118 };
119 
120 // Theses are prim::CallFunctions for ops that doesn't require observation and
121 // have a single input Tensor
122 // Also these ops do computation on the value of Tensor
123 // TODO: [Need verify] looks like we can quantize simple functionals that just
124 // call into aten functions
125 std::vector<std::string> _single_input_general_value_call_funcs = {
126     "avg_pool1d",
127     "avg_pool2d",
128     "avg_pool3d",
129     "adaptive_avg_pool1d",
130     "adaptive_avg_pool2d",
131     "adaptive_avg_pool3d",
132     "interpolate",
133     "upsample",
134     "upsample_bilinear",
135     "upsample_nearest",
136     "hardtanh",
137     "leaky_relu",
138 };
139 
140 // Theses are aten functions for ops that doesn't require observation and
141 // have a single input Tensor
142 // Also these ops do computation on the value of Tensor
143 // e.g. `aten::avg_pool2d(%input_tensor, ...)`
144 std::vector<std::string> _single_input_general_value_aten_funcs = {
145     "avg_pool1d",
146     "avg_pool2d",
147     "avg_pool3d",
148     "adaptive_avg_pool1d",
149     "adaptive_avg_pool2d",
150     "adaptive_avg_pool3d",
151     "mean",
152     "upsample_nearest1d",
153     "upsample_nearest2d",
154     "upsample_nearest3d",
155     "upsample_linear1d",
156     "upsample_bilinear2d",
157     "upsample_trilinear3d",
158     "upsample_bicubic2d",
159     "clamp",
160     // "clamp_",  // Enable when quantized `clamp_` is ready
161     "hardtanh",
162     "hardtanh_",
163     "leaky_relu",
164     "leaky_relu_",
165 };
166 
167 std::vector<std::string> _clamp_funcs = {
168     "hardtanh",
169     "hardtanh_",
170     "clamp",
171     // "clamp_",  // Enable when quantized `clamp_` is ready
172 };
173 
174 const float _asym_scale = 1.0f / 256.0f;
175 const int _asym_zero_point = 0;
176 const float _sym_scale = 2.0f / 256.0f;
177 const int _sym_zero_point = 128;
178 // quantization parameters for ops with range 0 to 1
179 // for example: aten/src/ATen/native/quantized/cpu/qsigmoid.cpp
180 std::tuple<c10::QScheme, QParamVector> _per_tensor_asym_qparam =
181     std::make_tuple(
182         c10::kPerTensorAffine,
183         QParamVector(
184             {std::make_pair(".scale", IValue(_asym_scale)),
185              std::make_pair(".zero_point", IValue(_asym_zero_point)),
186              std::make_pair(".scalar_type", IValue(c10::kQUInt8))}));
187 
188 // quantization parameters for ops with range -1 to 1
189 // for example: aten/src/ATen/native/quantized/cpu/qtanh.cpp
190 std::tuple<c10::QScheme, QParamVector> _per_tensor_sym_qparam = std::make_tuple(
191     c10::kPerTensorAffine,
192     QParamVector(
193         {std::make_pair(".scale", IValue(_sym_scale)),
194          std::make_pair(".zero_point", IValue(_sym_zero_point)),
195          std::make_pair(".scalar_type", IValue(c10::kQUInt8))}));
196 
197 // Map from aten op symbol to the quantization parameters
198 // for the ops with fixed quantization parameters
199 std::unordered_map<NodeKind, std::tuple<c10::QScheme, QParamVector>>
200     _fixed_qparams_map = {
201         {Symbol::aten("hardsigmoid"), _per_tensor_asym_qparam},
202         {Symbol::aten("hardsigmoid_"), _per_tensor_asym_qparam},
203         {Symbol::aten("sigmoid"), _per_tensor_asym_qparam},
204         {Symbol::aten("sigmoid_"), _per_tensor_asym_qparam},
205         {Symbol::aten("tanh"), _per_tensor_sym_qparam},
206         {Symbol::aten("tanh_"), _per_tensor_sym_qparam},
207 };
208 
209 // Special checks for ops that do not require observers for all input tensors.
210 // For each operator in this list observers are inserted for the input based
211 // on the index specified.
212 AtenFuncArgs _observe_inputs_aten_func = {};
213 CallFuncArgs _observe_inputs_call_func = {{"batch_norm", 1}};
214 
215 // Aten functions for getting tensor information
216 std::vector<std::string> _tensor_info_funcs = {"size", "len", "dim", "numel"};
217 
218 // Aten functions whose output will be quantized or not quantized depending
219 // on input tensor
220 std::vector<std::string> _propagate_quant_single_input_ops = {"cat"};
221 
222 // Rules are slightly different for binary ops like `aten::add`, for these ops,
223 // if both of the inputs are Tensor, we'll quantize the output only if both of
224 // the inputs are quantized
225 // if the second input is a Scalar, we'll only look at the first input to decide
226 // if we need to quantize the output
227 std::vector<std::string> _propagate_quant_binary_ops = {
228     "add",
229     "add_",
230     "mul",
231     "mul_"};
232 
233 // Check if `use` is an aten function of name `func_name` and if value
234 // `v` is the nth argument (if provided) of the function.
matchAtenFuncToUse(const Use & use,const std::string & func_name,std::optional<int> n)235 bool matchAtenFuncToUse(
236     const Use& use,
237     const std::string& func_name,
238     std::optional<int> n) {
239   Node* node = use.user;
240   return node->kind() == Symbol::aten(func_name) &&
241       (!n.has_value() || static_cast<size_t>(n.value()) == use.offset);
242 }
243 
matchCallFuncToUse(const Use & use,const std::string & func_name,std::optional<int> n)244 bool matchCallFuncToUse(
245     const Use& use,
246     const std::string& func_name,
247     std::optional<int> n) {
248   Node* node = use.user;
249   return node->kind() == prim::CallFunction &&
250       getFuncName(node->inputs()[0]) == func_name &&
251       (!n.has_value() || static_cast<size_t>(n.value()) == use.offset);
252 }
253 
254 // Check any use of `v` matches the aten function call
255 // or CallFunction patterns
matchArgPattern(Value * v,const AtenFuncArgs & aten_func_args,const CallFuncArgs & call_func_args)256 static bool matchArgPattern(
257     Value* v,
258     const AtenFuncArgs& aten_func_args,
259     const CallFuncArgs& call_func_args) {
260   for (const Use& u : v->uses()) {
261     for (const auto& func_arg : aten_func_args) {
262       if (matchAtenFuncToUse(u, func_arg.func_name, func_arg.arg_index)) {
263         return true;
264       }
265     }
266 
267     for (const auto& func_arg : call_func_args) {
268       if (matchCallFuncToUse(u, func_arg.func_name, func_arg.arg_index)) {
269         return true;
270       }
271     }
272   }
273   return false;
274 }
275 
276 // TODO add other op signatures.
isWeight(Value * v)277 bool isWeight(Value* v) {
278   bool result = matchArgPattern(
279       v,
280       // ate::embedding_bag(%weight, %input, %offsets, %scale_grad_by_freq,
281       // %mode_enum, %sparse, %per_sample_weights, %include_last_offset)
282       AtenFuncArgs(
283           {{"conv1d", 1},
284            {"conv2d", 1},
285            {"conv3d", 1},
286            {"conv_transpose1d", 1},
287            {"conv_transpose2d", 1},
288            {"linear", 1},
289            {"embedding_bag", 0}}),
290       // embedding_bag - prim::CallFunction(%func, %input.1, %weight,
291       // %offsets.1, %max_norm, %norm_type, %scale_grad_by_freq, %mode, %sparse,
292       // %per_sample_weights.1, %include_last_offset)
293       CallFuncArgs({{"linear", 2}, {"embedding_bag", 2}}));
294   return result;
295 }
296 
isBiasOfConvOrLinear(Value * v)297 bool isBiasOfConvOrLinear(Value* v) {
298   bool result = matchArgPattern(
299       v,
300       AtenFuncArgs(
301           {{"conv1d", 2},
302            {"conv2d", 2},
303            {"conv3d", 2},
304            {"conv_transpose1d", 2},
305            {"conv_transpose2d", 2},
306            {"linear", 2}}),
307       CallFuncArgs({{"linear", 3}}));
308   return result;
309 }
310 
isEmbeddingBagNonInput(Value * v)311 bool isEmbeddingBagNonInput(Value* v) {
312   bool result = matchArgPattern(
313       v,
314       AtenFuncArgs({{"embedding_bag", 2}, {"embedding_bag", 6}}),
315       CallFuncArgs({}));
316   return result;
317 }
318 
getClampScalarInputUse(Value * v)319 std::optional<Use> getClampScalarInputUse(Value* v) {
320   for (const auto& use : v->uses()) {
321     for (const auto& aten_func : _clamp_funcs) {
322       if (matchAtenFuncToUse(use, aten_func, 1) ||
323           matchAtenFuncToUse(use, aten_func, 2)) {
324         return use;
325       }
326     }
327   }
328   return std::nullopt;
329 }
330 
cloneMethod(Module & module,const std::string & orig_method_name,const std::string & new_method_name)331 void cloneMethod(
332     Module& module,
333     const std::string& orig_method_name,
334     const std::string& new_method_name) {
335   const Function& method = module.get_method(orig_method_name).function();
336   auto graph = toGraphFunction(method).graph()->copy();
337   const auto& schema = method.getSchema();
338   const auto this_method_name =
339       c10::QualifiedName(*module.type()->name(), new_method_name);
340   auto copied = module._ivalue()->compilation_unit()->create_function(
341       this_method_name, std::move(graph));
342   module.type()->addMethod(copied);
343   copied->setSchema(schema);
344 }
345 
getPassThroughInputs(Value * v)346 std::vector<Value*> getPassThroughInputs(Value* v) {
347   Node* n = v->node();
348   if (isSingleInputGeneralCallFunction(n)) {
349     return {n->input(1)};
350   } else if (
351       isSingleInputGeneralAtenFunction(n) ||
352       (n->kind() == Symbol::aten("sort") && v->offset() == 0)) {
353     return {n->input(0)};
354   } else if (n->kind() == prim::If && n->outputs().size() == 1) {
355     std::vector<Value*> inputs;
356     for (Block* subblock : n->blocks()) {
357       if (alwaysRaisesException(subblock)) {
358         continue;
359       }
360       auto* output = subblock->outputs()[0];
361       inputs.push_back(output);
362     }
363     return inputs;
364   } else if (n->kind() == prim::ListUnpack || n->kind() == prim::TupleUnpack) {
365     // only propagate dequantize for Tensor
366     if (v->type()->isSubtypeOf(*TensorType::get())) {
367       return {n->input(0)};
368     } else {
369       return {};
370     }
371   } else if (
372       n->kind() == prim::ListConstruct &&
373       v->type()->isSubtypeOf(*ListType::ofTensors())) {
374     std::vector<Value*> inputs;
375     for (auto* v : n->inputs()) {
376       inputs.push_back(v);
377     }
378     return inputs;
379   } else if (n->kind() == prim::TupleConstruct) {
380     std::vector<Value*> inputs;
381     for (auto* input : n->inputs()) {
382       if (input->type()->isSubtypeOf(*TensorType::get())) {
383         inputs.push_back(input);
384       }
385     }
386     return inputs;
387   } else if (n->kind() == Symbol::aten("append")) {
388     std::vector<Value*> inputs;
389     for (auto* input : n->inputs()) {
390       inputs.push_back(input);
391     }
392     return inputs;
393   }
394 
395   return {};
396 }
397 
toAtenSymbol(const std::vector<std::string> & func_names)398 static std::vector<NodeKind> toAtenSymbol(
399     const std::vector<std::string>& func_names) {
400   std::vector<NodeKind> symbols;
401   std::transform(
402       func_names.begin(),
403       func_names.end(),
404       std::back_inserter(symbols),
405       Symbol::aten);
406   return symbols;
407 }
408 
isAtenFunc(Node * n,const std::vector<NodeKind> & aten_funcs)409 static bool isAtenFunc(Node* n, const std::vector<NodeKind>& aten_funcs) {
410   return std::find(aten_funcs.begin(), aten_funcs.end(), n->kind()) !=
411       aten_funcs.end();
412 }
413 
isAtenFunc(Node * n,const std::vector<std::string> & aten_funcs)414 static bool isAtenFunc(Node* n, const std::vector<std::string>& aten_funcs) {
415   const auto& symbols = toAtenSymbol(aten_funcs);
416   return isAtenFunc(n, symbols);
417 }
418 
419 // TODO: factor out isCallFunc
isFunctionNode(Node * n,const std::vector<std::string> & call_funcs,const std::vector<std::string> & aten_funcs)420 static bool isFunctionNode(
421     Node* n,
422     const std::vector<std::string>& call_funcs,
423     const std::vector<std::string>& aten_funcs) {
424   bool is_func_node = isAtenFunc(n, aten_funcs);
425   if (n->kind() == prim::CallFunction) {
426     auto func_name = getFuncName(n->inputs()[0]);
427     is_func_node |=
428         std::find(call_funcs.begin(), call_funcs.end(), func_name) !=
429         call_funcs.end();
430   }
431   return is_func_node;
432 }
433 
isSingleInputGeneralShapeAtenFunction(Node * n)434 bool isSingleInputGeneralShapeAtenFunction(Node* n) {
435   return isAtenFunc(n, _single_input_general_shape_aten_funcs);
436 }
437 
isSingleInputGeneralValueAtenFunction(Node * n)438 bool isSingleInputGeneralValueAtenFunction(Node* n) {
439   return isAtenFunc(n, _single_input_general_value_aten_funcs) ||
440       isBinaryOpWithScalarInput(n);
441 }
442 
isSingleInputGeneralCallFunction(Node * n)443 bool isSingleInputGeneralCallFunction(Node* n) {
444   static std::vector<std::string> single_input_general_call_funcs;
445   std::copy(
446       _single_input_general_shape_call_funcs.begin(),
447       _single_input_general_shape_call_funcs.end(),
448       std::back_inserter(single_input_general_call_funcs));
449   std::copy(
450       _single_input_general_value_call_funcs.begin(),
451       _single_input_general_value_call_funcs.end(),
452       std::back_inserter(single_input_general_call_funcs));
453   return isFunctionNode(
454       n,
455       /* call_funcs = */ single_input_general_call_funcs,
456       /* aten_funcs = */ {});
457 }
458 
isSingleInputGeneralAtenFunction(Node * n)459 bool isSingleInputGeneralAtenFunction(Node* n) {
460   static std::vector<NodeKind> fixed_qparams_aten_funcs;
461   std::transform(
462       _fixed_qparams_map.begin(),
463       _fixed_qparams_map.end(),
464       std::back_inserter(fixed_qparams_aten_funcs),
465       [](auto pair) { return pair.first; });
466 
467   return isSingleInputGeneralValueAtenFunction(n) ||
468       isSingleInputGeneralShapeAtenFunction(n) ||
469       isAtenFunc(n, fixed_qparams_aten_funcs);
470 }
471 
isClamp(Node * n)472 bool isClamp(Node* n) {
473   return isAtenFunc(n, _clamp_funcs);
474 }
475 
isTensorInfoNode(Node * n)476 bool isTensorInfoNode(Node* n) {
477   return isAtenFunc(n, _tensor_info_funcs);
478 }
479 
isPropagateQuantSingleInputOp(Node * n)480 bool isPropagateQuantSingleInputOp(Node* n) {
481   return isAtenFunc(n, _propagate_quant_single_input_ops);
482 }
483 
isPropagateQuantBinaryOp(Node * n)484 bool isPropagateQuantBinaryOp(Node* n) {
485   return isAtenFunc(n, _propagate_quant_binary_ops);
486 }
487 
isPropagateQuantOp(Node * n)488 bool isPropagateQuantOp(Node* n) {
489   return isPropagateQuantSingleInputOp(n) || isPropagateQuantBinaryOp(n);
490 }
491 
isBinaryOpWithScalarInput(Node * n)492 bool isBinaryOpWithScalarInput(Node* n) {
493   return isPropagateQuantBinaryOp(n) && isScalar(n->input(1));
494 }
495 
getFixedQParams(Node * n)496 std::optional<std::tuple<c10::QScheme, QParamVector>> getFixedQParams(Node* n) {
497   static std::vector<NodeKind> fixed_qparam_funcs;
498   std::transform(
499       _fixed_qparams_map.begin(),
500       _fixed_qparams_map.end(),
501       std::back_inserter(fixed_qparam_funcs),
502       [](const auto& pair) { return pair.first; });
503   if (isAtenFunc(n, fixed_qparam_funcs)) {
504     return _fixed_qparams_map.at(n->kind());
505   }
506   return std::nullopt;
507 }
508 
userDefinedCallFunction(Node * n)509 bool userDefinedCallFunction(Node* n) {
510   return n->kind() == prim::CallFunction &&
511       !isSingleInputGeneralCallFunction(n) &&
512       !isFunctionNode(n, _static_quantizable_call_funcs, {});
513 }
514 
isWeightOnlyStaticQuantOp(Node * n)515 bool isWeightOnlyStaticQuantOp(Node* n) {
516   return isFunctionNode(
517       n,
518       _static_weight_only_quant_call_funcs,
519       _static_weight_only_quant_aten_funcs);
520 }
521 
nodeQuantizable(Node * n,QuantType quant_type)522 bool nodeQuantizable(Node* n, QuantType quant_type) {
523   bool is_dynamic = quant_type == QuantType::DYNAMIC;
524   return isFunctionNode(
525       n,
526       /* call_funcs = */
527       is_dynamic ? _dynamic_quantizable_call_funcs
528                  : _static_quantizable_call_funcs,
529       /* aten_funcs = */
530       is_dynamic ? _dynamic_quantizable_aten_funcs
531                  : _static_quantizable_aten_funcs);
532 }
533 
useQuantizable(const Use & use,QuantType quant_type)534 bool useQuantizable(const Use& use, QuantType quant_type) {
535   if (quant_type == QuantType::STATIC) {
536     for (const auto& func_input : _observe_inputs_aten_func) {
537       if (matchAtenFuncToUse(use, func_input.func_name, std::nullopt)) {
538         return use.offset == static_cast<size_t>(func_input.arg_index);
539       }
540     }
541 
542     for (const auto& func_input : _observe_inputs_call_func) {
543       if (matchCallFuncToUse(use, func_input.func_name, std::nullopt)) {
544         return use.offset == static_cast<size_t>(func_input.arg_index);
545       }
546     }
547   }
548 
549   return nodeQuantizable(use.user, quant_type);
550 }
551 
getCallFunctionGraph(Node * n)552 std::shared_ptr<Graph> getCallFunctionGraph(Node* n) {
553   auto* func_node = n->input(0)->node();
554   auto func = func_node->output()->type()->expectRef<FunctionType>().function();
555   auto graphFunc = tryToGraphFunction(*func);
556   TORCH_CHECK(graphFunc, "Quantization only works for graph function");
557   return graphFunc->graph();
558 }
559 
560 // Block helper functions
alwaysRaisesException(Block * block)561 bool alwaysRaisesException(Block* block) {
562   for (Node* n : block->nodes()) {
563     if (n->kind() == prim::RaiseException) {
564       return true;
565     }
566     if (n->kind() == prim::If) {
567       bool exception = true;
568       for (Block* b : n->blocks()) {
569         exception &= alwaysRaisesException(b);
570       }
571       if (exception) {
572         return true;
573       }
574     }
575   }
576   return false;
577 }
578 
579 // Check if a value in the graph is a Scalar value
isScalar(Value * v)580 bool isScalar(Value* v) {
581   auto iv = toIValue(v);
582   return v->type()->isSubtypeOf(*NumberType::get()) ||
583       (v->type()->isSubtypeOf(*TensorType::get()) && iv && iv->isTensor() &&
584        iv->toTensor().dim() == 0);
585 }
586 
587 // =================== Graph/Module analysis helper functions ============
588 // Check if value is the input of the graph
hitGraphInput(Value * value)589 bool hitGraphInput(Value* value) {
590   Graph* graph = value->owningGraph();
591   const auto& inputs = graph->inputs();
592   return std::find(inputs.begin(), inputs.end(), value) != inputs.end();
593 }
594 
595 // Get the module access path for a Value representing a module instance
596 // by tracing back the GetAttr nodes and recording all the attribute
597 // names along the way.
598 // Assuming 'self.sub.basic_block.conv1',
599 // Input1: Value instance of conv1
600 // Input2: Value instance of self
601 // Output: ['sub', 'basic_block', 'conv1']
getModuleAccessPath(Value * instance,Value * self)602 std::vector<std::string> getModuleAccessPath(Value* instance, Value* self) {
603   std::vector<std::string> path;
604   // Iterator to traverse back the GetAttr calls
605   Value* iter = instance;
606   // trace back the instance to recover the path of the submodule
607   while (!hitGraphInput(iter) && iter->node()->kind() == prim::GetAttr) {
608     Node* get_attr = iter->node();
609     // record the name of GetAttr
610     path.push_back(get_attr->s(attr::name));
611     // trace back the chain of GetAttr
612     iter = get_attr->inputs()[0];
613   }
614   TORCH_CHECK(
615       iter == self,
616       "Can't handle the access pattern of GetAttr "
617       " in getModuleAccessPath, traced back to:",
618       iter->debugName(),
619       " which is not self:",
620       self->debugName());
621   std::reverse(path.begin(), path.end());
622   return path;
623 }
624 
625 // Assuming self.foo.bar.conv1,
626 // Input1: Module instance of self
627 // Input2: ['foo', 'bar', 'conv1']
628 // Output: Module instance of conv1
findChildModule(const Module & module,const std::vector<std::string> & path)629 Module findChildModule(
630     const Module& module,
631     const std::vector<std::string>& path) {
632   Module m = module;
633   for (const auto& p : path) {
634     m = m.attr(p).toModule();
635   }
636   return m;
637 }
638 
getInvokedModule(Module & module,Node * n,Value * self)639 Module getInvokedModule(Module& module, Node* n, Value* self) {
640   auto* instance = n->inputs()[0];
641   auto path = getModuleAccessPath(instance, self);
642   return findChildModule(module, path);
643 }
644 
getInvokedModuleOpt(const Module & module,Node * n,Value * self)645 std::optional<Module> getInvokedModuleOpt(
646     const Module& module,
647     Node* n,
648     Value* self) {
649   auto* instance = n->inputs()[0];
650   auto path = getModuleAccessPath(instance, self);
651   Module m = module;
652   for (const auto& p : path) {
653     if (m.attr(p).isModule()) {
654       m = m.attr(p).toModule();
655     } else {
656       return std::nullopt;
657     }
658   }
659   return m;
660 }
661 
662 // ==================== filter functions for matches ==============
is_int_constant(const Match & match,const std::unordered_map<std::string,Value * > & vmap,const std::string & vname,int value)663 bool is_int_constant(
664     const Match& match,
665     const std::unordered_map<std::string, Value*>& vmap,
666     const std::string& vname,
667     int value) {
668   const auto& match_vmap = match.values_map;
669   auto v = toIValue(match_vmap.at(vmap.at(vname)));
670   return v && v->isInt() && v->toInt() == value;
671 }
672 
is_functional(const Match & match,const std::unordered_map<std::string,Value * > & vmap,const std::string & vname,const std::string & functional)673 static bool is_functional(
674     const Match& match,
675     const std::unordered_map<std::string, Value*>& vmap,
676     const std::string& vname,
677     const std::string& functional) {
678   const auto& match_vmap = match.values_map;
679   Value* v = match_vmap.at(vmap.at(vname));
680   return v->type()->cast<FunctionType>() && getFuncName(v) == functional;
681 }
682 
removeTorchMangle(const std::string & orig_name)683 std::string removeTorchMangle(const std::string& orig_name) {
684   static std::regex mangle_re("\\.___torch_mangle_\\d+");
685   auto qualified_name = std::regex_replace(orig_name, mangle_re, "");
686   return qualified_name;
687 }
688 
getModuleName(Value * value)689 std::optional<std::string> getModuleName(Value* value) {
690   auto type = value->type()->cast<ClassType>();
691   if (type && type->name()) {
692     return removeTorchMangle(type->name()->qualifiedName());
693   }
694   return std::nullopt;
695 }
696 
is_module(const Match & match,const std::unordered_map<std::string,Value * > & vmap,const std::string & vname,const std::string & module_qualified_name)697 static bool is_module(
698     const Match& match,
699     const std::unordered_map<std::string, Value*>& vmap,
700     const std::string& vname,
701     const std::string& module_qualified_name) {
702   const auto& match_vmap = match.values_map;
703   Value* v = match_vmap.at(vmap.at(vname));
704   auto module_name = getModuleName(v);
705   if (module_name.has_value()) {
706     return module_name.value() == module_qualified_name;
707   }
708   return false;
709 };
710 
aten_add_alpha_is_one(const Match & match,const std::unordered_map<std::string,Value * > & vmap)711 bool aten_add_alpha_is_one(
712     const Match& match,
713     const std::unordered_map<std::string, Value*>& vmap) {
714   return is_int_constant(match, vmap, "alpha", 1);
715 }
716 
is_functional_relu(const Match & match,const std::unordered_map<std::string,Value * > & vmap)717 bool is_functional_relu(
718     const Match& match,
719     const std::unordered_map<std::string, Value*>& vmap) {
720   return is_functional(match, vmap, "relu", "relu");
721 }
722 
is_relu_module(const Match & match,const std::unordered_map<std::string,Value * > & vmap)723 bool is_relu_module(
724     const Match& match,
725     const std::unordered_map<std::string, Value*>& vmap) {
726   return is_module(
727       match, vmap, "relu", "__torch__.torch.nn.modules.activation.ReLU");
728 }
729 
is_linear_module(const Match & match,const std::unordered_map<std::string,Value * > & vmap)730 bool is_linear_module(
731     const Match& match,
732     const std::unordered_map<std::string, Value*>& vmap) {
733   return is_module(
734       match, vmap, "linear", "__torch__.torch.nn.modules.linear.Linear");
735 }
736 
is_conv1d_module(const Match & match,const std::unordered_map<std::string,Value * > & vmap)737 bool is_conv1d_module(
738     const Match& match,
739     const std::unordered_map<std::string, Value*>& vmap) {
740   return is_module(
741       match, vmap, "conv", "__torch__.torch.nn.modules.conv.Conv1d");
742 }
743 
is_conv2d_module(const Match & match,const std::unordered_map<std::string,Value * > & vmap)744 bool is_conv2d_module(
745     const Match& match,
746     const std::unordered_map<std::string, Value*>& vmap) {
747   return is_module(
748       match, vmap, "conv", "__torch__.torch.nn.modules.conv.Conv2d");
749 }
750 
is_conv3d_module(const Match & match,const std::unordered_map<std::string,Value * > & vmap)751 bool is_conv3d_module(
752     const Match& match,
753     const std::unordered_map<std::string, Value*>& vmap) {
754   return is_module(
755       match, vmap, "conv", "__torch__.torch.nn.modules.conv.Conv3d");
756 }
757 
is_conv_transpose1d_module(const Match & match,const std::unordered_map<std::string,Value * > & vmap)758 bool is_conv_transpose1d_module(
759     const Match& match,
760     const std::unordered_map<std::string, Value*>& vmap) {
761   return is_module(
762       match, vmap, "conv", "__torch__.torch.nn.modules.conv.ConvTranspose1d");
763 }
764 
is_conv_transpose2d_module(const Match & match,const std::unordered_map<std::string,Value * > & vmap)765 bool is_conv_transpose2d_module(
766     const Match& match,
767     const std::unordered_map<std::string, Value*>& vmap) {
768   return is_module(
769       match, vmap, "conv", "__torch__.torch.nn.modules.conv.ConvTranspose2d");
770 }
771 
is_batchnorm2d_module(const Match & match,const std::unordered_map<std::string,Value * > & vmap)772 bool is_batchnorm2d_module(
773     const Match& match,
774     const std::unordered_map<std::string, Value*>& vmap) {
775   bool regnorm = is_module(
776       match,
777       vmap,
778       "batchnorm",
779       "__torch__.torch.nn.modules.batchnorm.BatchNorm2d");
780   bool naivenorm = is_module(
781       match,
782       vmap,
783       "batchnorm",
784       "__torch__.mobile_cv.arch.layers.batch_norm.NaiveSyncBatchNorm");
785   return (regnorm || naivenorm);
786 }
787 
is_batchnorm3d_module(const Match & match,const std::unordered_map<std::string,Value * > & vmap)788 bool is_batchnorm3d_module(
789     const Match& match,
790     const std::unordered_map<std::string, Value*>& vmap) {
791   return is_module(
792       match,
793       vmap,
794       "batchnorm",
795       "__torch__.torch.nn.modules.batchnorm.BatchNorm3d");
796 }
797 
798 } // namespace jit
799 } // namespace torch
800