xref: /aosp_15_r20/external/pytorch/torch/ao/quantization/fx/custom_config.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2from __future__ import annotations
3
4from dataclasses import dataclass
5from typing import Any, Dict, List, Optional, Tuple, Type
6
7from torch.ao.quantization import QConfigMapping
8from torch.ao.quantization.backend_config import BackendConfig
9from torch.ao.quantization.quant_type import (
10    _get_quant_type_to_str,
11    _quant_type_from_str,
12    QuantType,
13)
14
15
16__all__ = [
17    "ConvertCustomConfig",
18    "FuseCustomConfig",
19    "PrepareCustomConfig",
20    "StandaloneModuleConfigEntry",
21]
22
23
24# TODO: replace all usages with these constants
25STANDALONE_MODULE_NAME_DICT_KEY = "standalone_module_name"
26STANDALONE_MODULE_CLASS_DICT_KEY = "standalone_module_class"
27FLOAT_TO_OBSERVED_DICT_KEY = "float_to_observed_custom_module_class"
28OBSERVED_TO_QUANTIZED_DICT_KEY = "observed_to_quantized_custom_module_class"
29NON_TRACEABLE_MODULE_NAME_DICT_KEY = "non_traceable_module_name"
30NON_TRACEABLE_MODULE_CLASS_DICT_KEY = "non_traceable_module_class"
31INPUT_QUANTIZED_INDEXES_DICT_KEY = "input_quantized_idxs"
32OUTPUT_QUANTIZED_INDEXES_DICT_KEY = "output_quantized_idxs"
33PRESERVED_ATTRIBUTES_DICT_KEY = "preserved_attributes"
34
35
36@dataclass
37class StandaloneModuleConfigEntry:
38    # qconfig_mapping for the prepare function called in the submodule,
39    # None means use qconfig from parent qconfig_mapping
40    qconfig_mapping: Optional[QConfigMapping]
41    example_inputs: Tuple[Any, ...]
42    prepare_custom_config: Optional[PrepareCustomConfig]
43    backend_config: Optional[BackendConfig]
44
45
46class PrepareCustomConfig:
47    """
48    Custom configuration for :func:`~torch.ao.quantization.quantize_fx.prepare_fx` and
49    :func:`~torch.ao.quantization.quantize_fx.prepare_qat_fx`.
50
51    Example usage::
52
53        prepare_custom_config = PrepareCustomConfig() \
54            .set_standalone_module_name("module1", qconfig_mapping, example_inputs, \
55                child_prepare_custom_config, backend_config) \
56            .set_standalone_module_class(MyStandaloneModule, qconfig_mapping, example_inputs, \
57                child_prepare_custom_config, backend_config) \
58            .set_float_to_observed_mapping(FloatCustomModule, ObservedCustomModule) \
59            .set_non_traceable_module_names(["module2", "module3"]) \
60            .set_non_traceable_module_classes([NonTraceableModule1, NonTraceableModule2]) \
61            .set_input_quantized_indexes([0]) \
62            .set_output_quantized_indexes([0]) \
63            .set_preserved_attributes(["attr1", "attr2"])
64    """
65
66    def __init__(self) -> None:
67        self.standalone_module_names: Dict[str, StandaloneModuleConfigEntry] = {}
68        self.standalone_module_classes: Dict[Type, StandaloneModuleConfigEntry] = {}
69        self.float_to_observed_mapping: Dict[QuantType, Dict[Type, Type]] = {}
70        self.non_traceable_module_names: List[str] = []
71        self.non_traceable_module_classes: List[Type] = []
72        self.input_quantized_indexes: List[int] = []
73        self.output_quantized_indexes: List[int] = []
74        self.preserved_attributes: List[str] = []
75
76    def __repr__(self):
77        dict_nonempty = {k: v for k, v in self.__dict__.items() if len(v) > 0}
78        return f"PrepareCustomConfig({dict_nonempty})"
79
80    def set_standalone_module_name(
81        self,
82        module_name: str,
83        qconfig_mapping: Optional[QConfigMapping],
84        example_inputs: Tuple[Any, ...],
85        prepare_custom_config: Optional[PrepareCustomConfig],
86        backend_config: Optional[BackendConfig],
87    ) -> PrepareCustomConfig:
88        """
89        Set the configuration for running a standalone module identified by ``module_name``.
90
91        If ``qconfig_mapping`` is None, the parent ``qconfig_mapping`` will be used instead.
92        If ``prepare_custom_config`` is None, an empty ``PrepareCustomConfig`` will be used.
93        If ``backend_config`` is None, the parent ``backend_config`` will be used instead.
94        """
95        self.standalone_module_names[module_name] = StandaloneModuleConfigEntry(
96            qconfig_mapping, example_inputs, prepare_custom_config, backend_config
97        )
98        return self
99
100    def set_standalone_module_class(
101        self,
102        module_class: Type,
103        qconfig_mapping: Optional[QConfigMapping],
104        example_inputs: Tuple[Any, ...],
105        prepare_custom_config: Optional[PrepareCustomConfig],
106        backend_config: Optional[BackendConfig],
107    ) -> PrepareCustomConfig:
108        """
109        Set the configuration for running a standalone module identified by ``module_class``.
110
111        If ``qconfig_mapping`` is None, the parent ``qconfig_mapping`` will be used instead.
112        If ``prepare_custom_config`` is None, an empty ``PrepareCustomConfig`` will be used.
113        If ``backend_config`` is None, the parent ``backend_config`` will be used instead.
114        """
115        self.standalone_module_classes[module_class] = StandaloneModuleConfigEntry(
116            qconfig_mapping, example_inputs, prepare_custom_config, backend_config
117        )
118        return self
119
120    def set_float_to_observed_mapping(
121        self,
122        float_class: Type,
123        observed_class: Type,
124        quant_type: QuantType = QuantType.STATIC,
125    ) -> PrepareCustomConfig:
126        """
127        Set the mapping from a custom float module class to a custom observed module class.
128
129        The observed module class must have a ``from_float`` class method that converts the float module class
130        to the observed module class. This is currently only supported for static quantization.
131        """
132        if quant_type != QuantType.STATIC:
133            raise ValueError(
134                "set_float_to_observed_mapping is currently only supported for static quantization"
135            )
136        if quant_type not in self.float_to_observed_mapping:
137            self.float_to_observed_mapping[quant_type] = {}
138        self.float_to_observed_mapping[quant_type][float_class] = observed_class
139        return self
140
141    def set_non_traceable_module_names(
142        self, module_names: List[str]
143    ) -> PrepareCustomConfig:
144        """
145        Set the modules that are not symbolically traceable, identified by name.
146        """
147        self.non_traceable_module_names = module_names
148        return self
149
150    def set_non_traceable_module_classes(
151        self, module_classes: List[Type]
152    ) -> PrepareCustomConfig:
153        """
154        Set the modules that are not symbolically traceable, identified by class.
155        """
156        self.non_traceable_module_classes = module_classes
157        return self
158
159    def set_input_quantized_indexes(self, indexes: List[int]) -> PrepareCustomConfig:
160        """
161        Set the indexes of the inputs of the graph that should be quantized.
162        Inputs are otherwise assumed to be in fp32 by default instead.
163        """
164        self.input_quantized_indexes = indexes
165        return self
166
167    def set_output_quantized_indexes(self, indexes: List[int]) -> PrepareCustomConfig:
168        """
169        Set the indexes of the outputs of the graph that should be quantized.
170        Outputs are otherwise assumed to be in fp32 by default instead.
171        """
172        self.output_quantized_indexes = indexes
173        return self
174
175    def set_preserved_attributes(self, attributes: List[str]) -> PrepareCustomConfig:
176        """
177        Set the names of the attributes that will persist in the graph module even if they are not used in
178        the model's ``forward`` method.
179        """
180        self.preserved_attributes = attributes
181        return self
182
183    # TODO: remove this
184    @classmethod
185    def from_dict(
186        cls, prepare_custom_config_dict: Dict[str, Any]
187    ) -> PrepareCustomConfig:
188        """
189        Create a ``PrepareCustomConfig`` from a dictionary with the following items:
190
191            "standalone_module_name": a list of (module_name, qconfig_mapping, example_inputs,
192            child_prepare_custom_config, backend_config) tuples
193
194            "standalone_module_class" a list of (module_class, qconfig_mapping, example_inputs,
195            child_prepare_custom_config, backend_config) tuples
196
197            "float_to_observed_custom_module_class": a nested dictionary mapping from quantization
198            mode to an inner mapping from float module classes to observed module classes, e.g.
199            {"static": {FloatCustomModule: ObservedCustomModule}}
200
201            "non_traceable_module_name": a list of modules names that are not symbolically traceable
202            "non_traceable_module_class": a list of module classes that are not symbolically traceable
203            "input_quantized_idxs": a list of indexes of graph inputs that should be quantized
204            "output_quantized_idxs": a list of indexes of graph outputs that should be quantized
205            "preserved_attributes": a list of attributes that persist even if they are not used in ``forward``
206
207        This function is primarily for backward compatibility and may be removed in the future.
208        """
209
210        def _get_qconfig_mapping(obj: Any, dict_key: str) -> Optional[QConfigMapping]:
211            """
212            Convert the given object into a QConfigMapping if possible, else throw an exception.
213            """
214            if isinstance(obj, QConfigMapping) or obj is None:
215                return obj
216            if isinstance(obj, Dict):
217                return QConfigMapping.from_dict(obj)
218            raise ValueError(
219                f"Expected QConfigMapping in prepare_custom_config_dict[\"{dict_key}\"], got '{type(obj)}'"
220            )
221
222        def _get_prepare_custom_config(
223            obj: Any, dict_key: str
224        ) -> Optional[PrepareCustomConfig]:
225            """
226            Convert the given object into a PrepareCustomConfig if possible, else throw an exception.
227            """
228            if isinstance(obj, PrepareCustomConfig) or obj is None:
229                return obj
230            if isinstance(obj, Dict):
231                return PrepareCustomConfig.from_dict(obj)
232            raise ValueError(
233                f"Expected PrepareCustomConfig in prepare_custom_config_dict[\"{dict_key}\"], got '{type(obj)}'"
234            )
235
236        def _get_backend_config(obj: Any, dict_key: str) -> Optional[BackendConfig]:
237            """
238            Convert the given object into a BackendConfig if possible, else throw an exception.
239            """
240            if isinstance(obj, BackendConfig) or obj is None:
241                return obj
242            if isinstance(obj, Dict):
243                return BackendConfig.from_dict(obj)
244            raise ValueError(
245                f"Expected BackendConfig in prepare_custom_config_dict[\"{dict_key}\"], got '{type(obj)}'"
246            )
247
248        conf = cls()
249        for (
250            module_name,
251            qconfig_dict,
252            example_inputs,
253            _prepare_custom_config_dict,
254            backend_config_dict,
255        ) in prepare_custom_config_dict.get(STANDALONE_MODULE_NAME_DICT_KEY, []):
256            qconfig_mapping = _get_qconfig_mapping(
257                qconfig_dict, STANDALONE_MODULE_NAME_DICT_KEY
258            )
259            prepare_custom_config = _get_prepare_custom_config(
260                _prepare_custom_config_dict, STANDALONE_MODULE_NAME_DICT_KEY
261            )
262            backend_config = _get_backend_config(
263                backend_config_dict, STANDALONE_MODULE_NAME_DICT_KEY
264            )
265            conf.set_standalone_module_name(
266                module_name,
267                qconfig_mapping,
268                example_inputs,
269                prepare_custom_config,
270                backend_config,
271            )
272        for (
273            module_class,
274            qconfig_dict,
275            example_inputs,
276            _prepare_custom_config_dict,
277            backend_config_dict,
278        ) in prepare_custom_config_dict.get(STANDALONE_MODULE_CLASS_DICT_KEY, []):
279            qconfig_mapping = _get_qconfig_mapping(
280                qconfig_dict, STANDALONE_MODULE_CLASS_DICT_KEY
281            )
282            prepare_custom_config = _get_prepare_custom_config(
283                _prepare_custom_config_dict, STANDALONE_MODULE_CLASS_DICT_KEY
284            )
285            backend_config = _get_backend_config(
286                backend_config_dict, STANDALONE_MODULE_CLASS_DICT_KEY
287            )
288            conf.set_standalone_module_class(
289                module_class,
290                qconfig_mapping,
291                example_inputs,
292                prepare_custom_config,
293                backend_config,
294            )
295        for quant_type_name, custom_module_mapping in prepare_custom_config_dict.get(
296            FLOAT_TO_OBSERVED_DICT_KEY, {}
297        ).items():
298            quant_type = _quant_type_from_str(quant_type_name)
299            for float_class, observed_class in custom_module_mapping.items():
300                conf.set_float_to_observed_mapping(
301                    float_class, observed_class, quant_type
302                )
303        conf.set_non_traceable_module_names(
304            prepare_custom_config_dict.get(NON_TRACEABLE_MODULE_NAME_DICT_KEY, [])
305        )
306        conf.set_non_traceable_module_classes(
307            prepare_custom_config_dict.get(NON_TRACEABLE_MODULE_CLASS_DICT_KEY, [])
308        )
309        conf.set_input_quantized_indexes(
310            prepare_custom_config_dict.get(INPUT_QUANTIZED_INDEXES_DICT_KEY, [])
311        )
312        conf.set_output_quantized_indexes(
313            prepare_custom_config_dict.get(OUTPUT_QUANTIZED_INDEXES_DICT_KEY, [])
314        )
315        conf.set_preserved_attributes(
316            prepare_custom_config_dict.get(PRESERVED_ATTRIBUTES_DICT_KEY, [])
317        )
318        return conf
319
320    def to_dict(self) -> Dict[str, Any]:
321        """
322        Convert this ``PrepareCustomConfig`` to a dictionary with the items described in
323        :func:`~torch.ao.quantization.fx.custom_config.PrepareCustomConfig.from_dict`.
324        """
325
326        def _make_tuple(key: Any, e: StandaloneModuleConfigEntry):
327            qconfig_dict = e.qconfig_mapping.to_dict() if e.qconfig_mapping else None
328            prepare_custom_config_dict = (
329                e.prepare_custom_config.to_dict() if e.prepare_custom_config else None
330            )
331            return (
332                key,
333                qconfig_dict,
334                e.example_inputs,
335                prepare_custom_config_dict,
336                e.backend_config,
337            )
338
339        d: Dict[str, Any] = {}
340        for module_name, sm_config_entry in self.standalone_module_names.items():
341            if STANDALONE_MODULE_NAME_DICT_KEY not in d:
342                d[STANDALONE_MODULE_NAME_DICT_KEY] = []
343            d[STANDALONE_MODULE_NAME_DICT_KEY].append(
344                _make_tuple(module_name, sm_config_entry)
345            )
346        for module_class, sm_config_entry in self.standalone_module_classes.items():
347            if STANDALONE_MODULE_CLASS_DICT_KEY not in d:
348                d[STANDALONE_MODULE_CLASS_DICT_KEY] = []
349            d[STANDALONE_MODULE_CLASS_DICT_KEY].append(
350                _make_tuple(module_class, sm_config_entry)
351            )
352        for (
353            quant_type,
354            float_to_observed_mapping,
355        ) in self.float_to_observed_mapping.items():
356            if FLOAT_TO_OBSERVED_DICT_KEY not in d:
357                d[FLOAT_TO_OBSERVED_DICT_KEY] = {}
358            d[FLOAT_TO_OBSERVED_DICT_KEY][
359                _get_quant_type_to_str(quant_type)
360            ] = float_to_observed_mapping
361        if len(self.non_traceable_module_names) > 0:
362            d[NON_TRACEABLE_MODULE_NAME_DICT_KEY] = self.non_traceable_module_names
363        if len(self.non_traceable_module_classes) > 0:
364            d[NON_TRACEABLE_MODULE_CLASS_DICT_KEY] = self.non_traceable_module_classes
365        if len(self.input_quantized_indexes) > 0:
366            d[INPUT_QUANTIZED_INDEXES_DICT_KEY] = self.input_quantized_indexes
367        if len(self.output_quantized_indexes) > 0:
368            d[OUTPUT_QUANTIZED_INDEXES_DICT_KEY] = self.output_quantized_indexes
369        if len(self.preserved_attributes) > 0:
370            d[PRESERVED_ATTRIBUTES_DICT_KEY] = self.preserved_attributes
371        return d
372
373
374class ConvertCustomConfig:
375    """
376    Custom configuration for :func:`~torch.ao.quantization.quantize_fx.convert_fx`.
377
378    Example usage::
379
380        convert_custom_config = ConvertCustomConfig() \
381            .set_observed_to_quantized_mapping(ObservedCustomModule, QuantizedCustomModule) \
382            .set_preserved_attributes(["attr1", "attr2"])
383    """
384
385    def __init__(self) -> None:
386        self.observed_to_quantized_mapping: Dict[QuantType, Dict[Type, Type]] = {}
387        self.preserved_attributes: List[str] = []
388
389    def __repr__(self):
390        dict_nonempty = {k: v for k, v in self.__dict__.items() if len(v) > 0}
391        return f"ConvertCustomConfig({dict_nonempty})"
392
393    def set_observed_to_quantized_mapping(
394        self,
395        observed_class: Type,
396        quantized_class: Type,
397        quant_type: QuantType = QuantType.STATIC,
398    ) -> ConvertCustomConfig:
399        """
400        Set the mapping from a custom observed module class to a custom quantized module class.
401
402        The quantized module class must have a ``from_observed`` class method that converts the observed module class
403        to the quantized module class.
404        """
405        if quant_type not in self.observed_to_quantized_mapping:
406            self.observed_to_quantized_mapping[quant_type] = {}
407        self.observed_to_quantized_mapping[quant_type][observed_class] = quantized_class
408        return self
409
410    def set_preserved_attributes(self, attributes: List[str]) -> ConvertCustomConfig:
411        """
412        Set the names of the attributes that will persist in the graph module even if they are not used in
413        the model's ``forward`` method.
414        """
415        self.preserved_attributes = attributes
416        return self
417
418    # TODO: remove this
419    @classmethod
420    def from_dict(
421        cls, convert_custom_config_dict: Dict[str, Any]
422    ) -> ConvertCustomConfig:
423        """
424        Create a ``ConvertCustomConfig`` from a dictionary with the following items:
425
426            "observed_to_quantized_custom_module_class": a nested dictionary mapping from quantization
427            mode to an inner mapping from observed module classes to quantized module classes, e.g.::
428            {
429            "static": {FloatCustomModule: ObservedCustomModule},
430            "dynamic": {FloatCustomModule: ObservedCustomModule},
431            "weight_only": {FloatCustomModule: ObservedCustomModule}
432            }
433            "preserved_attributes": a list of attributes that persist even if they are not used in ``forward``
434
435        This function is primarily for backward compatibility and may be removed in the future.
436        """
437        conf = cls()
438        for quant_type_name, custom_module_mapping in convert_custom_config_dict.get(
439            OBSERVED_TO_QUANTIZED_DICT_KEY, {}
440        ).items():
441            quant_type = _quant_type_from_str(quant_type_name)
442            for observed_class, quantized_class in custom_module_mapping.items():
443                conf.set_observed_to_quantized_mapping(
444                    observed_class, quantized_class, quant_type
445                )
446        conf.set_preserved_attributes(
447            convert_custom_config_dict.get(PRESERVED_ATTRIBUTES_DICT_KEY, [])
448        )
449        return conf
450
451    def to_dict(self) -> Dict[str, Any]:
452        """
453        Convert this ``ConvertCustomConfig`` to a dictionary with the items described in
454        :func:`~torch.ao.quantization.fx.custom_config.ConvertCustomConfig.from_dict`.
455        """
456        d: Dict[str, Any] = {}
457        for (
458            quant_type,
459            observed_to_quantized_mapping,
460        ) in self.observed_to_quantized_mapping.items():
461            if OBSERVED_TO_QUANTIZED_DICT_KEY not in d:
462                d[OBSERVED_TO_QUANTIZED_DICT_KEY] = {}
463            d[OBSERVED_TO_QUANTIZED_DICT_KEY][
464                _get_quant_type_to_str(quant_type)
465            ] = observed_to_quantized_mapping
466        if len(self.preserved_attributes) > 0:
467            d[PRESERVED_ATTRIBUTES_DICT_KEY] = self.preserved_attributes
468        return d
469
470
471class FuseCustomConfig:
472    """
473    Custom configuration for :func:`~torch.ao.quantization.quantize_fx.fuse_fx`.
474
475    Example usage::
476
477        fuse_custom_config = FuseCustomConfig().set_preserved_attributes(["attr1", "attr2"])
478    """
479
480    def __init__(self) -> None:
481        self.preserved_attributes: List[str] = []
482
483    def __repr__(self):
484        dict_nonempty = {k: v for k, v in self.__dict__.items() if len(v) > 0}
485        return f"FuseCustomConfig({dict_nonempty})"
486
487    def set_preserved_attributes(self, attributes: List[str]) -> FuseCustomConfig:
488        """
489        Set the names of the attributes that will persist in the graph module even if they are not used in
490        the model's ``forward`` method.
491        """
492        self.preserved_attributes = attributes
493        return self
494
495    # TODO: remove this
496    @classmethod
497    def from_dict(cls, fuse_custom_config_dict: Dict[str, Any]) -> FuseCustomConfig:
498        """
499        Create a ``ConvertCustomConfig`` from a dictionary with the following items:
500
501            "preserved_attributes": a list of attributes that persist even if they are not used in ``forward``
502
503        This function is primarily for backward compatibility and may be removed in the future.
504        """
505        conf = cls()
506        conf.set_preserved_attributes(
507            fuse_custom_config_dict.get(PRESERVED_ATTRIBUTES_DICT_KEY, [])
508        )
509        return conf
510
511    def to_dict(self) -> Dict[str, Any]:
512        """
513        Convert this ``FuseCustomConfig`` to a dictionary with the items described in
514        :func:`~torch.ao.quantization.fx.custom_config.ConvertCustomConfig.from_dict`.
515        """
516        d: Dict[str, Any] = {}
517        if len(self.preserved_attributes) > 0:
518            d[PRESERVED_ATTRIBUTES_DICT_KEY] = self.preserved_attributes
519        return d
520