xref: /aosp_15_r20/external/pytorch/torch/ao/quantization/backend_config/_common_operator_config_utils.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import copy
3import operator
4from collections import namedtuple
5from typing import Callable, Dict, List, Union
6
7import torch
8import torch.ao.nn.intrinsic as nni
9import torch.ao.nn.intrinsic.qat as nniqat
10import torch.ao.nn.qat as nnqat
11import torch.ao.nn.quantized.reference as nnqr
12import torch.nn as nn
13import torch.nn.functional as F
14from torch.ao.quantization.fuser_method_mappings import (
15    _sequential_wrapper2,
16    fuse_conv_bn,
17    fuse_conv_bn_relu,
18    fuse_convtranspose_bn,
19    fuse_linear_bn,
20)
21
22from .backend_config import (
23    BackendPatternConfig,
24    DTypeConfig,
25    DTypeWithConstraints,
26    ObservationType,
27)
28
29
30__all__: List[str] = []
31
32# TODO: rename to be more explicit, e.g. qat_conv_relu
33_ConvMetadata = namedtuple(
34    "_ConvMetadata",
35    [
36        "root",
37        "transpose",
38        "bn",
39        "reference",
40        "transpose_reference",
41        "fused_conv_relu",
42        "fused_conv_bn",
43        "fused_conv_bn_relu",
44        "qat",
45        "relu_qat",
46        "bn_qat",
47        "bn_relu_qat",
48        "func",
49        "func_transpose",
50    ],
51)
52_Conv1dMetadata = _ConvMetadata(
53    nn.Conv1d,
54    nn.ConvTranspose1d,
55    nn.BatchNorm1d,
56    nnqr.Conv1d,
57    nnqr.ConvTranspose1d,
58    nni.ConvReLU1d,
59    nni.ConvBn1d,
60    nni.ConvBnReLU1d,
61    nnqat.Conv1d,
62    nniqat.ConvReLU1d,
63    nniqat.ConvBn1d,
64    nniqat.ConvBnReLU1d,
65    F.conv1d,
66    F.conv_transpose1d,
67)
68_Conv2dMetadata = _ConvMetadata(
69    nn.Conv2d,
70    nn.ConvTranspose2d,
71    nn.BatchNorm2d,
72    nnqr.Conv2d,
73    nnqr.ConvTranspose2d,
74    nni.ConvReLU2d,
75    nni.ConvBn2d,
76    nni.ConvBnReLU2d,
77    nnqat.Conv2d,
78    nniqat.ConvReLU2d,
79    nniqat.ConvBn2d,
80    nniqat.ConvBnReLU2d,
81    F.conv2d,
82    F.conv_transpose2d,
83)
84_Conv3dMetadata = _ConvMetadata(
85    nn.Conv3d,
86    nn.ConvTranspose3d,
87    nn.BatchNorm3d,
88    nnqr.Conv3d,
89    nnqr.ConvTranspose3d,
90    nni.ConvReLU3d,
91    nni.ConvBn3d,
92    nni.ConvBnReLU3d,
93    nnqat.Conv3d,
94    nniqat.ConvReLU3d,
95    nniqat.ConvBn3d,
96    nniqat.ConvBnReLU3d,
97    F.conv3d,
98    F.conv_transpose3d,
99)
100
101# Add constraints for fixed qparams ops like sigmoid and tanh to ensure values
102# fall within the proper ranges, e.g. [0, 1] for sigmoid, [-1, 1] for tanh
103_FIXED_QPARAM_OP_0TO1_CONSTRAINTS = DTypeWithConstraints(
104    dtype=torch.quint8,
105    quant_min_lower_bound=0,
106    quant_max_upper_bound=255,
107    scale_exact_match=1.0 / 256.0,
108    zero_point_exact_match=0,
109)
110_FIXED_QPARAM_OP_NEG1TO1_CONSTRAINTS = DTypeWithConstraints(
111    dtype=torch.quint8,
112    quant_min_lower_bound=0,
113    quant_max_upper_bound=255,
114    scale_exact_match=2.0 / 256.0,
115    zero_point_exact_match=128,
116)
117_FIXED_QPARAMS_OP_TO_CONSTRAINTS: Dict[Union[Callable, str], DTypeWithConstraints] = {
118    torch.nn.Hardsigmoid: _FIXED_QPARAM_OP_0TO1_CONSTRAINTS,
119    torch.nn.functional.hardsigmoid: _FIXED_QPARAM_OP_0TO1_CONSTRAINTS,
120    "hardsigmoid": _FIXED_QPARAM_OP_0TO1_CONSTRAINTS,
121    "hardsigmoid_": _FIXED_QPARAM_OP_0TO1_CONSTRAINTS,
122    torch.nn.Sigmoid: _FIXED_QPARAM_OP_0TO1_CONSTRAINTS,
123    torch.sigmoid: _FIXED_QPARAM_OP_0TO1_CONSTRAINTS,
124    "sigmoid": _FIXED_QPARAM_OP_0TO1_CONSTRAINTS,
125    "sigmoid_": _FIXED_QPARAM_OP_0TO1_CONSTRAINTS,
126    torch.nn.Softmax: _FIXED_QPARAM_OP_0TO1_CONSTRAINTS,
127    torch.nn.Tanh: _FIXED_QPARAM_OP_NEG1TO1_CONSTRAINTS,
128    torch.tanh: _FIXED_QPARAM_OP_NEG1TO1_CONSTRAINTS,
129    "tanh": _FIXED_QPARAM_OP_NEG1TO1_CONSTRAINTS,
130    "tanh_": _FIXED_QPARAM_OP_NEG1TO1_CONSTRAINTS,
131}
132
133
134def _get_binary_op_configs(
135    dtype_configs: List[DTypeConfig],
136) -> List[BackendPatternConfig]:
137    binary_op_configs: List[BackendPatternConfig] = []
138    num_tensor_args_to_observation_type_mapping = {
139        # TODO: this is not used right now since we have extra check in prepare
140        # will need to change this to NO_OBSERVER later after we implemented
141        # Tensor dtype inference properly
142        0: ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT,
143        1: ObservationType.OUTPUT_SHARE_OBSERVER_WITH_INPUT,
144        2: ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT,
145    }
146    for op_with_quantized_bop_scalar_variant in [
147        operator.add,
148        torch.add,
149        operator.mul,
150        torch.mul,
151    ]:
152        bop_patterns = [
153            (op_with_quantized_bop_scalar_variant, nn.ReLU),
154            (op_with_quantized_bop_scalar_variant, F.relu),
155            (op_with_quantized_bop_scalar_variant, torch.relu),
156            op_with_quantized_bop_scalar_variant,
157        ]
158        for bop_pattern in bop_patterns:
159            binary_op_configs.append(
160                BackendPatternConfig(bop_pattern)
161                .set_dtype_configs(dtype_configs)  # noqa: E131
162                ._set_num_tensor_args_to_observation_type(
163                    num_tensor_args_to_observation_type_mapping
164                )
165            )
166    # matmul
167    binary_op_configs.append(
168        BackendPatternConfig(torch.matmul).set_dtype_configs(
169            dtype_configs
170        )  # noqa: E131
171    )
172    return binary_op_configs
173
174
175def _get_linear_configs(dtype_configs: List[DTypeConfig]) -> List[BackendPatternConfig]:
176    """
177    Return all configs related to linear modules and ops.
178    """
179    observation_type = ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT
180    linear_configs: List[BackendPatternConfig] = []
181
182    # (1) Single linear modules/functions
183    # -------------------------------------
184    # linear module
185    linear_configs.append(
186        BackendPatternConfig(torch.nn.Linear)
187        .set_observation_type(observation_type)  # noqa: E131
188        .set_dtype_configs(dtype_configs)
189        .set_root_module(torch.nn.Linear)
190        .set_reference_quantized_module(nnqr.Linear)
191        .set_qat_module(nnqat.Linear)
192    )
193    # linear qat module
194    linear_configs.append(
195        BackendPatternConfig(nnqat.Linear)
196        .set_observation_type(observation_type)  # noqa: E131
197        .set_dtype_configs(dtype_configs)
198        .set_root_module(torch.nn.Linear)
199        .set_reference_quantized_module(nnqr.Linear)
200    )
201    # functional linear
202    linear_configs.append(
203        BackendPatternConfig(torch.nn.functional.linear)
204        .set_observation_type(observation_type)  # noqa: E131
205        .set_dtype_configs(dtype_configs)
206        ._set_input_type_to_index({"weight": 1, "bias": 2})
207    )
208
209    # (2) Linear + relu
210    # -------------------
211    # 2.1 linear module + relu fusion config
212    # linear relu, linear module + relu module
213    linear_configs.append(
214        BackendPatternConfig((torch.nn.Linear, torch.nn.ReLU))
215        .set_dtype_configs(dtype_configs)  # noqa: E131
216        .set_fuser_method(_sequential_wrapper2(nni.LinearReLU))
217        .set_fused_module(nni.LinearReLU)
218    )
219    # linear relu, linear module + functional relu
220    linear_configs.append(
221        BackendPatternConfig((torch.nn.Linear, torch.nn.functional.relu))
222        .set_dtype_configs(dtype_configs)  # noqa: E131
223        .set_fuser_method(_sequential_wrapper2(nni.LinearReLU))
224        .set_fused_module(nni.LinearReLU)
225    )
226
227    # 2.2 linear module + relu, fused module configs
228    # linear relu, fused module
229    linear_configs.append(
230        BackendPatternConfig(nni.LinearReLU)
231        .set_observation_type(observation_type)  # noqa: E131
232        .set_dtype_configs(dtype_configs)
233        .set_root_module(torch.nn.Linear)
234        .set_reference_quantized_module(nnqr.Linear)
235        .set_qat_module(nniqat.LinearReLU)
236    )
237    # linear relu, qat fused module
238    linear_configs.append(
239        BackendPatternConfig(nniqat.LinearReLU)
240        .set_observation_type(observation_type)  # noqa: E131
241        .set_dtype_configs(dtype_configs)
242        .set_root_module(torch.nn.Linear)
243        .set_reference_quantized_module(nnqr.Linear)
244    )
245    # 2.3 functional linear + relu configs
246    # linear relu, functional linear + relu module
247    linear_configs.append(
248        BackendPatternConfig((F.linear, torch.nn.ReLU))
249        .set_observation_type(observation_type)  # noqa: E131
250        .set_dtype_configs(dtype_configs)
251    )
252    # linear relu, functional linear + functional relu
253    linear_configs.append(
254        BackendPatternConfig((F.linear, F.relu))
255        .set_observation_type(observation_type)  # noqa: E131
256        .set_dtype_configs(dtype_configs)
257    )
258
259    # (3) Linear + batchnorm
260    # ------------------------
261    # 3.1 linear bn fusion
262    linear_configs.append(
263        BackendPatternConfig((nn.Linear, nn.BatchNorm1d))
264        .set_dtype_configs(dtype_configs)  # noqa: E131
265        .set_fuser_method(fuse_linear_bn)
266        .set_fused_module(nni.LinearBn1d)
267    )
268
269    # 3.2 linear bn fused
270    # linear bn, fused module
271    linear_configs.append(
272        BackendPatternConfig(nni.LinearBn1d)
273        .set_observation_type(observation_type)  # noqa: E131
274        .set_dtype_configs(dtype_configs)
275        .set_root_module(torch.nn.Linear)
276        .set_reference_quantized_module(nnqr.Linear)
277        .set_qat_module(nniqat.LinearBn1d)
278    )
279    # linear bn, qat fused module
280    linear_configs.append(
281        BackendPatternConfig(nniqat.LinearBn1d)
282        .set_observation_type(observation_type)  # noqa: E131
283        .set_dtype_configs(dtype_configs)
284        .set_root_module(torch.nn.Linear)
285        .set_reference_quantized_module(nnqr.Linear)
286    )
287    return linear_configs
288
289
290def _get_conv_configs(dtype_configs):
291    """
292    Return all configs related to conv modules and ops.
293    """
294    conv_configs = []
295    observation_type = ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT
296    for convs in [_Conv1dMetadata, _Conv2dMetadata, _Conv3dMetadata]:
297        # (1) Single conv modules/functions
298        # -----------------------------------
299        # conv module
300        conv_configs.append(
301            BackendPatternConfig(convs.root)
302            .set_observation_type(observation_type)  # noqa: E131
303            .set_dtype_configs(dtype_configs)
304            .set_root_module(convs.root)
305            .set_reference_quantized_module(convs.reference)
306            .set_qat_module(convs.qat)
307        )
308        # conv qat module
309        conv_configs.append(
310            BackendPatternConfig(convs.qat)
311            .set_observation_type(observation_type)  # noqa: E131
312            .set_dtype_configs(dtype_configs)
313            .set_root_module(convs.root)
314            .set_reference_quantized_module(convs.reference)
315        )
316        # functional conv
317        conv_configs.append(
318            BackendPatternConfig(convs.func)
319            .set_observation_type(observation_type)  # noqa: E131
320            .set_dtype_configs(dtype_configs)
321            ._set_input_type_to_index({"weight": 1, "bias": 2})
322        )
323
324        # (2) Conv + relu
325        # -----------------
326        # 2.1 conv module + relu fusion configs
327        # conv relu fusion, conv module + relu module
328        conv_configs.append(
329            BackendPatternConfig((convs.root, torch.nn.ReLU))
330            .set_dtype_configs(dtype_configs)  # noqa: E131
331            .set_fuser_method(_sequential_wrapper2(convs.fused_conv_relu))
332            .set_fused_module(convs.fused_conv_relu)
333        )
334        # conv relu fusion, conv module + functional relu
335        conv_configs.append(
336            BackendPatternConfig((convs.root, F.relu))
337            .set_dtype_configs(dtype_configs)  # noqa: E131
338            .set_fuser_method(_sequential_wrapper2(convs.fused_conv_relu))
339            .set_fused_module(convs.fused_conv_relu)
340        )
341        # 2.2 conv module + relu fused module configs
342        # conv relu, fused module
343        conv_configs.append(
344            BackendPatternConfig(convs.fused_conv_relu)
345            .set_observation_type(observation_type)  # noqa: E131
346            .set_dtype_configs(dtype_configs)
347            .set_root_module(convs.root)
348            .set_reference_quantized_module(convs.reference)
349            .set_qat_module(convs.relu_qat)
350        )
351        # conv relu, qat fused module
352        conv_configs.append(
353            BackendPatternConfig(convs.relu_qat)
354            .set_observation_type(observation_type)  # noqa: E131
355            .set_dtype_configs(dtype_configs)
356            .set_root_module(convs.root)
357            .set_reference_quantized_module(convs.reference)
358        )
359        # 2.3 functional conv + relu configs
360        # conv relu, functional conv + relu module
361        conv_configs.append(
362            BackendPatternConfig((convs.func, torch.nn.ReLU))
363            .set_observation_type(observation_type)  # noqa: E131
364            .set_dtype_configs(dtype_configs)
365        )
366        # conv relu, functional conv + functional relu
367        conv_configs.append(
368            BackendPatternConfig((convs.func, F.relu))
369            .set_observation_type(observation_type)  # noqa: E131
370            .set_dtype_configs(dtype_configs)
371        )
372
373        # fused conv relu
374        conv_configs.append(
375            BackendPatternConfig(convs.fused_conv_relu)
376            .set_dtype_configs(dtype_configs)  # noqa: E131
377            .set_qat_module(convs.relu_qat)
378        )
379
380        conv_configs.append(
381            BackendPatternConfig(convs.relu_qat)
382            .set_dtype_configs(dtype_configs)  # noqa: E131
383            .set_root_module(convs.root)
384            .set_reference_quantized_module(convs.reference)
385        )
386
387        # (3) Conv + batchnorm (+ relu)
388        # -------------------------------
389        # 3.1 conv bn fusion configs
390        # conv + bn fusion
391        conv_configs.append(
392            BackendPatternConfig((convs.root, convs.bn))
393            .set_dtype_configs(dtype_configs)  # noqa: E131
394            .set_fuser_method(fuse_conv_bn)
395            .set_fused_module(convs.fused_conv_bn)
396        )
397        # conv + bn + relu module fusion
398        conv_configs.append(
399            BackendPatternConfig((convs.root, convs.bn, nn.ReLU))
400            .set_dtype_configs(dtype_configs)  # noqa: E131
401            .set_fuser_method(fuse_conv_bn_relu)
402            .set_fused_module(convs.fused_conv_bn_relu)
403        )
404        # conv + bn + relu functional fusion
405        conv_configs.append(
406            BackendPatternConfig((convs.root, convs.bn, F.relu))
407            .set_dtype_configs(dtype_configs)  # noqa: E131
408            .set_root_module(convs.root)
409            .set_fuser_method(fuse_conv_bn_relu)
410            .set_fused_module(convs.fused_conv_bn_relu)
411        )
412        # TODO: we can add fusion for torch.relu as well
413
414        # 3.2 conv + bn (+ relu) fused module configs
415        # fused conv bn
416        conv_configs.append(
417            BackendPatternConfig(convs.fused_conv_bn)
418            .set_dtype_configs(dtype_configs)  # noqa: E131
419            .set_qat_module(convs.bn_qat)
420        )
421
422        # fused conv bn relu
423        conv_configs.append(
424            BackendPatternConfig(convs.fused_conv_bn_relu)
425            .set_dtype_configs(dtype_configs)  # noqa: E131
426            .set_qat_module(convs.bn_relu_qat)
427        )
428
429        # conv bn, qat fused module
430        conv_configs.append(
431            BackendPatternConfig(convs.bn_qat)
432            .set_observation_type(observation_type)  # noqa: E131
433            .set_dtype_configs(dtype_configs)
434            .set_root_module(convs.root)
435            .set_reference_quantized_module(convs.reference)
436        )
437        # conv bn relu, qat fused module
438        conv_configs.append(
439            BackendPatternConfig(convs.bn_relu_qat)
440            .set_observation_type(observation_type)  # noqa: E131
441            .set_dtype_configs(dtype_configs)
442            .set_root_module(convs.root)
443            .set_reference_quantized_module(convs.reference)
444        )
445
446        # (4) conv transpose and its fusion
447        # 4.1 conv transpose config
448        conv_configs.append(
449            BackendPatternConfig(convs.transpose)
450            .set_dtype_configs(dtype_configs)  # noqa: E131
451            .set_root_module(convs.transpose)
452            .set_reference_quantized_module(convs.transpose_reference)
453        )
454
455        # 4.2 conv transpose + bn fusion
456        conv_configs.append(
457            BackendPatternConfig((convs.transpose, convs.bn))
458            .set_dtype_configs(dtype_configs)  # noqa: E131
459            .set_fuser_method(fuse_convtranspose_bn)
460            .set_root_module(convs.transpose)
461            .set_reference_quantized_module(convs.transpose_reference)
462        )
463
464        # 4.3 functional conv transpose
465        conv_configs.append(
466            BackendPatternConfig(convs.func_transpose)
467            .set_dtype_configs(dtype_configs)  # noqa: E131
468            ._set_input_type_to_index({"weight": 1, "bias": 2})
469        )
470
471    return conv_configs
472
473
474def _get_cat_config(dtype_configs: List[DTypeConfig]) -> BackendPatternConfig:
475    return (
476        BackendPatternConfig(torch.cat)
477        .set_observation_type(ObservationType.OUTPUT_SHARE_OBSERVER_WITH_INPUT)
478        .set_dtype_configs(dtype_configs)
479    )
480
481
482def _get_ln_configs(dtype_configs: List[DTypeConfig]) -> List[BackendPatternConfig]:
483    ln_configs = []
484    ln_configs.append(
485        BackendPatternConfig(torch.nn.LayerNorm)
486        .set_observation_type(
487            ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT
488        )  # noqa: E131
489        .set_dtype_configs(dtype_configs)
490    )
491    ln_configs.append(
492        BackendPatternConfig(torch.nn.functional.layer_norm)
493        .set_observation_type(
494            ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT
495        )  # noqa: E131
496        .set_dtype_configs(dtype_configs)
497        ._set_input_type_to_index({"weight": 2, "bias": 3})
498    )
499    return ln_configs
500
501
502def _get_default_op_configs(
503    dtype_configs: List[DTypeConfig],
504) -> List[BackendPatternConfig]:
505    configs = []
506    default_ops = [
507        torch.nn.ELU,
508        torch.nn.LeakyReLU,
509        torch.nn.Hardswish,
510        torch.nn.InstanceNorm1d,
511        torch.nn.InstanceNorm2d,
512        torch.nn.InstanceNorm3d,
513        torch.nn.Dropout,
514        torch.nn.PReLU,
515        torch.nn.functional.elu,
516        torch.nn.functional.hardswish,
517        torch.nn.functional.leaky_relu,
518        torch.nn.functional.dropout,
519    ]
520    for op in default_ops:
521        configs.append(
522            BackendPatternConfig(op)
523            .set_observation_type(
524                ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT
525            )  # noqa: E131
526            .set_dtype_configs(dtype_configs)
527        )
528
529    configs.append(
530        BackendPatternConfig(torch.nn.functional.group_norm)
531        .set_observation_type(
532            ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT
533        )  # noqa: E131
534        .set_dtype_configs(dtype_configs)
535        ._set_input_type_to_index({"weight": 2, "bias": 3})
536    )
537
538    configs.append(
539        BackendPatternConfig(torch.nn.functional.instance_norm)
540        .set_observation_type(
541            ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT
542        )  # noqa: E131
543        .set_dtype_configs(dtype_configs)
544        ._set_input_type_to_index({"weight": 3, "bias": 4})
545    )
546    return configs
547
548
549def _add_fixed_qparams_to_dtype_configs(
550    dtype_configs: List[DTypeConfig],
551    constraints: DTypeWithConstraints,
552) -> List[DTypeConfig]:
553    """
554    Return a copy of the list of DTypeConfigs where activations are subject to the specified
555    constraints required for fixed qparams ops.
556
557    If the data type doesn't match the one in the constraints, simply leave the corresponding
558    DTypeConfig unchanged.
559
560    If `scale_min_lower_bound` or `scale_max_upper_bound` is specified in the activations,
561    throw an exception since these settings are incompatible with fixed qparams ops.
562    """
563    new_dtype_configs = []
564    for dtype_config in dtype_configs:
565        dc = copy.deepcopy(dtype_config)
566        for orig_constraints in [
567            dc.input_dtype_with_constraints,
568            dc.output_dtype_with_constraints,
569        ]:
570            if orig_constraints.dtype != constraints.dtype:
571                continue
572            if orig_constraints.scale_min_lower_bound is not None:
573                raise ValueError(
574                    f"scale_min_lower_bound is invalid for fixed qparams ops: {dtype_config}"
575                )
576            if orig_constraints.scale_max_upper_bound is not None:
577                raise ValueError(
578                    f"scale_max_upper_bound is invalid for fixed qparams ops: {dtype_config}"
579                )
580            orig_constraints.quant_min_lower_bound = constraints.quant_min_lower_bound
581            orig_constraints.quant_max_upper_bound = constraints.quant_max_upper_bound
582            orig_constraints.scale_exact_match = constraints.scale_exact_match
583            orig_constraints.zero_point_exact_match = constraints.zero_point_exact_match
584        new_dtype_configs.append(dc)
585    return new_dtype_configs
586
587
588def _get_fixed_qparams_op_configs(
589    dtype_configs: List[DTypeConfig],
590) -> List[BackendPatternConfig]:
591    fixed_qparams_op_configs = []
592    for fixed_qparam_op, constraints in _FIXED_QPARAMS_OP_TO_CONSTRAINTS.items():
593        new_dtype_configs = _add_fixed_qparams_to_dtype_configs(
594            dtype_configs, constraints
595        )
596        fixed_qparams_op_configs.append(
597            BackendPatternConfig(fixed_qparam_op)
598            .set_observation_type(
599                ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT
600            )  # noqa: E131
601            .set_dtype_configs(new_dtype_configs)
602        )
603    return fixed_qparams_op_configs
604
605
606def _get_share_qparams_op_configs(dtype_configs):
607    """Get the operator config for the operators that works for both float and quantized input
608    if input is quantized, the output Tensor shares the same quantization parameter
609    with input.
610    Example operator: avgpool2d, reshape, transpose, maxpool2d
611    Example observed operator:
612    observer_0 - avgpool2d - observer_0 (same observer instance as input)
613    """
614
615    def _get_share_qprams_op_backend_config(op):
616        return (
617            BackendPatternConfig(op)
618            .set_observation_type(ObservationType.OUTPUT_SHARE_OBSERVER_WITH_INPUT)
619            .set_dtype_configs(dtype_configs)
620        )
621
622    share_qparams_ops = [
623        torch.nn.AdaptiveAvgPool1d,
624        torch.nn.AdaptiveAvgPool2d,
625        torch.nn.AdaptiveAvgPool3d,
626        torch.nn.AvgPool1d,
627        torch.nn.AvgPool2d,
628        torch.nn.AvgPool3d,
629        torch.nn.Hardtanh,
630        torch.nn.Identity,
631        torch.nn.MaxPool1d,
632        torch.nn.MaxPool2d,
633        torch.nn.MaxPool3d,
634        torch.nn.PixelShuffle,
635        torch.nn.PixelUnshuffle,
636        torch.nn.ReLU,
637        torch.nn.ReLU6,
638        torch.adaptive_avg_pool1d,
639        torch.nn.functional.adaptive_avg_pool2d,
640        torch.nn.functional.adaptive_avg_pool3d,
641        torch.nn.functional.hardtanh,
642        torch.nn.functional.hardtanh_,
643        torch.nn.functional.interpolate,
644        torch.nn.functional.max_pool1d,
645        torch.nn.functional.max_pool2d,
646        torch.nn.functional.max_pool3d,
647        torch.nn.functional.pixel_shuffle,
648        torch.nn.functional.pixel_unshuffle,
649        torch.nn.functional.relu,
650        torch.nn.functional.relu6,
651        torch.avg_pool1d,
652        torch._C._nn.avg_pool2d,
653        torch._C._nn.avg_pool3d,
654        torch.clamp,
655        torch.flatten,
656        torch.mean,
657        torch.narrow,
658        torch.repeat_interleave,
659        torch.transpose,
660        torch.squeeze,
661        torch.stack,
662        torch.unsqueeze,
663        operator.floordiv,
664        "contiguous",
665        "clamp",
666        "detach",
667        "detach_",
668        "mean",
669        "permute",
670        "repeat",
671        "repeat_interleave",
672        "reshape",
673        "resize_",
674        "relu",
675        "relu_",
676        "squeeze",
677        "squeeze_",
678        "transpose",
679        "unsqueeze",
680        "unsqueeze_",
681        "view",
682    ]
683    return [_get_share_qprams_op_backend_config(op) for op in share_qparams_ops]
684
685
686def _get_bn_configs(dtype_configs: List[DTypeConfig]) -> List[BackendPatternConfig]:
687    """Get configs related to batchnorm."""
688    bn_configs = []
689    bn_to_fused_bn = {
690        torch.nn.BatchNorm2d: nni.BNReLU2d,
691        torch.nn.BatchNorm3d: nni.BNReLU3d,
692    }
693    for bn in bn_to_fused_bn.keys():
694        fused_bn = bn_to_fused_bn[bn]
695        # bn module + relu module fusion config
696        bn_configs.append(
697            BackendPatternConfig((bn, nn.ReLU))
698            .set_dtype_configs(dtype_configs)  # noqa: E131
699            .set_fuser_method(_sequential_wrapper2(fused_bn))
700            .set_fused_module(fused_bn)
701        )
702        # bn module + F.relu fusion config
703        bn_configs.append(
704            BackendPatternConfig((bn, F.relu))
705            .set_dtype_configs(dtype_configs)  # noqa: E131
706            .set_fuser_method(_sequential_wrapper2(fused_bn))
707            .set_fused_module(fused_bn)
708        )
709        bn_configs.append(
710            BackendPatternConfig(bn)
711            .set_observation_type(
712                ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT
713            )  # noqa: E131
714            .set_dtype_configs(dtype_configs)
715        )
716
717    # fused bn configs
718    for fused_bn in bn_to_fused_bn.values():
719        bn_configs.append(
720            BackendPatternConfig(fused_bn)
721            .set_observation_type(
722                ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT
723            )  # noqa: E131
724            .set_dtype_configs(dtype_configs)
725        )
726    return bn_configs
727
728
729def _get_rnn_op_configs(dtype_configs: List[DTypeConfig]) -> List[BackendPatternConfig]:
730    rnn_op_configs = []
731    for rnn_op, ref_rnn_op in [
732        (nn.GRUCell, nnqr.GRUCell),
733        (nn.LSTMCell, nnqr.LSTMCell),
734        (nn.RNNCell, nnqr.RNNCell),
735        (nn.LSTM, nnqr.LSTM),
736        (nn.GRU, nnqr.GRU),
737    ]:
738        rnn_op_configs.append(
739            BackendPatternConfig(rnn_op)
740            .set_observation_type(
741                ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT
742            )  # noqa: E131
743            .set_dtype_configs(dtype_configs)
744            .set_root_module(rnn_op)
745            .set_reference_quantized_module(ref_rnn_op)
746        )
747    return rnn_op_configs
748
749
750def _get_embedding_op_configs(
751    dtype_configs: List[DTypeConfig],
752) -> List[BackendPatternConfig]:
753    embedding_op_configs = []
754    for embedding_op, qat_embedding_op, ref_embedding_op in [
755        (nn.Embedding, nnqat.Embedding, nnqr.Embedding),
756        (nn.EmbeddingBag, nnqat.EmbeddingBag, nnqr.EmbeddingBag),
757    ]:
758        embedding_op_configs.append(
759            BackendPatternConfig(embedding_op)
760            .set_observation_type(
761                ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT
762            )  # noqa: E131
763            .set_dtype_configs(dtype_configs)
764            .set_qat_module(qat_embedding_op)
765            .set_root_module(embedding_op)
766            .set_reference_quantized_module(ref_embedding_op)
767        )
768
769        # config for qat op
770        embedding_op_configs.append(
771            BackendPatternConfig(qat_embedding_op)
772            .set_observation_type(
773                ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT
774            )  # noqa: E131
775            .set_dtype_configs(dtype_configs)
776            .set_root_module(embedding_op)
777            .set_reference_quantized_module(ref_embedding_op)
778        )
779    return embedding_op_configs
780
781
782def _get_tensor_info_op_configs(dtype_configs):
783    """
784    These ops work on tensors of different dtypes but return non-tensors
785    containing information about the input tensor.
786    """
787
788    def _get_config(op):
789        return (
790            BackendPatternConfig(op)
791            .set_observation_type(ObservationType.INPUT_OUTPUT_NOT_OBSERVED)
792            .set_dtype_configs(dtype_configs)
793        )
794
795    return [_get_config(op) for op in ("shape", "size")]
796