xref: /aosp_15_r20/external/executorch/backends/vulkan/quantizer/vulkan_quantizer.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7# pyre-strict
8
9from __future__ import annotations
10
11import functools
12from typing import Any, Callable, Dict, Optional
13
14import torch
15from torch.ao.quantization.observer import MinMaxObserver, PerChannelMinMaxObserver
16from torch.ao.quantization.qconfig import _ObserverOrFakeQuantizeConstructor
17from torch.ao.quantization.quantizer import QuantizationSpec, Quantizer
18from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import (
19    _convert_scalars_to_attrs,
20    OP_TO_ANNOTATOR,
21    propagate_annotation,
22    QuantizationConfig,
23)
24from torch.fx import Node
25
26
27__all__ = [
28    "VulkanQuantizer",
29    "get_weight_quantization_config",
30]
31
32
33@functools.lru_cache
34def get_weight_quantization_config(
35    is_per_channel: bool = True,
36    weight_qmin: int = -128,
37    weight_qmax: int = 127,
38) -> QuantizationConfig:
39
40    weight_qscheme = (
41        torch.per_channel_symmetric if is_per_channel else torch.per_tensor_symmetric
42    )
43    weight_observer_or_fake_quant_ctr: _ObserverOrFakeQuantizeConstructor = (
44        PerChannelMinMaxObserver if is_per_channel else MinMaxObserver
45    )
46    extra_args: Dict[str, Any] = {"eps": 2**-12}
47
48    weight_quantization_spec = QuantizationSpec(
49        dtype=torch.int8,
50        quant_min=weight_qmin,
51        quant_max=weight_qmax,
52        qscheme=weight_qscheme,
53        ch_axis=0,
54        is_dynamic=False,
55        observer_or_fake_quant_ctr=weight_observer_or_fake_quant_ctr.with_args(
56            **extra_args
57        ),
58    )
59
60    quantization_config = QuantizationConfig(
61        input_activation=None,
62        output_activation=None,
63        weight=weight_quantization_spec,
64        bias=None,
65        is_qat=False,
66    )
67    return quantization_config
68
69
70_SUPPORTED_OPS = [
71    "linear",
72]
73
74
75class VulkanQuantizer(Quantizer):
76
77    def __init__(self) -> None:
78        super().__init__()
79        self.global_config: Optional[QuantizationConfig] = None
80
81    def set_global(self, quantization_config: QuantizationConfig) -> VulkanQuantizer:
82        self.global_config = quantization_config
83        return self
84
85    def transform_for_annotation(
86        self, model: torch.fx.GraphModule
87    ) -> torch.fx.GraphModule:
88        """Transforms scalar values to tensor attributes"""
89        return _convert_scalars_to_attrs(model)
90
91    def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
92        # currently only support static quant on Vulkan
93        model = self._annotate_for_static_quantization_config(model)
94        propagate_annotation(model)
95        return model
96
97    def _annotate_all_static_patterns(
98        self,
99        model: torch.fx.GraphModule,
100        quantization_config: Optional[QuantizationConfig],
101        filter_fn: Optional[Callable[[Node], bool]] = None,
102    ) -> torch.fx.GraphModule:
103        if quantization_config is None:
104            return model
105
106        for op in _SUPPORTED_OPS:
107            OP_TO_ANNOTATOR[op](model, quantization_config, filter_fn)
108        return model
109
110    def _annotate_for_static_quantization_config(
111        self, model: torch.fx.GraphModule
112    ) -> torch.fx.GraphModule:
113        self._annotate_all_static_patterns(
114            model,
115            self.global_config,
116        )
117        return model
118
119    def validate(self, model: torch.fx.GraphModule) -> None:
120        pass
121