1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7# pyre-strict 8 9from __future__ import annotations 10 11import functools 12from typing import Any, Callable, Dict, Optional 13 14import torch 15from torch.ao.quantization.observer import MinMaxObserver, PerChannelMinMaxObserver 16from torch.ao.quantization.qconfig import _ObserverOrFakeQuantizeConstructor 17from torch.ao.quantization.quantizer import QuantizationSpec, Quantizer 18from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import ( 19 _convert_scalars_to_attrs, 20 OP_TO_ANNOTATOR, 21 propagate_annotation, 22 QuantizationConfig, 23) 24from torch.fx import Node 25 26 27__all__ = [ 28 "VulkanQuantizer", 29 "get_weight_quantization_config", 30] 31 32 33@functools.lru_cache 34def get_weight_quantization_config( 35 is_per_channel: bool = True, 36 weight_qmin: int = -128, 37 weight_qmax: int = 127, 38) -> QuantizationConfig: 39 40 weight_qscheme = ( 41 torch.per_channel_symmetric if is_per_channel else torch.per_tensor_symmetric 42 ) 43 weight_observer_or_fake_quant_ctr: _ObserverOrFakeQuantizeConstructor = ( 44 PerChannelMinMaxObserver if is_per_channel else MinMaxObserver 45 ) 46 extra_args: Dict[str, Any] = {"eps": 2**-12} 47 48 weight_quantization_spec = QuantizationSpec( 49 dtype=torch.int8, 50 quant_min=weight_qmin, 51 quant_max=weight_qmax, 52 qscheme=weight_qscheme, 53 ch_axis=0, 54 is_dynamic=False, 55 observer_or_fake_quant_ctr=weight_observer_or_fake_quant_ctr.with_args( 56 **extra_args 57 ), 58 ) 59 60 quantization_config = QuantizationConfig( 61 input_activation=None, 62 output_activation=None, 63 weight=weight_quantization_spec, 64 bias=None, 65 is_qat=False, 66 ) 67 return quantization_config 68 69 70_SUPPORTED_OPS = [ 71 "linear", 72] 73 74 75class VulkanQuantizer(Quantizer): 76 77 def __init__(self) -> None: 78 super().__init__() 79 self.global_config: Optional[QuantizationConfig] = None 80 81 def set_global(self, quantization_config: QuantizationConfig) -> VulkanQuantizer: 82 self.global_config = quantization_config 83 return self 84 85 def transform_for_annotation( 86 self, model: torch.fx.GraphModule 87 ) -> torch.fx.GraphModule: 88 """Transforms scalar values to tensor attributes""" 89 return _convert_scalars_to_attrs(model) 90 91 def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: 92 # currently only support static quant on Vulkan 93 model = self._annotate_for_static_quantization_config(model) 94 propagate_annotation(model) 95 return model 96 97 def _annotate_all_static_patterns( 98 self, 99 model: torch.fx.GraphModule, 100 quantization_config: Optional[QuantizationConfig], 101 filter_fn: Optional[Callable[[Node], bool]] = None, 102 ) -> torch.fx.GraphModule: 103 if quantization_config is None: 104 return model 105 106 for op in _SUPPORTED_OPS: 107 OP_TO_ANNOTATOR[op](model, quantization_config, filter_fn) 108 return model 109 110 def _annotate_for_static_quantization_config( 111 self, model: torch.fx.GraphModule 112 ) -> torch.fx.GraphModule: 113 self._annotate_all_static_patterns( 114 model, 115 self.global_config, 116 ) 117 return model 118 119 def validate(self, model: torch.fx.GraphModule) -> None: 120 pass 121