1 #include <torch/csrc/jit/passes/quantization/helper.h>
2
3 #include <torch/csrc/jit/api/function_impl.h>
4 #include <torch/csrc/jit/passes/graph_rewrite_helper.h>
5
6 #include <utility>
7
8 namespace torch {
9 namespace jit {
10
11 using graph_rewrite_helper::getFuncName;
12
13 struct FuncArg {
14 std::string func_name;
15 int arg_index;
16 };
17
18 using AtenFuncArgs = std::vector<FuncArg>;
19 using CallFuncArgs = std::vector<FuncArg>;
20
21 // Lists of allowed quantizable operators
22 std::vector<std::string> _static_quantizable_call_funcs = {
23 "conv2d",
24 "linear",
25 "batch_norm",
26 "hardswish",
27 "elu",
28 "celu",
29 "layer_norm",
30 "group_norm",
31 "instance_norm",
32 "embedding_bag",
33 };
34
35 std::vector<std::string> _static_quantizable_aten_funcs = {
36 "conv1d",
37 "conv2d",
38 "conv3d",
39 "conv_transpose1d",
40 "conv_transpose2d",
41 "linear",
42 "hardswish",
43 "hardswish_",
44 "elu",
45 "elu_",
46 "celu",
47 "celu_",
48 "batch_norm",
49 "layer_norm",
50 "group_norm",
51 "instance_norm",
52 "embedding_bag",
53 };
54
55 std::vector<std::string> _dynamic_quantizable_call_funcs = {
56 "linear",
57 };
58
59 std::vector<std::string> _dynamic_quantizable_aten_funcs = {
60 "linear",
61 };
62
63 std::vector<std::string> _static_weight_only_quant_aten_funcs = {
64 "embedding_bag",
65 };
66 std::vector<std::string> _static_weight_only_quant_call_funcs = {
67 "embedding_bag",
68 };
69
70 // These are the prim::CallFunctions that doesn't require observation and
71 // have a single input Tensor
72 // example: `prim::CallFunction(%dropout, %input_tensor, ...)
73 // so we propagate observed property from %input_tensor to the
74 // output of the `prim::CallFunction`
75 // Also these ops doesn't do computation on the value of Tensor, the
76 // operation only depends on the shape of the Tensor
77 std::vector<std::string> _single_input_general_shape_call_funcs = {
78 "_max_pool1d",
79 "_max_pool2d",
80 "_max_pool3d",
81 "dropout",
82 "relu",
83 };
84
85 // Similar to prim::CallFunctions, there are aten ops that doesn't
86 // require observation and have a single input Tensor
87 // Also these ops doesn't do computation on the value of Tensor, the
88 // operation only depends on the shape of the Tensor
89 // e.g. `aten::flatten(%input_tensor, ...)`
90 std::vector<std::string> _single_input_general_shape_aten_funcs = {
91 "max_pool1d",
92 "max_pool2d",
93 "max_pool3d",
94 "flatten",
95 "max",
96 "min",
97 "dropout",
98 "reshape",
99 // Non-inplace resize is deprecated
100 "resize_",
101 "chunk",
102 "view",
103 "transpose",
104 "contiguous",
105 "permute",
106 "repeat",
107 "repeat_interleave",
108 "relu",
109 "relu_",
110 "squeeze",
111 "squeeze_",
112 "unsqueeze",
113 "unsqueeze_",
114 "detach",
115 "detach_",
116 "stack",
117 "__getitem__",
118 };
119
120 // Theses are prim::CallFunctions for ops that doesn't require observation and
121 // have a single input Tensor
122 // Also these ops do computation on the value of Tensor
123 // TODO: [Need verify] looks like we can quantize simple functionals that just
124 // call into aten functions
125 std::vector<std::string> _single_input_general_value_call_funcs = {
126 "avg_pool1d",
127 "avg_pool2d",
128 "avg_pool3d",
129 "adaptive_avg_pool1d",
130 "adaptive_avg_pool2d",
131 "adaptive_avg_pool3d",
132 "interpolate",
133 "upsample",
134 "upsample_bilinear",
135 "upsample_nearest",
136 "hardtanh",
137 "leaky_relu",
138 };
139
140 // Theses are aten functions for ops that doesn't require observation and
141 // have a single input Tensor
142 // Also these ops do computation on the value of Tensor
143 // e.g. `aten::avg_pool2d(%input_tensor, ...)`
144 std::vector<std::string> _single_input_general_value_aten_funcs = {
145 "avg_pool1d",
146 "avg_pool2d",
147 "avg_pool3d",
148 "adaptive_avg_pool1d",
149 "adaptive_avg_pool2d",
150 "adaptive_avg_pool3d",
151 "mean",
152 "upsample_nearest1d",
153 "upsample_nearest2d",
154 "upsample_nearest3d",
155 "upsample_linear1d",
156 "upsample_bilinear2d",
157 "upsample_trilinear3d",
158 "upsample_bicubic2d",
159 "clamp",
160 // "clamp_", // Enable when quantized `clamp_` is ready
161 "hardtanh",
162 "hardtanh_",
163 "leaky_relu",
164 "leaky_relu_",
165 };
166
167 std::vector<std::string> _clamp_funcs = {
168 "hardtanh",
169 "hardtanh_",
170 "clamp",
171 // "clamp_", // Enable when quantized `clamp_` is ready
172 };
173
174 const float _asym_scale = 1.0f / 256.0f;
175 const int _asym_zero_point = 0;
176 const float _sym_scale = 2.0f / 256.0f;
177 const int _sym_zero_point = 128;
178 // quantization parameters for ops with range 0 to 1
179 // for example: aten/src/ATen/native/quantized/cpu/qsigmoid.cpp
180 std::tuple<c10::QScheme, QParamVector> _per_tensor_asym_qparam =
181 std::make_tuple(
182 c10::kPerTensorAffine,
183 QParamVector(
184 {std::make_pair(".scale", IValue(_asym_scale)),
185 std::make_pair(".zero_point", IValue(_asym_zero_point)),
186 std::make_pair(".scalar_type", IValue(c10::kQUInt8))}));
187
188 // quantization parameters for ops with range -1 to 1
189 // for example: aten/src/ATen/native/quantized/cpu/qtanh.cpp
190 std::tuple<c10::QScheme, QParamVector> _per_tensor_sym_qparam = std::make_tuple(
191 c10::kPerTensorAffine,
192 QParamVector(
193 {std::make_pair(".scale", IValue(_sym_scale)),
194 std::make_pair(".zero_point", IValue(_sym_zero_point)),
195 std::make_pair(".scalar_type", IValue(c10::kQUInt8))}));
196
197 // Map from aten op symbol to the quantization parameters
198 // for the ops with fixed quantization parameters
199 std::unordered_map<NodeKind, std::tuple<c10::QScheme, QParamVector>>
200 _fixed_qparams_map = {
201 {Symbol::aten("hardsigmoid"), _per_tensor_asym_qparam},
202 {Symbol::aten("hardsigmoid_"), _per_tensor_asym_qparam},
203 {Symbol::aten("sigmoid"), _per_tensor_asym_qparam},
204 {Symbol::aten("sigmoid_"), _per_tensor_asym_qparam},
205 {Symbol::aten("tanh"), _per_tensor_sym_qparam},
206 {Symbol::aten("tanh_"), _per_tensor_sym_qparam},
207 };
208
209 // Special checks for ops that do not require observers for all input tensors.
210 // For each operator in this list observers are inserted for the input based
211 // on the index specified.
212 AtenFuncArgs _observe_inputs_aten_func = {};
213 CallFuncArgs _observe_inputs_call_func = {{"batch_norm", 1}};
214
215 // Aten functions for getting tensor information
216 std::vector<std::string> _tensor_info_funcs = {"size", "len", "dim", "numel"};
217
218 // Aten functions whose output will be quantized or not quantized depending
219 // on input tensor
220 std::vector<std::string> _propagate_quant_single_input_ops = {"cat"};
221
222 // Rules are slightly different for binary ops like `aten::add`, for these ops,
223 // if both of the inputs are Tensor, we'll quantize the output only if both of
224 // the inputs are quantized
225 // if the second input is a Scalar, we'll only look at the first input to decide
226 // if we need to quantize the output
227 std::vector<std::string> _propagate_quant_binary_ops = {
228 "add",
229 "add_",
230 "mul",
231 "mul_"};
232
233 // Check if `use` is an aten function of name `func_name` and if value
234 // `v` is the nth argument (if provided) of the function.
matchAtenFuncToUse(const Use & use,const std::string & func_name,std::optional<int> n)235 bool matchAtenFuncToUse(
236 const Use& use,
237 const std::string& func_name,
238 std::optional<int> n) {
239 Node* node = use.user;
240 return node->kind() == Symbol::aten(func_name) &&
241 (!n.has_value() || static_cast<size_t>(n.value()) == use.offset);
242 }
243
matchCallFuncToUse(const Use & use,const std::string & func_name,std::optional<int> n)244 bool matchCallFuncToUse(
245 const Use& use,
246 const std::string& func_name,
247 std::optional<int> n) {
248 Node* node = use.user;
249 return node->kind() == prim::CallFunction &&
250 getFuncName(node->inputs()[0]) == func_name &&
251 (!n.has_value() || static_cast<size_t>(n.value()) == use.offset);
252 }
253
254 // Check any use of `v` matches the aten function call
255 // or CallFunction patterns
matchArgPattern(Value * v,const AtenFuncArgs & aten_func_args,const CallFuncArgs & call_func_args)256 static bool matchArgPattern(
257 Value* v,
258 const AtenFuncArgs& aten_func_args,
259 const CallFuncArgs& call_func_args) {
260 for (const Use& u : v->uses()) {
261 for (const auto& func_arg : aten_func_args) {
262 if (matchAtenFuncToUse(u, func_arg.func_name, func_arg.arg_index)) {
263 return true;
264 }
265 }
266
267 for (const auto& func_arg : call_func_args) {
268 if (matchCallFuncToUse(u, func_arg.func_name, func_arg.arg_index)) {
269 return true;
270 }
271 }
272 }
273 return false;
274 }
275
276 // TODO add other op signatures.
isWeight(Value * v)277 bool isWeight(Value* v) {
278 bool result = matchArgPattern(
279 v,
280 // ate::embedding_bag(%weight, %input, %offsets, %scale_grad_by_freq,
281 // %mode_enum, %sparse, %per_sample_weights, %include_last_offset)
282 AtenFuncArgs(
283 {{"conv1d", 1},
284 {"conv2d", 1},
285 {"conv3d", 1},
286 {"conv_transpose1d", 1},
287 {"conv_transpose2d", 1},
288 {"linear", 1},
289 {"embedding_bag", 0}}),
290 // embedding_bag - prim::CallFunction(%func, %input.1, %weight,
291 // %offsets.1, %max_norm, %norm_type, %scale_grad_by_freq, %mode, %sparse,
292 // %per_sample_weights.1, %include_last_offset)
293 CallFuncArgs({{"linear", 2}, {"embedding_bag", 2}}));
294 return result;
295 }
296
isBiasOfConvOrLinear(Value * v)297 bool isBiasOfConvOrLinear(Value* v) {
298 bool result = matchArgPattern(
299 v,
300 AtenFuncArgs(
301 {{"conv1d", 2},
302 {"conv2d", 2},
303 {"conv3d", 2},
304 {"conv_transpose1d", 2},
305 {"conv_transpose2d", 2},
306 {"linear", 2}}),
307 CallFuncArgs({{"linear", 3}}));
308 return result;
309 }
310
isEmbeddingBagNonInput(Value * v)311 bool isEmbeddingBagNonInput(Value* v) {
312 bool result = matchArgPattern(
313 v,
314 AtenFuncArgs({{"embedding_bag", 2}, {"embedding_bag", 6}}),
315 CallFuncArgs({}));
316 return result;
317 }
318
getClampScalarInputUse(Value * v)319 std::optional<Use> getClampScalarInputUse(Value* v) {
320 for (const auto& use : v->uses()) {
321 for (const auto& aten_func : _clamp_funcs) {
322 if (matchAtenFuncToUse(use, aten_func, 1) ||
323 matchAtenFuncToUse(use, aten_func, 2)) {
324 return use;
325 }
326 }
327 }
328 return std::nullopt;
329 }
330
cloneMethod(Module & module,const std::string & orig_method_name,const std::string & new_method_name)331 void cloneMethod(
332 Module& module,
333 const std::string& orig_method_name,
334 const std::string& new_method_name) {
335 const Function& method = module.get_method(orig_method_name).function();
336 auto graph = toGraphFunction(method).graph()->copy();
337 const auto& schema = method.getSchema();
338 const auto this_method_name =
339 c10::QualifiedName(*module.type()->name(), new_method_name);
340 auto copied = module._ivalue()->compilation_unit()->create_function(
341 this_method_name, std::move(graph));
342 module.type()->addMethod(copied);
343 copied->setSchema(schema);
344 }
345
getPassThroughInputs(Value * v)346 std::vector<Value*> getPassThroughInputs(Value* v) {
347 Node* n = v->node();
348 if (isSingleInputGeneralCallFunction(n)) {
349 return {n->input(1)};
350 } else if (
351 isSingleInputGeneralAtenFunction(n) ||
352 (n->kind() == Symbol::aten("sort") && v->offset() == 0)) {
353 return {n->input(0)};
354 } else if (n->kind() == prim::If && n->outputs().size() == 1) {
355 std::vector<Value*> inputs;
356 for (Block* subblock : n->blocks()) {
357 if (alwaysRaisesException(subblock)) {
358 continue;
359 }
360 auto* output = subblock->outputs()[0];
361 inputs.push_back(output);
362 }
363 return inputs;
364 } else if (n->kind() == prim::ListUnpack || n->kind() == prim::TupleUnpack) {
365 // only propagate dequantize for Tensor
366 if (v->type()->isSubtypeOf(*TensorType::get())) {
367 return {n->input(0)};
368 } else {
369 return {};
370 }
371 } else if (
372 n->kind() == prim::ListConstruct &&
373 v->type()->isSubtypeOf(*ListType::ofTensors())) {
374 std::vector<Value*> inputs;
375 for (auto* v : n->inputs()) {
376 inputs.push_back(v);
377 }
378 return inputs;
379 } else if (n->kind() == prim::TupleConstruct) {
380 std::vector<Value*> inputs;
381 for (auto* input : n->inputs()) {
382 if (input->type()->isSubtypeOf(*TensorType::get())) {
383 inputs.push_back(input);
384 }
385 }
386 return inputs;
387 } else if (n->kind() == Symbol::aten("append")) {
388 std::vector<Value*> inputs;
389 for (auto* input : n->inputs()) {
390 inputs.push_back(input);
391 }
392 return inputs;
393 }
394
395 return {};
396 }
397
toAtenSymbol(const std::vector<std::string> & func_names)398 static std::vector<NodeKind> toAtenSymbol(
399 const std::vector<std::string>& func_names) {
400 std::vector<NodeKind> symbols;
401 std::transform(
402 func_names.begin(),
403 func_names.end(),
404 std::back_inserter(symbols),
405 Symbol::aten);
406 return symbols;
407 }
408
isAtenFunc(Node * n,const std::vector<NodeKind> & aten_funcs)409 static bool isAtenFunc(Node* n, const std::vector<NodeKind>& aten_funcs) {
410 return std::find(aten_funcs.begin(), aten_funcs.end(), n->kind()) !=
411 aten_funcs.end();
412 }
413
isAtenFunc(Node * n,const std::vector<std::string> & aten_funcs)414 static bool isAtenFunc(Node* n, const std::vector<std::string>& aten_funcs) {
415 const auto& symbols = toAtenSymbol(aten_funcs);
416 return isAtenFunc(n, symbols);
417 }
418
419 // TODO: factor out isCallFunc
isFunctionNode(Node * n,const std::vector<std::string> & call_funcs,const std::vector<std::string> & aten_funcs)420 static bool isFunctionNode(
421 Node* n,
422 const std::vector<std::string>& call_funcs,
423 const std::vector<std::string>& aten_funcs) {
424 bool is_func_node = isAtenFunc(n, aten_funcs);
425 if (n->kind() == prim::CallFunction) {
426 auto func_name = getFuncName(n->inputs()[0]);
427 is_func_node |=
428 std::find(call_funcs.begin(), call_funcs.end(), func_name) !=
429 call_funcs.end();
430 }
431 return is_func_node;
432 }
433
isSingleInputGeneralShapeAtenFunction(Node * n)434 bool isSingleInputGeneralShapeAtenFunction(Node* n) {
435 return isAtenFunc(n, _single_input_general_shape_aten_funcs);
436 }
437
isSingleInputGeneralValueAtenFunction(Node * n)438 bool isSingleInputGeneralValueAtenFunction(Node* n) {
439 return isAtenFunc(n, _single_input_general_value_aten_funcs) ||
440 isBinaryOpWithScalarInput(n);
441 }
442
isSingleInputGeneralCallFunction(Node * n)443 bool isSingleInputGeneralCallFunction(Node* n) {
444 static std::vector<std::string> single_input_general_call_funcs;
445 std::copy(
446 _single_input_general_shape_call_funcs.begin(),
447 _single_input_general_shape_call_funcs.end(),
448 std::back_inserter(single_input_general_call_funcs));
449 std::copy(
450 _single_input_general_value_call_funcs.begin(),
451 _single_input_general_value_call_funcs.end(),
452 std::back_inserter(single_input_general_call_funcs));
453 return isFunctionNode(
454 n,
455 /* call_funcs = */ single_input_general_call_funcs,
456 /* aten_funcs = */ {});
457 }
458
isSingleInputGeneralAtenFunction(Node * n)459 bool isSingleInputGeneralAtenFunction(Node* n) {
460 static std::vector<NodeKind> fixed_qparams_aten_funcs;
461 std::transform(
462 _fixed_qparams_map.begin(),
463 _fixed_qparams_map.end(),
464 std::back_inserter(fixed_qparams_aten_funcs),
465 [](auto pair) { return pair.first; });
466
467 return isSingleInputGeneralValueAtenFunction(n) ||
468 isSingleInputGeneralShapeAtenFunction(n) ||
469 isAtenFunc(n, fixed_qparams_aten_funcs);
470 }
471
isClamp(Node * n)472 bool isClamp(Node* n) {
473 return isAtenFunc(n, _clamp_funcs);
474 }
475
isTensorInfoNode(Node * n)476 bool isTensorInfoNode(Node* n) {
477 return isAtenFunc(n, _tensor_info_funcs);
478 }
479
isPropagateQuantSingleInputOp(Node * n)480 bool isPropagateQuantSingleInputOp(Node* n) {
481 return isAtenFunc(n, _propagate_quant_single_input_ops);
482 }
483
isPropagateQuantBinaryOp(Node * n)484 bool isPropagateQuantBinaryOp(Node* n) {
485 return isAtenFunc(n, _propagate_quant_binary_ops);
486 }
487
isPropagateQuantOp(Node * n)488 bool isPropagateQuantOp(Node* n) {
489 return isPropagateQuantSingleInputOp(n) || isPropagateQuantBinaryOp(n);
490 }
491
isBinaryOpWithScalarInput(Node * n)492 bool isBinaryOpWithScalarInput(Node* n) {
493 return isPropagateQuantBinaryOp(n) && isScalar(n->input(1));
494 }
495
getFixedQParams(Node * n)496 std::optional<std::tuple<c10::QScheme, QParamVector>> getFixedQParams(Node* n) {
497 static std::vector<NodeKind> fixed_qparam_funcs;
498 std::transform(
499 _fixed_qparams_map.begin(),
500 _fixed_qparams_map.end(),
501 std::back_inserter(fixed_qparam_funcs),
502 [](const auto& pair) { return pair.first; });
503 if (isAtenFunc(n, fixed_qparam_funcs)) {
504 return _fixed_qparams_map.at(n->kind());
505 }
506 return std::nullopt;
507 }
508
userDefinedCallFunction(Node * n)509 bool userDefinedCallFunction(Node* n) {
510 return n->kind() == prim::CallFunction &&
511 !isSingleInputGeneralCallFunction(n) &&
512 !isFunctionNode(n, _static_quantizable_call_funcs, {});
513 }
514
isWeightOnlyStaticQuantOp(Node * n)515 bool isWeightOnlyStaticQuantOp(Node* n) {
516 return isFunctionNode(
517 n,
518 _static_weight_only_quant_call_funcs,
519 _static_weight_only_quant_aten_funcs);
520 }
521
nodeQuantizable(Node * n,QuantType quant_type)522 bool nodeQuantizable(Node* n, QuantType quant_type) {
523 bool is_dynamic = quant_type == QuantType::DYNAMIC;
524 return isFunctionNode(
525 n,
526 /* call_funcs = */
527 is_dynamic ? _dynamic_quantizable_call_funcs
528 : _static_quantizable_call_funcs,
529 /* aten_funcs = */
530 is_dynamic ? _dynamic_quantizable_aten_funcs
531 : _static_quantizable_aten_funcs);
532 }
533
useQuantizable(const Use & use,QuantType quant_type)534 bool useQuantizable(const Use& use, QuantType quant_type) {
535 if (quant_type == QuantType::STATIC) {
536 for (const auto& func_input : _observe_inputs_aten_func) {
537 if (matchAtenFuncToUse(use, func_input.func_name, std::nullopt)) {
538 return use.offset == static_cast<size_t>(func_input.arg_index);
539 }
540 }
541
542 for (const auto& func_input : _observe_inputs_call_func) {
543 if (matchCallFuncToUse(use, func_input.func_name, std::nullopt)) {
544 return use.offset == static_cast<size_t>(func_input.arg_index);
545 }
546 }
547 }
548
549 return nodeQuantizable(use.user, quant_type);
550 }
551
getCallFunctionGraph(Node * n)552 std::shared_ptr<Graph> getCallFunctionGraph(Node* n) {
553 auto* func_node = n->input(0)->node();
554 auto func = func_node->output()->type()->expectRef<FunctionType>().function();
555 auto graphFunc = tryToGraphFunction(*func);
556 TORCH_CHECK(graphFunc, "Quantization only works for graph function");
557 return graphFunc->graph();
558 }
559
560 // Block helper functions
alwaysRaisesException(Block * block)561 bool alwaysRaisesException(Block* block) {
562 for (Node* n : block->nodes()) {
563 if (n->kind() == prim::RaiseException) {
564 return true;
565 }
566 if (n->kind() == prim::If) {
567 bool exception = true;
568 for (Block* b : n->blocks()) {
569 exception &= alwaysRaisesException(b);
570 }
571 if (exception) {
572 return true;
573 }
574 }
575 }
576 return false;
577 }
578
579 // Check if a value in the graph is a Scalar value
isScalar(Value * v)580 bool isScalar(Value* v) {
581 auto iv = toIValue(v);
582 return v->type()->isSubtypeOf(*NumberType::get()) ||
583 (v->type()->isSubtypeOf(*TensorType::get()) && iv && iv->isTensor() &&
584 iv->toTensor().dim() == 0);
585 }
586
587 // =================== Graph/Module analysis helper functions ============
588 // Check if value is the input of the graph
hitGraphInput(Value * value)589 bool hitGraphInput(Value* value) {
590 Graph* graph = value->owningGraph();
591 const auto& inputs = graph->inputs();
592 return std::find(inputs.begin(), inputs.end(), value) != inputs.end();
593 }
594
595 // Get the module access path for a Value representing a module instance
596 // by tracing back the GetAttr nodes and recording all the attribute
597 // names along the way.
598 // Assuming 'self.sub.basic_block.conv1',
599 // Input1: Value instance of conv1
600 // Input2: Value instance of self
601 // Output: ['sub', 'basic_block', 'conv1']
getModuleAccessPath(Value * instance,Value * self)602 std::vector<std::string> getModuleAccessPath(Value* instance, Value* self) {
603 std::vector<std::string> path;
604 // Iterator to traverse back the GetAttr calls
605 Value* iter = instance;
606 // trace back the instance to recover the path of the submodule
607 while (!hitGraphInput(iter) && iter->node()->kind() == prim::GetAttr) {
608 Node* get_attr = iter->node();
609 // record the name of GetAttr
610 path.push_back(get_attr->s(attr::name));
611 // trace back the chain of GetAttr
612 iter = get_attr->inputs()[0];
613 }
614 TORCH_CHECK(
615 iter == self,
616 "Can't handle the access pattern of GetAttr "
617 " in getModuleAccessPath, traced back to:",
618 iter->debugName(),
619 " which is not self:",
620 self->debugName());
621 std::reverse(path.begin(), path.end());
622 return path;
623 }
624
625 // Assuming self.foo.bar.conv1,
626 // Input1: Module instance of self
627 // Input2: ['foo', 'bar', 'conv1']
628 // Output: Module instance of conv1
findChildModule(const Module & module,const std::vector<std::string> & path)629 Module findChildModule(
630 const Module& module,
631 const std::vector<std::string>& path) {
632 Module m = module;
633 for (const auto& p : path) {
634 m = m.attr(p).toModule();
635 }
636 return m;
637 }
638
getInvokedModule(Module & module,Node * n,Value * self)639 Module getInvokedModule(Module& module, Node* n, Value* self) {
640 auto* instance = n->inputs()[0];
641 auto path = getModuleAccessPath(instance, self);
642 return findChildModule(module, path);
643 }
644
getInvokedModuleOpt(const Module & module,Node * n,Value * self)645 std::optional<Module> getInvokedModuleOpt(
646 const Module& module,
647 Node* n,
648 Value* self) {
649 auto* instance = n->inputs()[0];
650 auto path = getModuleAccessPath(instance, self);
651 Module m = module;
652 for (const auto& p : path) {
653 if (m.attr(p).isModule()) {
654 m = m.attr(p).toModule();
655 } else {
656 return std::nullopt;
657 }
658 }
659 return m;
660 }
661
662 // ==================== filter functions for matches ==============
is_int_constant(const Match & match,const std::unordered_map<std::string,Value * > & vmap,const std::string & vname,int value)663 bool is_int_constant(
664 const Match& match,
665 const std::unordered_map<std::string, Value*>& vmap,
666 const std::string& vname,
667 int value) {
668 const auto& match_vmap = match.values_map;
669 auto v = toIValue(match_vmap.at(vmap.at(vname)));
670 return v && v->isInt() && v->toInt() == value;
671 }
672
is_functional(const Match & match,const std::unordered_map<std::string,Value * > & vmap,const std::string & vname,const std::string & functional)673 static bool is_functional(
674 const Match& match,
675 const std::unordered_map<std::string, Value*>& vmap,
676 const std::string& vname,
677 const std::string& functional) {
678 const auto& match_vmap = match.values_map;
679 Value* v = match_vmap.at(vmap.at(vname));
680 return v->type()->cast<FunctionType>() && getFuncName(v) == functional;
681 }
682
removeTorchMangle(const std::string & orig_name)683 std::string removeTorchMangle(const std::string& orig_name) {
684 static std::regex mangle_re("\\.___torch_mangle_\\d+");
685 auto qualified_name = std::regex_replace(orig_name, mangle_re, "");
686 return qualified_name;
687 }
688
getModuleName(Value * value)689 std::optional<std::string> getModuleName(Value* value) {
690 auto type = value->type()->cast<ClassType>();
691 if (type && type->name()) {
692 return removeTorchMangle(type->name()->qualifiedName());
693 }
694 return std::nullopt;
695 }
696
is_module(const Match & match,const std::unordered_map<std::string,Value * > & vmap,const std::string & vname,const std::string & module_qualified_name)697 static bool is_module(
698 const Match& match,
699 const std::unordered_map<std::string, Value*>& vmap,
700 const std::string& vname,
701 const std::string& module_qualified_name) {
702 const auto& match_vmap = match.values_map;
703 Value* v = match_vmap.at(vmap.at(vname));
704 auto module_name = getModuleName(v);
705 if (module_name.has_value()) {
706 return module_name.value() == module_qualified_name;
707 }
708 return false;
709 };
710
aten_add_alpha_is_one(const Match & match,const std::unordered_map<std::string,Value * > & vmap)711 bool aten_add_alpha_is_one(
712 const Match& match,
713 const std::unordered_map<std::string, Value*>& vmap) {
714 return is_int_constant(match, vmap, "alpha", 1);
715 }
716
is_functional_relu(const Match & match,const std::unordered_map<std::string,Value * > & vmap)717 bool is_functional_relu(
718 const Match& match,
719 const std::unordered_map<std::string, Value*>& vmap) {
720 return is_functional(match, vmap, "relu", "relu");
721 }
722
is_relu_module(const Match & match,const std::unordered_map<std::string,Value * > & vmap)723 bool is_relu_module(
724 const Match& match,
725 const std::unordered_map<std::string, Value*>& vmap) {
726 return is_module(
727 match, vmap, "relu", "__torch__.torch.nn.modules.activation.ReLU");
728 }
729
is_linear_module(const Match & match,const std::unordered_map<std::string,Value * > & vmap)730 bool is_linear_module(
731 const Match& match,
732 const std::unordered_map<std::string, Value*>& vmap) {
733 return is_module(
734 match, vmap, "linear", "__torch__.torch.nn.modules.linear.Linear");
735 }
736
is_conv1d_module(const Match & match,const std::unordered_map<std::string,Value * > & vmap)737 bool is_conv1d_module(
738 const Match& match,
739 const std::unordered_map<std::string, Value*>& vmap) {
740 return is_module(
741 match, vmap, "conv", "__torch__.torch.nn.modules.conv.Conv1d");
742 }
743
is_conv2d_module(const Match & match,const std::unordered_map<std::string,Value * > & vmap)744 bool is_conv2d_module(
745 const Match& match,
746 const std::unordered_map<std::string, Value*>& vmap) {
747 return is_module(
748 match, vmap, "conv", "__torch__.torch.nn.modules.conv.Conv2d");
749 }
750
is_conv3d_module(const Match & match,const std::unordered_map<std::string,Value * > & vmap)751 bool is_conv3d_module(
752 const Match& match,
753 const std::unordered_map<std::string, Value*>& vmap) {
754 return is_module(
755 match, vmap, "conv", "__torch__.torch.nn.modules.conv.Conv3d");
756 }
757
is_conv_transpose1d_module(const Match & match,const std::unordered_map<std::string,Value * > & vmap)758 bool is_conv_transpose1d_module(
759 const Match& match,
760 const std::unordered_map<std::string, Value*>& vmap) {
761 return is_module(
762 match, vmap, "conv", "__torch__.torch.nn.modules.conv.ConvTranspose1d");
763 }
764
is_conv_transpose2d_module(const Match & match,const std::unordered_map<std::string,Value * > & vmap)765 bool is_conv_transpose2d_module(
766 const Match& match,
767 const std::unordered_map<std::string, Value*>& vmap) {
768 return is_module(
769 match, vmap, "conv", "__torch__.torch.nn.modules.conv.ConvTranspose2d");
770 }
771
is_batchnorm2d_module(const Match & match,const std::unordered_map<std::string,Value * > & vmap)772 bool is_batchnorm2d_module(
773 const Match& match,
774 const std::unordered_map<std::string, Value*>& vmap) {
775 bool regnorm = is_module(
776 match,
777 vmap,
778 "batchnorm",
779 "__torch__.torch.nn.modules.batchnorm.BatchNorm2d");
780 bool naivenorm = is_module(
781 match,
782 vmap,
783 "batchnorm",
784 "__torch__.mobile_cv.arch.layers.batch_norm.NaiveSyncBatchNorm");
785 return (regnorm || naivenorm);
786 }
787
is_batchnorm3d_module(const Match & match,const std::unordered_map<std::string,Value * > & vmap)788 bool is_batchnorm3d_module(
789 const Match& match,
790 const std::unordered_map<std::string, Value*>& vmap) {
791 return is_module(
792 match,
793 vmap,
794 "batchnorm",
795 "__torch__.torch.nn.modules.batchnorm.BatchNorm3d");
796 }
797
798 } // namespace jit
799 } // namespace torch
800