1# mypy: allow-untyped-defs 2from abc import ABC, abstractmethod 3from dataclasses import dataclass, field 4from typing import Callable, Dict, List, Optional, Tuple, Union 5 6import torch 7from torch import Tensor 8from torch.ao.quantization import ObserverOrFakeQuantize 9from torch.ao.quantization.qconfig import _ObserverOrFakeQuantizeConstructor 10from torch.fx import Node 11 12 13__all__ = [ 14 "Quantizer", 15 "QuantizationSpecBase", 16 "QuantizationSpec", 17 "FixedQParamsQuantizationSpec", 18 "EdgeOrNode", 19 "SharedQuantizationSpec", 20 "DerivedQuantizationSpec", 21 "QuantizationAnnotation", 22] 23 24 25class QuantizationSpecBase(ABC): # noqa: B024 26 """Base class for different types of quantization specs that allows users to 27 specify how to quantize a Tensor (input/output of a Node) in the model 28 """ 29 30 31@dataclass(eq=True, frozen=True) 32class QuantizationSpec(QuantizationSpecBase): 33 """Quantization spec for common operators that allows user to specify how to 34 quantize a Tensor, this includes dtype, quant_min, quant_max etc. 35 """ 36 37 dtype: torch.dtype 38 # observer or fake_quantize constructor such as 39 # MinMaxObserver, PerChannelHistogramObserver etc. 40 # or we can attach some custom args to them 41 # e.g. MinMaxObserver.with_args(eps=eps) 42 observer_or_fake_quant_ctr: _ObserverOrFakeQuantizeConstructor 43 quant_min: Optional[int] = None 44 quant_max: Optional[int] = None 45 qscheme: Optional[torch.qscheme] = None 46 ch_axis: Optional[int] = None 47 is_dynamic: bool = False 48 49 def __post_init__(self): 50 # TODO: add init for quant_min/quant_max 51 # quant_min must be less than quant_max 52 if ( 53 self.quant_min is not None 54 and self.quant_max is not None 55 and self.quant_min > self.quant_max 56 ): 57 raise ValueError( 58 f"quant_min {self.quant_min} must be <= quant_max {self.quant_max}." 59 ) 60 61 # ch_axis must be less than the number of channels 62 # but no way to check here. Just check that it is not < 0. 63 if self.ch_axis is not None and self.ch_axis < 0: 64 raise ValueError("Ch_axis is < 0.") 65 66 67@dataclass(eq=True, frozen=True) 68class FixedQParamsQuantizationSpec(QuantizationSpecBase): 69 dtype: torch.dtype 70 scale: float 71 zero_point: int 72 quant_min: Optional[int] = None 73 quant_max: Optional[int] = None 74 qscheme: Optional[torch.qscheme] = None 75 is_dynamic: bool = False 76 77 78""" 79The way we refer to other points of quantization in the graph will be either 80an input edge or an output value 81input edge is the connection between input node and the node consuming the input, so it's a Tuple[Node, Node] 82output value is an fx Node 83""" 84EdgeOrNode = Union[Tuple[Node, Node], Node] 85EdgeOrNode.__module__ = "torch.ao.quantization.quantizer.quantizer" 86 87 88@dataclass(eq=True, frozen=True) 89class SharedQuantizationSpec(QuantizationSpecBase): 90 """ 91 Quantization spec for the Tensors whose quantization parameters are shared with other Tensors 92 """ 93 94 # the edge or node to share observer or fake quant instances with 95 edge_or_node: EdgeOrNode 96 97 98@dataclass(eq=True, frozen=True) 99class DerivedQuantizationSpec(QuantizationSpecBase): 100 """Quantization spec for the Tensors whose quantization parameters are derived from other Tensors""" 101 102 derived_from: List[EdgeOrNode] 103 derive_qparams_fn: Callable[[List[ObserverOrFakeQuantize]], Tuple[Tensor, Tensor]] 104 dtype: torch.dtype 105 quant_min: Optional[int] = None 106 quant_max: Optional[int] = None 107 qscheme: Optional[torch.qscheme] = None 108 ch_axis: Optional[int] = None 109 is_dynamic: bool = False 110 111 112@dataclass 113class QuantizationAnnotation: 114 """How are input arguemnt or output should be quantized, 115 expressed as QuantizationSpec, this corresponds to how a Tensor in the 116 operator Graph is observed (PTQ) or fake quantized (QAT) 117 """ 118 119 # a map from torch.fx.Node to a type of QuantizationSpecBase 120 input_qspec_map: Dict[Node, Optional[QuantizationSpecBase]] = field( 121 default_factory=dict 122 ) 123 124 # How the output of this node is quantized, expressed as QuantizationSpec 125 # TODO: change the value to QuantizationSpec in a separate PR 126 output_qspec: Optional[QuantizationSpecBase] = None 127 128 # For a Node: node1 and edge: (node1, node2), since they are observing the same 129 # Tensor, we may want to implicitly share observers, this flag allows people to 130 # turn off this behavior for the output of the node 131 allow_implicit_sharing: bool = True 132 133 # whether the node is annotated or not 134 _annotated: bool = False 135 136 137class Quantizer(ABC): 138 def transform_for_annotation( 139 self, model: torch.fx.GraphModule 140 ) -> torch.fx.GraphModule: 141 """Allows for user defined transforms to run before annotating the graph. 142 This allows quantizer to allow quantizing part of the model that are otherwise not quantizable. 143 For example quantizer can 144 a) decompose a compound operator like scaled dot product attention, 145 into bmm and softmax if quantizer knows how to quantize bmm/softmax but not sdpa 146 or b) transform scalars to tensor to allow quantizing scalares. 147 148 Note: this is an optional method 149 """ 150 return model 151 152 # annotate nodes in the graph with observer or fake quant constructors 153 # to convey the desired way of quantization 154 @abstractmethod 155 def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: 156 pass 157 158 # validate the annotated graph is supported by the backend 159 @abstractmethod 160 def validate(self, model: torch.fx.GraphModule) -> None: 161 pass 162