1*523fa7a6SAndroid Build Coastguard Worker# Copyright (c) 2024 MediaTek Inc. 2*523fa7a6SAndroid Build Coastguard Worker# 3*523fa7a6SAndroid Build Coastguard Worker# Licensed under the BSD License (the "License"); you may not use this file 4*523fa7a6SAndroid Build Coastguard Worker# except in compliance with the License. See the license file in the root 5*523fa7a6SAndroid Build Coastguard Worker# directory of this source tree for more details. 6*523fa7a6SAndroid Build Coastguard Worker 7*523fa7a6SAndroid Build Coastguard Workerfrom torch.ao.quantization.quantizer import Quantizer 8*523fa7a6SAndroid Build Coastguard Workerfrom torch.fx import GraphModule 9*523fa7a6SAndroid Build Coastguard Worker 10*523fa7a6SAndroid Build Coastguard Workerfrom .._passes.decompose_scaled_dot_product_attention import ( 11*523fa7a6SAndroid Build Coastguard Worker DecomposeScaledDotProductAttention, 12*523fa7a6SAndroid Build Coastguard Worker) 13*523fa7a6SAndroid Build Coastguard Workerfrom .annotator import annotate 14*523fa7a6SAndroid Build Coastguard Workerfrom .qconfig import get_quant_config, Precision 15*523fa7a6SAndroid Build Coastguard Worker 16*523fa7a6SAndroid Build Coastguard Worker 17*523fa7a6SAndroid Build Coastguard Workerclass NeuropilotQuantizer(Quantizer): 18*523fa7a6SAndroid Build Coastguard Worker 19*523fa7a6SAndroid Build Coastguard Worker def __init__(self): 20*523fa7a6SAndroid Build Coastguard Worker super().__init__() 21*523fa7a6SAndroid Build Coastguard Worker 22*523fa7a6SAndroid Build Coastguard Worker # TODO: Provide setter functions for those attributes 23*523fa7a6SAndroid Build Coastguard Worker self._precision = Precision.A8W8 24*523fa7a6SAndroid Build Coastguard Worker self._is_per_channel = True 25*523fa7a6SAndroid Build Coastguard Worker self._is_qat = False 26*523fa7a6SAndroid Build Coastguard Worker 27*523fa7a6SAndroid Build Coastguard Worker def setup_precision(self, precision: Precision) -> None: 28*523fa7a6SAndroid Build Coastguard Worker self._precision = precision 29*523fa7a6SAndroid Build Coastguard Worker 30*523fa7a6SAndroid Build Coastguard Worker def transform_for_annotation(self, model: GraphModule) -> GraphModule: 31*523fa7a6SAndroid Build Coastguard Worker model = DecomposeScaledDotProductAttention()(model).graph_module 32*523fa7a6SAndroid Build Coastguard Worker return model 33*523fa7a6SAndroid Build Coastguard Worker 34*523fa7a6SAndroid Build Coastguard Worker def annotate(self, model: GraphModule) -> GraphModule: 35*523fa7a6SAndroid Build Coastguard Worker self._annotate(model) 36*523fa7a6SAndroid Build Coastguard Worker return model 37*523fa7a6SAndroid Build Coastguard Worker 38*523fa7a6SAndroid Build Coastguard Worker def validate(self, model: GraphModule) -> None: 39*523fa7a6SAndroid Build Coastguard Worker pass 40*523fa7a6SAndroid Build Coastguard Worker 41*523fa7a6SAndroid Build Coastguard Worker def _annotate(self, gm: GraphModule) -> None: 42*523fa7a6SAndroid Build Coastguard Worker quant_config = get_quant_config( 43*523fa7a6SAndroid Build Coastguard Worker self._precision, self._is_per_channel, self._is_qat 44*523fa7a6SAndroid Build Coastguard Worker ) 45*523fa7a6SAndroid Build Coastguard Worker annotate(gm.graph, quant_config) 46