xref: /aosp_15_r20/external/pytorch/torch/ao/quantization/quantizer/quantizer.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2from abc import ABC, abstractmethod
3from dataclasses import dataclass, field
4from typing import Callable, Dict, List, Optional, Tuple, Union
5
6import torch
7from torch import Tensor
8from torch.ao.quantization import ObserverOrFakeQuantize
9from torch.ao.quantization.qconfig import _ObserverOrFakeQuantizeConstructor
10from torch.fx import Node
11
12
13__all__ = [
14    "Quantizer",
15    "QuantizationSpecBase",
16    "QuantizationSpec",
17    "FixedQParamsQuantizationSpec",
18    "EdgeOrNode",
19    "SharedQuantizationSpec",
20    "DerivedQuantizationSpec",
21    "QuantizationAnnotation",
22]
23
24
25class QuantizationSpecBase(ABC):  # noqa: B024
26    """Base class for different types of quantization specs that allows users to
27    specify how to quantize a Tensor (input/output of a Node) in the model
28    """
29
30
31@dataclass(eq=True, frozen=True)
32class QuantizationSpec(QuantizationSpecBase):
33    """Quantization spec for common operators that allows user to specify how to
34    quantize a Tensor, this includes dtype, quant_min, quant_max etc.
35    """
36
37    dtype: torch.dtype
38    # observer or fake_quantize constructor such as
39    # MinMaxObserver, PerChannelHistogramObserver etc.
40    # or we can attach some custom args to them
41    # e.g. MinMaxObserver.with_args(eps=eps)
42    observer_or_fake_quant_ctr: _ObserverOrFakeQuantizeConstructor
43    quant_min: Optional[int] = None
44    quant_max: Optional[int] = None
45    qscheme: Optional[torch.qscheme] = None
46    ch_axis: Optional[int] = None
47    is_dynamic: bool = False
48
49    def __post_init__(self):
50        # TODO: add init for quant_min/quant_max
51        # quant_min must be less than quant_max
52        if (
53            self.quant_min is not None
54            and self.quant_max is not None
55            and self.quant_min > self.quant_max
56        ):
57            raise ValueError(
58                f"quant_min {self.quant_min} must be <= quant_max {self.quant_max}."
59            )
60
61        # ch_axis must be less than the number of channels
62        # but no way to check here. Just check that it is not < 0.
63        if self.ch_axis is not None and self.ch_axis < 0:
64            raise ValueError("Ch_axis is < 0.")
65
66
67@dataclass(eq=True, frozen=True)
68class FixedQParamsQuantizationSpec(QuantizationSpecBase):
69    dtype: torch.dtype
70    scale: float
71    zero_point: int
72    quant_min: Optional[int] = None
73    quant_max: Optional[int] = None
74    qscheme: Optional[torch.qscheme] = None
75    is_dynamic: bool = False
76
77
78"""
79The way we refer to other points of quantization in the graph will be either
80an input edge or an output value
81input edge is the connection between input node and the node consuming the input, so it's a Tuple[Node, Node]
82output value is an fx Node
83"""
84EdgeOrNode = Union[Tuple[Node, Node], Node]
85EdgeOrNode.__module__ = "torch.ao.quantization.quantizer.quantizer"
86
87
88@dataclass(eq=True, frozen=True)
89class SharedQuantizationSpec(QuantizationSpecBase):
90    """
91    Quantization spec for the Tensors whose quantization parameters are shared with other Tensors
92    """
93
94    # the edge or node to share observer or fake quant instances with
95    edge_or_node: EdgeOrNode
96
97
98@dataclass(eq=True, frozen=True)
99class DerivedQuantizationSpec(QuantizationSpecBase):
100    """Quantization spec for the Tensors whose quantization parameters are derived from other Tensors"""
101
102    derived_from: List[EdgeOrNode]
103    derive_qparams_fn: Callable[[List[ObserverOrFakeQuantize]], Tuple[Tensor, Tensor]]
104    dtype: torch.dtype
105    quant_min: Optional[int] = None
106    quant_max: Optional[int] = None
107    qscheme: Optional[torch.qscheme] = None
108    ch_axis: Optional[int] = None
109    is_dynamic: bool = False
110
111
112@dataclass
113class QuantizationAnnotation:
114    """How are input arguemnt or output should be quantized,
115    expressed as QuantizationSpec, this corresponds to how a Tensor in the
116    operator Graph is observed (PTQ) or fake quantized (QAT)
117    """
118
119    # a map from torch.fx.Node to a type of QuantizationSpecBase
120    input_qspec_map: Dict[Node, Optional[QuantizationSpecBase]] = field(
121        default_factory=dict
122    )
123
124    # How the output of this node is quantized, expressed as QuantizationSpec
125    # TODO: change the value to QuantizationSpec in a separate PR
126    output_qspec: Optional[QuantizationSpecBase] = None
127
128    # For a Node: node1 and edge: (node1, node2), since they are observing the same
129    # Tensor, we may want to implicitly share observers, this flag allows people to
130    # turn off this behavior for the output of the node
131    allow_implicit_sharing: bool = True
132
133    # whether the node is annotated or not
134    _annotated: bool = False
135
136
137class Quantizer(ABC):
138    def transform_for_annotation(
139        self, model: torch.fx.GraphModule
140    ) -> torch.fx.GraphModule:
141        """Allows for user defined transforms to run before annotating the graph.
142        This allows quantizer to allow quantizing part of the model that are otherwise not quantizable.
143        For example quantizer can
144        a) decompose a compound operator like scaled dot product attention,
145        into bmm and softmax if quantizer knows how to quantize bmm/softmax but not sdpa
146        or b) transform scalars to tensor to allow quantizing scalares.
147
148        Note: this is an optional method
149        """
150        return model
151
152    # annotate nodes in the graph with observer or fake quant constructors
153    # to convey the desired way of quantization
154    @abstractmethod
155    def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
156        pass
157
158    # validate the annotated graph is supported by the backend
159    @abstractmethod
160    def validate(self, model: torch.fx.GraphModule) -> None:
161        pass
162