1# mypy: allow-untyped-defs 2from __future__ import annotations 3 4import copy 5from typing import List, Set 6 7import torch 8import torch.nn.functional as F 9from torch.ao.quantization.observer import PerChannelMinMaxObserver 10from torch.ao.quantization.quantizer.quantizer import ( 11 QuantizationAnnotation, 12 QuantizationSpec, 13 Quantizer, 14) 15from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import ( 16 OperatorConfig, 17 OperatorPatternType, 18 QuantizationConfig, 19) 20 21 22__all__ = [ 23 "get_embedding_operators_config", 24 "EmbeddingQuantizer", 25] 26 27 28def get_embedding_operators_config() -> OperatorConfig: 29 weight_quantization_spec = QuantizationSpec( 30 dtype=torch.uint8, 31 qscheme=torch.per_channel_affine_float_qparams, 32 ch_axis=0, 33 observer_or_fake_quant_ctr=PerChannelMinMaxObserver.with_args(eps=2**-12), 34 ) 35 quantization_config = QuantizationConfig(None, None, weight_quantization_spec, None) 36 ops: List[OperatorPatternType] = [[torch.nn.Embedding]] 37 ops.append([F.embedding]) 38 supported_config_and_operators = OperatorConfig( 39 config=quantization_config, operators=ops 40 ) 41 return copy.deepcopy(supported_config_and_operators) 42 43 44class EmbeddingQuantizer(Quantizer): 45 def __init__(self) -> None: 46 super().__init__() 47 48 @classmethod 49 def get_supported_quantization_configs(cls) -> List[QuantizationConfig]: 50 op_configs: Set[QuantizationConfig] = { 51 spec for spec, _ in cls.get_supported_operators() 52 } 53 return list(op_configs) 54 55 @classmethod 56 def get_supported_operator_for_quantization_config( 57 cls, quantization_config: QuantizationConfig 58 ) -> List[OperatorPatternType]: 59 for config, ops in cls.get_supported_operators(): 60 # note: this assumes each entry in cls.supported_spec_and_operators 61 # corresponds to one spec, e.g. we don't have 62 # [(spec1, op_list1), (spec1, op_list2), (spec2, op_list3)] 63 # where the first and second entry have the same spec but did not 64 # merge the op list 65 if config == quantization_config: 66 return ops 67 return [] 68 69 def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: 70 """just handling global spec for now""" 71 self._annotate_embedding_ops(model.graph) 72 return model 73 74 def _annotate_embedding_ops(self, graph: torch.fx.Graph) -> None: 75 embedding_config: OperatorConfig = get_embedding_operators_config() 76 for node in graph.nodes: 77 # Keep node parsing based annotations instead of module partitioners 78 # just as an example of alternate ways of annotating 79 if ( 80 node.op == "call_function" 81 and node.target == torch.ops.aten.embedding.default 82 ): 83 if embedding_config.config.weight is None: 84 raise ValueError( 85 "Embedding config must have a valid weight quantization spec." 86 ) 87 node.meta["quantization_annotation"] = QuantizationAnnotation( 88 input_qspec_map={ 89 node.args[0]: embedding_config.config.weight, 90 } 91 ) 92 93 def validate(self, model: torch.fx.GraphModule) -> None: 94 pass 95 96 @classmethod 97 def get_supported_operators(cls) -> List[OperatorConfig]: 98 return [get_embedding_operators_config()] 99