xref: /aosp_15_r20/external/executorch/backends/mediatek/quantizer/quantizer.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) 2024 MediaTek Inc.
2#
3# Licensed under the BSD License (the "License"); you may not use this file
4# except in compliance with the License. See the license file in the root
5# directory of this source tree for more details.
6
7from torch.ao.quantization.quantizer import Quantizer
8from torch.fx import GraphModule
9
10from .._passes.decompose_scaled_dot_product_attention import (
11    DecomposeScaledDotProductAttention,
12)
13from .annotator import annotate
14from .qconfig import get_quant_config, Precision
15
16
17class NeuropilotQuantizer(Quantizer):
18
19    def __init__(self):
20        super().__init__()
21
22        # TODO: Provide setter functions for those attributes
23        self._precision = Precision.A8W8
24        self._is_per_channel = True
25        self._is_qat = False
26
27    def setup_precision(self, precision: Precision) -> None:
28        self._precision = precision
29
30    def transform_for_annotation(self, model: GraphModule) -> GraphModule:
31        model = DecomposeScaledDotProductAttention()(model).graph_module
32        return model
33
34    def annotate(self, model: GraphModule) -> GraphModule:
35        self._annotate(model)
36        return model
37
38    def validate(self, model: GraphModule) -> None:
39        pass
40
41    def _annotate(self, gm: GraphModule) -> None:
42        quant_config = get_quant_config(
43            self._precision, self._is_per_channel, self._is_qat
44        )
45        annotate(gm.graph, quant_config)
46