xref: /aosp_15_r20/external/executorch/backends/example/example_quantizer.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7from dataclasses import dataclass
8from typing import List
9
10import torch
11from executorch.backends.example.example_operators.ops import module_to_annotator
12from torch import fx
13from torch.ao.quantization.observer import HistogramObserver, MinMaxObserver
14from torch.ao.quantization.pt2e.graph_utils import find_sequential_partitions
15from torch.ao.quantization.quantizer import QuantizationSpec, Quantizer
16from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import OperatorConfig
17
18
19def get_uint8_tensor_spec(observer_or_fake_quant_ctr):
20    return QuantizationSpec(
21        dtype=torch.uint8,
22        quant_min=0,
23        quant_max=255,
24        qscheme=torch.per_tensor_affine,
25        is_dynamic=False,
26        observer_or_fake_quant_ctr=observer_or_fake_quant_ctr,
27    )
28
29
30@dataclass
31class ExampleQuantConfig:
32    input_quant_spec: QuantizationSpec
33    output_quant_spec: QuantizationSpec
34    weight_quant_spec: QuantizationSpec
35    bias_quant_spec: QuantizationSpec
36
37
38default_static_config = ExampleQuantConfig(
39    get_uint8_tensor_spec(HistogramObserver),
40    get_uint8_tensor_spec(HistogramObserver),
41    get_uint8_tensor_spec(MinMaxObserver),
42    # pyre-fixme[6]: Incompatible parameter type [6]: In call `ExampleQuantConfig.__init__`, for 4th positional argument, expected `QuantizationSpec` but got `None`.
43    None,  # #bias quantization can be configured here or done in a pass later on.
44)
45
46
47def check_for_outside_users(partitions) -> bool:
48    """
49    Make sure that all the users of this partiton are within the delegatable subgraph,
50    except the last partition. If we quantize partitions that have users outside this
51    subgraph then delegation of this partition to the backend will not be possible.
52    """
53    for source_partition in partitions[:-1]:
54        if len(source_partition.output_nodes) != 1:
55            return True
56        if len(source_partition.output_nodes[0].users) != 1:
57            return True
58    return False
59
60
61class ExampleQuantizer(Quantizer):
62    def __init__(self, quantizer_supported_modules=None, quant_config=None):
63        super().__init__()
64        if quantizer_supported_modules is not None:
65            self.quantizer_supported_modules = quantizer_supported_modules
66            for module in self.quantizer_supported_modules:
67                if module not in module_to_annotator.keys():
68                    assert 0, f"{module} is not supported by this quantizer"
69        else:
70            self.quantizer_supported_modules = module_to_annotator.keys()
71        if quant_config is not None:
72            self.quant_config = quant_config
73        else:
74            self.quant_config = default_static_config
75
76    def annotate(self, model):
77        for supported_modules in self.quantizer_supported_modules:
78            # print("supported modules: ", supported_modules)
79            fused_partitions = find_sequential_partitions(
80                model,
81                list(supported_modules),
82            )
83
84            for partitions in fused_partitions:
85                if check_for_outside_users(partitions):
86                    continue
87
88                source_module_list = ()
89                for partition in partitions:
90                    source_module_list += (partition,)
91
92                annotator = module_to_annotator[supported_modules].annotate_handle
93                annotator(partitions, self.quant_config)
94
95        return model
96
97    def validate(self, model: fx.GraphModule) -> None:
98        pass
99
100    @classmethod
101    def get_supported_operators(cls) -> List[OperatorConfig]:
102        return []
103