xref: /aosp_15_r20/external/pytorch/torch/ao/quantization/backend_config/executorch.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# TODO: rename executorch to qnnpack_executorch since executorch is a general runtime
2# not a specific backend
3
4import operator
5from typing import List
6
7import torch
8import torch.ao.nn.qat as nnqat
9import torch.ao.nn.quantized.reference as nnqr
10import torch.nn as nn
11import torch.nn.functional as F
12from torch.ao.quantization.fuser_method_mappings import (
13    _sequential_wrapper2,
14    fuse_conv_bn,
15    fuse_conv_bn_relu,
16)
17
18from ._common_operator_config_utils import _Conv2dMetadata
19from .backend_config import (
20    BackendConfig,
21    BackendPatternConfig,
22    DTypeConfig,
23    DTypeWithConstraints,
24    ObservationType,
25)
26from .qnnpack import (
27    qnnpack_default_op_qint8_symmetric_dtype_config,
28    qnnpack_weighted_op_qint8_symmetric_dtype_config,
29)
30
31
32__all__ = [
33    "get_executorch_backend_config",
34]
35
36
37# ===================
38# |  DTYPE CONFIGS  |
39# ===================
40
41executorch_weighted_op_int8_dtype_config = DTypeConfig(
42    input_dtype=torch.quint8,
43    output_dtype=torch.quint8,
44    weight_dtype=torch.qint8,
45    bias_dtype=torch.float,
46)
47
48executorch_default_op_quint8_dtype_config = DTypeConfig(
49    input_dtype=torch.quint8,
50    output_dtype=torch.quint8,
51)
52
53executorch_default_dynamic_quint8_dtype_config = DTypeConfig(
54    input_dtype=torch.quint8,
55    output_dtype=torch.float,
56    weight_dtype=torch.qint8,
57    bias_dtype=torch.float,
58    is_dynamic=True,
59)
60
61executorch_act_qint8_scale_min_2_neg_12 = DTypeWithConstraints(
62    dtype=torch.qint8,
63    scale_min_lower_bound=2**-12,
64)
65
66executorch_weight_qint8_neg_127_to_127_scale_min_2_neg_12 = DTypeWithConstraints(
67    dtype=torch.qint8,
68    quant_min_lower_bound=-127,
69    quant_max_upper_bound=127,
70    scale_min_lower_bound=2**-12,
71)
72
73executorch_default_dynamic_qint8_dtype_config = DTypeConfig(
74    input_dtype=executorch_act_qint8_scale_min_2_neg_12,
75    output_dtype=torch.float,
76    weight_dtype=executorch_weight_qint8_neg_127_to_127_scale_min_2_neg_12,
77    bias_dtype=torch.float,
78    is_dynamic=True,
79)
80
81executorch_default_dynamic_float16_dtype_config = DTypeConfig(
82    input_dtype=torch.float16,
83    output_dtype=torch.float,
84    weight_dtype=torch.float16,
85    bias_dtype=torch.float,
86    is_dynamic=True,
87)
88
89executorch_weight_only_quint8_dtype_config = DTypeConfig(
90    input_dtype=torch.float,
91    output_dtype=torch.float,
92    weight_dtype=torch.quint8,
93)
94
95
96# =============================
97# |  BACKEND PATTERN CONFIGS  |
98# =============================
99
100
101def _get_linear_configs() -> List[BackendPatternConfig]:
102    """
103    Return all configs related to linear modules and ops.
104    """
105    observation_type = ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT
106    dtype_configs = [
107        qnnpack_weighted_op_qint8_symmetric_dtype_config,
108        executorch_weighted_op_int8_dtype_config,
109        executorch_default_dynamic_quint8_dtype_config,
110        executorch_default_dynamic_qint8_dtype_config,
111        executorch_default_dynamic_float16_dtype_config,
112    ]
113    linear_configs: List[BackendPatternConfig] = []
114    # linear module
115    linear_configs.append(
116        BackendPatternConfig(torch.nn.Linear)
117        .set_observation_type(observation_type)  # noqa: E131
118        .set_dtype_configs(dtype_configs)
119        .set_root_module(torch.nn.Linear)
120        .set_reference_quantized_module(nnqr.Linear)
121        .set_qat_module(nnqat.Linear)
122    )
123    # linear qat module
124    linear_configs.append(
125        BackendPatternConfig(nnqat.Linear)
126        .set_observation_type(observation_type)  # noqa: E131
127        .set_dtype_configs(dtype_configs)
128        .set_root_module(torch.nn.Linear)
129        .set_reference_quantized_module(nnqr.Linear)
130    )
131    # functional linear
132    linear_configs.append(
133        BackendPatternConfig(torch.nn.functional.linear)
134        .set_observation_type(observation_type)  # noqa: E131
135        .set_dtype_configs(dtype_configs)
136        ._set_input_type_to_index({"weight": 1, "bias": 2})
137    )
138    return linear_configs
139
140
141def _get_conv_configs() -> List[BackendPatternConfig]:
142    """
143    Return all configs related to conv modules and ops.
144    """
145    observation_type = ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT
146    dtype_configs = [
147        qnnpack_weighted_op_qint8_symmetric_dtype_config,
148        executorch_weighted_op_int8_dtype_config,
149    ]
150    conv_configs = []
151    for convs in [_Conv2dMetadata]:
152        # (1) Single conv modules/functions
153        # -----------------------------------
154        # conv module
155        conv_configs.append(
156            BackendPatternConfig(convs.root)
157            .set_observation_type(observation_type)  # noqa: E131
158            .set_dtype_configs(dtype_configs)
159            .set_root_module(convs.root)
160            .set_reference_quantized_module(convs.reference)
161            .set_qat_module(convs.qat)
162        )
163        # conv qat module
164        conv_configs.append(
165            BackendPatternConfig(convs.qat)
166            .set_observation_type(observation_type)  # noqa: E131
167            .set_dtype_configs(dtype_configs)
168            .set_root_module(convs.root)
169            .set_reference_quantized_module(convs.reference)
170        )
171        # functional conv
172        conv_configs.append(
173            BackendPatternConfig(convs.func)
174            .set_observation_type(observation_type)  # noqa: E131
175            .set_dtype_configs(dtype_configs)
176            ._set_input_type_to_index({"weight": 1, "bias": 2})
177        )
178
179        # (2) Conv + relu
180        # -----------------------------------
181        # conv module + relu module
182        conv_configs.append(
183            BackendPatternConfig((convs.root, nn.ReLU))
184            .set_dtype_configs(dtype_configs)  # noqa: E131
185            .set_fuser_method(_sequential_wrapper2(convs.fused_conv_relu))
186            .set_fused_module(convs.fused_conv_relu)
187        )
188        # conv module + functional relu
189        conv_configs.append(
190            BackendPatternConfig((convs.root, F.relu))
191            .set_dtype_configs(dtype_configs)  # noqa: E131
192            .set_fuser_method(_sequential_wrapper2(convs.fused_conv_relu))
193            .set_fused_module(convs.fused_conv_relu)
194        )
195        # fused conv relu module
196        conv_configs.append(
197            BackendPatternConfig(convs.fused_conv_relu)
198            .set_observation_type(observation_type)  # noqa: E131
199            .set_dtype_configs(dtype_configs)
200            .set_root_module(convs.root)
201            .set_reference_quantized_module(convs.reference)
202            .set_qat_module(convs.relu_qat)
203        )
204        # conv relu, qat fused module
205        conv_configs.append(
206            BackendPatternConfig(convs.relu_qat)
207            .set_observation_type(observation_type)  # noqa: E131
208            .set_dtype_configs(dtype_configs)
209            .set_root_module(convs.root)
210            .set_reference_quantized_module(convs.reference)
211        )
212        # functional conv + relu module
213        conv_configs.append(
214            BackendPatternConfig((convs.func, nn.ReLU))
215            .set_observation_type(observation_type)  # noqa: E131
216            .set_dtype_configs(dtype_configs)
217        )
218        # functional conv + functional relu
219        conv_configs.append(
220            BackendPatternConfig((convs.func, F.relu))
221            .set_observation_type(observation_type)  # noqa: E131
222            .set_dtype_configs(dtype_configs)
223        )
224        # fused conv relu
225        conv_configs.append(
226            BackendPatternConfig(convs.fused_conv_relu)
227            .set_dtype_configs(dtype_configs)  # noqa: E131
228            .set_qat_module(convs.relu_qat)
229        )
230
231        conv_configs.append(
232            BackendPatternConfig(convs.relu_qat)
233            .set_dtype_configs(dtype_configs)  # noqa: E131
234            .set_root_module(convs.root)
235            .set_reference_quantized_module(convs.reference)
236        )
237
238        # (3) Conv + batchnorm (+ relu)
239        # -------------------------------
240        # conv + batchnorm (+ relu)
241        conv_configs.append(
242            BackendPatternConfig((convs.root, convs.bn))
243            .set_dtype_configs(dtype_configs)  # noqa: E131
244            .set_fuser_method(fuse_conv_bn)
245            .set_fused_module(convs.fused_conv_bn)
246        )
247        # conv + bn + relu module fusion
248        conv_configs.append(
249            BackendPatternConfig((convs.root, convs.bn, nn.ReLU))
250            .set_dtype_configs(dtype_configs)  # noqa: E131
251            .set_fuser_method(fuse_conv_bn_relu)
252            .set_fused_module(convs.fused_conv_bn_relu)
253        )
254        # conv + bn + relu functional fusion
255        conv_configs.append(
256            BackendPatternConfig((convs.root, convs.bn, F.relu))
257            .set_dtype_configs(dtype_configs)  # noqa: E131
258            .set_root_module(convs.root)
259            .set_fuser_method(fuse_conv_bn_relu)
260            .set_fused_module(convs.fused_conv_bn_relu)
261        )
262        # TODO: we can add fusion for torch.relu as well
263        # 3.2 conv + bn (+ relu) fused module configs
264        # fused conv bn
265        conv_configs.append(
266            BackendPatternConfig(convs.fused_conv_bn)
267            .set_dtype_configs(dtype_configs)  # noqa: E131
268            .set_qat_module(convs.bn_qat)
269        )
270
271        # fused conv bn relu
272        conv_configs.append(
273            BackendPatternConfig(convs.fused_conv_bn_relu)
274            .set_dtype_configs(dtype_configs)  # noqa: E131
275            .set_qat_module(convs.bn_relu_qat)
276        )
277
278        # conv bn, qat fused module
279        conv_configs.append(
280            BackendPatternConfig(convs.bn_qat)
281            .set_observation_type(observation_type)  # noqa: E131
282            .set_dtype_configs(dtype_configs)
283            .set_root_module(convs.root)
284            .set_reference_quantized_module(convs.reference)
285        )
286        # conv bn relu, qat fused module
287        conv_configs.append(
288            BackendPatternConfig(convs.bn_relu_qat)
289            .set_observation_type(observation_type)  # noqa: E131
290            .set_dtype_configs(dtype_configs)
291            .set_root_module(convs.root)
292            .set_reference_quantized_module(convs.reference)
293        )
294    return conv_configs
295
296
297def _get_binary_ops_configs() -> List[BackendPatternConfig]:
298    """
299    Return all configs related to binary ops.
300    """
301    dtype_configs = [
302        qnnpack_default_op_qint8_symmetric_dtype_config,
303        executorch_weighted_op_int8_dtype_config,
304    ]
305    num_tensor_args_to_observation_type_mapping = {
306        # TODO: this is not used right now since we have extra check in prepare
307        # will need to change this to NO_OBSERVER later after we implemented
308        # Tensor dtype inference properly
309        0: ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT,
310        1: ObservationType.OUTPUT_SHARE_OBSERVER_WITH_INPUT,
311        2: ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT,
312    }
313    binary_op_configs: List[BackendPatternConfig] = []
314    for op in [
315        operator.add,
316        torch.add,
317        operator.sub,
318        torch.sub,
319        operator.mul,
320        torch.mul,
321    ]:
322        bop_patterns = [
323            (op, torch.nn.ReLU),
324            (op, torch.nn.functional.relu),
325            (op, torch.relu),
326            op,
327        ]
328        for bop_pattern in bop_patterns:
329            binary_op_configs.append(
330                BackendPatternConfig(bop_pattern)
331                .set_dtype_configs(dtype_configs)  # noqa: E131
332                ._set_num_tensor_args_to_observation_type(
333                    num_tensor_args_to_observation_type_mapping
334                )
335            )
336    return binary_op_configs
337
338
339def _get_share_qparams_ops_configs() -> List[BackendPatternConfig]:
340    """
341    Return the operator configs for the operators that works for both float and quantized
342    input if input is quantized, the output Tensor shares the same quantization parameter
343    with input.
344
345    Example operator: avgpool2d, reshape, transpose, maxpool2d
346    Example observed operator:
347    observer_0 - avgpool2d - observer_0 (same observer instance as input)
348    """
349    observation_type = ObservationType.OUTPUT_SHARE_OBSERVER_WITH_INPUT
350    dtype_configs = [
351        qnnpack_default_op_qint8_symmetric_dtype_config,
352        executorch_default_op_quint8_dtype_config,
353    ]
354    share_qparams_ops = [
355        torch.nn.Flatten,
356        F.adaptive_avg_pool2d,
357        F.elu,
358        F.hardtanh,
359        F.max_pool2d,
360        F.pad,
361        F.relu,
362        F.relu6,
363        F.leaky_relu,
364        F.leaky_relu_,
365        torch.nn.AdaptiveAvgPool2d,
366        torch.nn.ConstantPad2d,
367        torch.nn.ELU,
368        torch.nn.MaxPool2d,
369        torch.nn.ReLU6,
370        torch.nn.Hardtanh,
371        torch.nn.LeakyReLU,
372        torch.clamp,
373        torch.flatten,
374        torch.mean,
375        torch.permute,
376        torch.permute_copy,
377        torch.squeeze,
378        "clamp",
379        "mean",
380        "permute",
381        "reshape",
382        "relu",
383        "relu_",
384        "squeeze",
385        "squeeze_",
386        "leaky_relu",
387    ]
388    share_qparams_op_configs: List[BackendPatternConfig] = []
389    for op in share_qparams_ops:
390        share_qparams_op_configs.append(
391            BackendPatternConfig(op)
392            .set_observation_type(observation_type)  # noqa: E131
393            .set_dtype_configs(dtype_configs)
394        )
395    return share_qparams_op_configs
396
397
398def _get_bn_configs() -> List[BackendPatternConfig]:
399    """
400    Return all configs related to batchnorm.
401    """
402    observation_type = ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT
403    dtype_configs = [
404        qnnpack_default_op_qint8_symmetric_dtype_config,
405        executorch_default_op_quint8_dtype_config,
406    ]
407    bn_configs = []
408    bn_configs.append(
409        BackendPatternConfig(nn.BatchNorm2d)
410        .set_observation_type(observation_type)  # noqa: E131
411        .set_dtype_configs(dtype_configs)
412    )
413    return bn_configs
414
415
416def _get_cat_configs() -> List[BackendPatternConfig]:
417    dtype_configs = [
418        qnnpack_default_op_qint8_symmetric_dtype_config,
419        executorch_default_op_quint8_dtype_config,
420    ]
421    cat_configs = []
422    cat_configs.append(
423        BackendPatternConfig(torch.cat)
424        .set_observation_type(ObservationType.OUTPUT_SHARE_OBSERVER_WITH_INPUT)
425        .set_dtype_configs(dtype_configs)
426    )
427    cat_configs.append(
428        BackendPatternConfig(torch.concat)
429        .set_observation_type(ObservationType.OUTPUT_SHARE_OBSERVER_WITH_INPUT)
430        .set_dtype_configs(dtype_configs)
431    )
432    cat_configs.append(
433        BackendPatternConfig(torch.concatenate)
434        .set_observation_type(ObservationType.OUTPUT_SHARE_OBSERVER_WITH_INPUT)
435        .set_dtype_configs(dtype_configs)
436    )
437    return cat_configs
438
439
440def _get_embedding_op_configs() -> List[BackendPatternConfig]:
441    dtype_configs = [
442        executorch_weight_only_quint8_dtype_config,
443    ]
444    embedding_op_configs = []
445    for embedding_op, qat_embedding_op, ref_embedding_op in [
446        (nn.Embedding, nnqat.Embedding, nnqr.Embedding),
447        (nn.EmbeddingBag, nnqat.EmbeddingBag, nnqr.EmbeddingBag),
448    ]:
449        embedding_op_configs.append(
450            BackendPatternConfig(embedding_op)
451            .set_observation_type(
452                ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT
453            )  # noqa: E131
454            .set_dtype_configs(dtype_configs)
455            .set_qat_module(qat_embedding_op)
456            .set_root_module(embedding_op)
457            .set_reference_quantized_module(ref_embedding_op)
458        )
459        # config for qat op
460        embedding_op_configs.append(
461            BackendPatternConfig(qat_embedding_op)
462            .set_observation_type(
463                ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT
464            )  # noqa: E131
465            .set_dtype_configs(dtype_configs)
466            .set_root_module(embedding_op)
467            .set_reference_quantized_module(ref_embedding_op)
468        )
469
470        # config for functional embedding
471        embedding_op_configs.append(
472            BackendPatternConfig(torch.nn.functional.embedding)
473            .set_observation_type(
474                ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT
475            )  # noqa: E131
476            .set_dtype_configs(dtype_configs)
477            ._set_input_type_to_index({"weight": 1})
478        )
479    return embedding_op_configs
480
481
482# =====================
483# |  BACKEND CONFIGS  |
484# =====================
485
486
487def get_executorch_backend_config() -> BackendConfig:
488    """
489    Return the `BackendConfig` for backends PyTorch lowers to through the Executorch stack.
490    """
491    return (
492        BackendConfig("executorch")
493        .set_backend_pattern_configs(_get_linear_configs())
494        .set_backend_pattern_configs(_get_conv_configs())
495        .set_backend_pattern_configs(_get_binary_ops_configs())
496        .set_backend_pattern_configs(_get_share_qparams_ops_configs())
497        .set_backend_pattern_configs(_get_bn_configs())
498        .set_backend_pattern_configs(_get_cat_configs())
499        .set_backend_pattern_configs(_get_embedding_op_configs())
500    )
501