xref: /aosp_15_r20/external/executorch/backends/mediatek/quantizer/quantizer.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1*523fa7a6SAndroid Build Coastguard Worker# Copyright (c) 2024 MediaTek Inc.
2*523fa7a6SAndroid Build Coastguard Worker#
3*523fa7a6SAndroid Build Coastguard Worker# Licensed under the BSD License (the "License"); you may not use this file
4*523fa7a6SAndroid Build Coastguard Worker# except in compliance with the License. See the license file in the root
5*523fa7a6SAndroid Build Coastguard Worker# directory of this source tree for more details.
6*523fa7a6SAndroid Build Coastguard Worker
7*523fa7a6SAndroid Build Coastguard Workerfrom torch.ao.quantization.quantizer import Quantizer
8*523fa7a6SAndroid Build Coastguard Workerfrom torch.fx import GraphModule
9*523fa7a6SAndroid Build Coastguard Worker
10*523fa7a6SAndroid Build Coastguard Workerfrom .._passes.decompose_scaled_dot_product_attention import (
11*523fa7a6SAndroid Build Coastguard Worker    DecomposeScaledDotProductAttention,
12*523fa7a6SAndroid Build Coastguard Worker)
13*523fa7a6SAndroid Build Coastguard Workerfrom .annotator import annotate
14*523fa7a6SAndroid Build Coastguard Workerfrom .qconfig import get_quant_config, Precision
15*523fa7a6SAndroid Build Coastguard Worker
16*523fa7a6SAndroid Build Coastguard Worker
17*523fa7a6SAndroid Build Coastguard Workerclass NeuropilotQuantizer(Quantizer):
18*523fa7a6SAndroid Build Coastguard Worker
19*523fa7a6SAndroid Build Coastguard Worker    def __init__(self):
20*523fa7a6SAndroid Build Coastguard Worker        super().__init__()
21*523fa7a6SAndroid Build Coastguard Worker
22*523fa7a6SAndroid Build Coastguard Worker        # TODO: Provide setter functions for those attributes
23*523fa7a6SAndroid Build Coastguard Worker        self._precision = Precision.A8W8
24*523fa7a6SAndroid Build Coastguard Worker        self._is_per_channel = True
25*523fa7a6SAndroid Build Coastguard Worker        self._is_qat = False
26*523fa7a6SAndroid Build Coastguard Worker
27*523fa7a6SAndroid Build Coastguard Worker    def setup_precision(self, precision: Precision) -> None:
28*523fa7a6SAndroid Build Coastguard Worker        self._precision = precision
29*523fa7a6SAndroid Build Coastguard Worker
30*523fa7a6SAndroid Build Coastguard Worker    def transform_for_annotation(self, model: GraphModule) -> GraphModule:
31*523fa7a6SAndroid Build Coastguard Worker        model = DecomposeScaledDotProductAttention()(model).graph_module
32*523fa7a6SAndroid Build Coastguard Worker        return model
33*523fa7a6SAndroid Build Coastguard Worker
34*523fa7a6SAndroid Build Coastguard Worker    def annotate(self, model: GraphModule) -> GraphModule:
35*523fa7a6SAndroid Build Coastguard Worker        self._annotate(model)
36*523fa7a6SAndroid Build Coastguard Worker        return model
37*523fa7a6SAndroid Build Coastguard Worker
38*523fa7a6SAndroid Build Coastguard Worker    def validate(self, model: GraphModule) -> None:
39*523fa7a6SAndroid Build Coastguard Worker        pass
40*523fa7a6SAndroid Build Coastguard Worker
41*523fa7a6SAndroid Build Coastguard Worker    def _annotate(self, gm: GraphModule) -> None:
42*523fa7a6SAndroid Build Coastguard Worker        quant_config = get_quant_config(
43*523fa7a6SAndroid Build Coastguard Worker            self._precision, self._is_per_channel, self._is_qat
44*523fa7a6SAndroid Build Coastguard Worker        )
45*523fa7a6SAndroid Build Coastguard Worker        annotate(gm.graph, quant_config)
46