xref: /aosp_15_r20/external/pytorch/torch/ao/quantization/quantizer/composable_quantizer.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1from __future__ import annotations
2
3from typing import Dict, List, TYPE_CHECKING
4
5from .quantizer import QuantizationAnnotation, Quantizer
6
7
8if TYPE_CHECKING:
9    import torch
10    from torch.fx import Node
11
12__all__ = [
13    "ComposableQuantizer",
14]
15
16
17class ComposableQuantizer(Quantizer):
18    """
19    ComposableQuantizer allows users to combine more than one quantizer into a single quantizer.
20    This allows users to quantize a model with multiple quantizers. E.g., embedding quantization
21    maybe supported by one quantizer while linear layers and other ops might be supported by another
22    quantizer.
23
24    ComposableQuantizer is initialized with a list of `Quantizer` instances.
25    The order of the composition matters since that is the order in which the quantizers will be
26    applies.
27    Example:
28    ```
29    embedding_quantizer = EmbeddingQuantizer()
30    linear_quantizer = MyLinearQuantizer()
31    xnnpack_quantizer = XNNPackQuantizer() # to handle ops not quantized by previous two quantizers
32    composed_quantizer = ComposableQuantizer([embedding_quantizer, linear_quantizer, xnnpack_quantizer])
33    prepared_m = prepare_pt2e(model, composed_quantizer)
34    ```
35    """
36
37    def __init__(self, quantizers: List[Quantizer]):
38        super().__init__()
39        self.quantizers = quantizers
40        self._graph_annotations: Dict[Node, QuantizationAnnotation] = {}
41
42    def _record_and_validate_annotations(
43        self, gm: torch.fx.GraphModule, quantizer: Quantizer
44    ) -> None:
45        for n in gm.graph.nodes:
46            if "quantization_annotation" in n.meta:
47                # check if the annotation has been changed by
48                # comparing QuantizationAnnotation object id
49                if n in self._graph_annotations and (
50                    id(self._graph_annotations[n])
51                    != id(n.meta["quantization_annotation"])
52                ):
53                    raise RuntimeError(
54                        f"Quantizer {quantizer.__class__.__name__} has changed annotations on node {n}"
55                    )
56                else:
57                    self._graph_annotations[n] = n.meta["quantization_annotation"]
58            else:
59                if n in self._graph_annotations:
60                    raise RuntimeError(
61                        f"Quantizer {quantizer.__class__.__name__} has removed annotations on node {n}"
62                    )
63
64    def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
65        """just handling global spec for now"""
66        for quantizer in self.quantizers:
67            quantizer.annotate(model)
68            self._record_and_validate_annotations(model, quantizer)
69        return model
70
71    def transform_for_annotation(
72        self, model: torch.fx.GraphModule
73    ) -> torch.fx.GraphModule:
74        for quantizer in self.quantizers:
75            model = quantizer.transform_for_annotation(model)
76        return model
77
78    def validate(self, model: torch.fx.GraphModule) -> None:
79        pass
80