1from __future__ import annotations 2 3from typing import Dict, List, TYPE_CHECKING 4 5from .quantizer import QuantizationAnnotation, Quantizer 6 7 8if TYPE_CHECKING: 9 import torch 10 from torch.fx import Node 11 12__all__ = [ 13 "ComposableQuantizer", 14] 15 16 17class ComposableQuantizer(Quantizer): 18 """ 19 ComposableQuantizer allows users to combine more than one quantizer into a single quantizer. 20 This allows users to quantize a model with multiple quantizers. E.g., embedding quantization 21 maybe supported by one quantizer while linear layers and other ops might be supported by another 22 quantizer. 23 24 ComposableQuantizer is initialized with a list of `Quantizer` instances. 25 The order of the composition matters since that is the order in which the quantizers will be 26 applies. 27 Example: 28 ``` 29 embedding_quantizer = EmbeddingQuantizer() 30 linear_quantizer = MyLinearQuantizer() 31 xnnpack_quantizer = XNNPackQuantizer() # to handle ops not quantized by previous two quantizers 32 composed_quantizer = ComposableQuantizer([embedding_quantizer, linear_quantizer, xnnpack_quantizer]) 33 prepared_m = prepare_pt2e(model, composed_quantizer) 34 ``` 35 """ 36 37 def __init__(self, quantizers: List[Quantizer]): 38 super().__init__() 39 self.quantizers = quantizers 40 self._graph_annotations: Dict[Node, QuantizationAnnotation] = {} 41 42 def _record_and_validate_annotations( 43 self, gm: torch.fx.GraphModule, quantizer: Quantizer 44 ) -> None: 45 for n in gm.graph.nodes: 46 if "quantization_annotation" in n.meta: 47 # check if the annotation has been changed by 48 # comparing QuantizationAnnotation object id 49 if n in self._graph_annotations and ( 50 id(self._graph_annotations[n]) 51 != id(n.meta["quantization_annotation"]) 52 ): 53 raise RuntimeError( 54 f"Quantizer {quantizer.__class__.__name__} has changed annotations on node {n}" 55 ) 56 else: 57 self._graph_annotations[n] = n.meta["quantization_annotation"] 58 else: 59 if n in self._graph_annotations: 60 raise RuntimeError( 61 f"Quantizer {quantizer.__class__.__name__} has removed annotations on node {n}" 62 ) 63 64 def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: 65 """just handling global spec for now""" 66 for quantizer in self.quantizers: 67 quantizer.annotate(model) 68 self._record_and_validate_annotations(model, quantizer) 69 return model 70 71 def transform_for_annotation( 72 self, model: torch.fx.GraphModule 73 ) -> torch.fx.GraphModule: 74 for quantizer in self.quantizers: 75 model = quantizer.transform_for_annotation(model) 76 return model 77 78 def validate(self, model: torch.fx.GraphModule) -> None: 79 pass 80