xref: /aosp_15_r20/external/pytorch/torch/ao/quantization/quantizer/embedding_quantizer.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2from __future__ import annotations
3
4import copy
5from typing import List, Set
6
7import torch
8import torch.nn.functional as F
9from torch.ao.quantization.observer import PerChannelMinMaxObserver
10from torch.ao.quantization.quantizer.quantizer import (
11    QuantizationAnnotation,
12    QuantizationSpec,
13    Quantizer,
14)
15from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import (
16    OperatorConfig,
17    OperatorPatternType,
18    QuantizationConfig,
19)
20
21
22__all__ = [
23    "get_embedding_operators_config",
24    "EmbeddingQuantizer",
25]
26
27
28def get_embedding_operators_config() -> OperatorConfig:
29    weight_quantization_spec = QuantizationSpec(
30        dtype=torch.uint8,
31        qscheme=torch.per_channel_affine_float_qparams,
32        ch_axis=0,
33        observer_or_fake_quant_ctr=PerChannelMinMaxObserver.with_args(eps=2**-12),
34    )
35    quantization_config = QuantizationConfig(None, None, weight_quantization_spec, None)
36    ops: List[OperatorPatternType] = [[torch.nn.Embedding]]
37    ops.append([F.embedding])
38    supported_config_and_operators = OperatorConfig(
39        config=quantization_config, operators=ops
40    )
41    return copy.deepcopy(supported_config_and_operators)
42
43
44class EmbeddingQuantizer(Quantizer):
45    def __init__(self) -> None:
46        super().__init__()
47
48    @classmethod
49    def get_supported_quantization_configs(cls) -> List[QuantizationConfig]:
50        op_configs: Set[QuantizationConfig] = {
51            spec for spec, _ in cls.get_supported_operators()
52        }
53        return list(op_configs)
54
55    @classmethod
56    def get_supported_operator_for_quantization_config(
57        cls, quantization_config: QuantizationConfig
58    ) -> List[OperatorPatternType]:
59        for config, ops in cls.get_supported_operators():
60            # note: this assumes each entry in cls.supported_spec_and_operators
61            # corresponds to one spec, e.g. we don't have
62            # [(spec1, op_list1), (spec1, op_list2), (spec2, op_list3)]
63            # where the first and second entry have the same spec but did not
64            # merge the op list
65            if config == quantization_config:
66                return ops
67        return []
68
69    def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
70        """just handling global spec for now"""
71        self._annotate_embedding_ops(model.graph)
72        return model
73
74    def _annotate_embedding_ops(self, graph: torch.fx.Graph) -> None:
75        embedding_config: OperatorConfig = get_embedding_operators_config()
76        for node in graph.nodes:
77            # Keep node parsing based annotations instead of module partitioners
78            # just as an example of alternate ways of annotating
79            if (
80                node.op == "call_function"
81                and node.target == torch.ops.aten.embedding.default
82            ):
83                if embedding_config.config.weight is None:
84                    raise ValueError(
85                        "Embedding config must have a valid weight quantization spec."
86                    )
87                node.meta["quantization_annotation"] = QuantizationAnnotation(
88                    input_qspec_map={
89                        node.args[0]: embedding_config.config.weight,
90                    }
91                )
92
93    def validate(self, model: torch.fx.GraphModule) -> None:
94        pass
95
96    @classmethod
97    def get_supported_operators(cls) -> List[OperatorConfig]:
98        return [get_embedding_operators_config()]
99