# Copyright (c) 2024 MediaTek Inc. # # Licensed under the BSD License (the "License"); you may not use this file # except in compliance with the License. See the license file in the root # directory of this source tree for more details. from torch.ao.quantization.quantizer import Quantizer from torch.fx import GraphModule from .._passes.decompose_scaled_dot_product_attention import ( DecomposeScaledDotProductAttention, ) from .annotator import annotate from .qconfig import get_quant_config, Precision class NeuropilotQuantizer(Quantizer): def __init__(self): super().__init__() # TODO: Provide setter functions for those attributes self._precision = Precision.A8W8 self._is_per_channel = True self._is_qat = False def setup_precision(self, precision: Precision) -> None: self._precision = precision def transform_for_annotation(self, model: GraphModule) -> GraphModule: model = DecomposeScaledDotProductAttention()(model).graph_module return model def annotate(self, model: GraphModule) -> GraphModule: self._annotate(model) return model def validate(self, model: GraphModule) -> None: pass def _annotate(self, gm: GraphModule) -> None: quant_config = get_quant_config( self._precision, self._is_per_channel, self._is_qat ) annotate(gm.graph, quant_config)