1# Copyright (c) 2024 MediaTek Inc. 2# 3# Licensed under the BSD License (the "License"); you may not use this file 4# except in compliance with the License. See the license file in the root 5# directory of this source tree for more details. 6 7from torch.ao.quantization.quantizer import Quantizer 8from torch.fx import GraphModule 9 10from .._passes.decompose_scaled_dot_product_attention import ( 11 DecomposeScaledDotProductAttention, 12) 13from .annotator import annotate 14from .qconfig import get_quant_config, Precision 15 16 17class NeuropilotQuantizer(Quantizer): 18 19 def __init__(self): 20 super().__init__() 21 22 # TODO: Provide setter functions for those attributes 23 self._precision = Precision.A8W8 24 self._is_per_channel = True 25 self._is_qat = False 26 27 def setup_precision(self, precision: Precision) -> None: 28 self._precision = precision 29 30 def transform_for_annotation(self, model: GraphModule) -> GraphModule: 31 model = DecomposeScaledDotProductAttention()(model).graph_module 32 return model 33 34 def annotate(self, model: GraphModule) -> GraphModule: 35 self._annotate(model) 36 return model 37 38 def validate(self, model: GraphModule) -> None: 39 pass 40 41 def _annotate(self, gm: GraphModule) -> None: 42 quant_config = get_quant_config( 43 self._precision, self._is_per_channel, self._is_qat 44 ) 45 annotate(gm.graph, quant_config) 46