xref: /aosp_15_r20/external/pytorch/docs/source/scripts/build_quantization_configs.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1"""
2This script will generate default values of quantization configs.
3These are for use in the documentation.
4"""
5
6import os.path
7
8import torch
9from torch.ao.quantization.backend_config import get_native_backend_config_dict
10from torch.ao.quantization.backend_config.utils import (
11    entry_to_pretty_str,
12    remove_boolean_dispatch_from_name,
13)
14
15
16# Create a directory for the images, if it doesn't exist
17QUANTIZATION_BACKEND_CONFIG_IMAGE_PATH = os.path.join(
18    os.path.realpath(os.path.join(__file__, "..")), "quantization_backend_configs"
19)
20
21if not os.path.exists(QUANTIZATION_BACKEND_CONFIG_IMAGE_PATH):
22    os.mkdir(QUANTIZATION_BACKEND_CONFIG_IMAGE_PATH)
23
24output_path = os.path.join(
25    QUANTIZATION_BACKEND_CONFIG_IMAGE_PATH, "default_backend_config.txt"
26)
27
28with open(output_path, "w") as f:
29    native_backend_config_dict = get_native_backend_config_dict()
30
31    configs = native_backend_config_dict["configs"]
32
33    def _sort_key_func(entry):
34        pattern = entry["pattern"]
35        while isinstance(pattern, tuple):
36            pattern = pattern[-1]
37
38        pattern = remove_boolean_dispatch_from_name(pattern)
39        if not isinstance(pattern, str):
40            # methods are already strings
41            pattern = torch.typename(pattern)
42
43        # we want
44        #
45        #   torch.nn.modules.pooling.AdaptiveAvgPool1d
46        #
47        # and
48        #
49        #   torch._VariableFunctionsClass.adaptive_avg_pool1d
50        #
51        # to be next to each other, so convert to all lower case
52        # and remove the underscores, and compare the last part
53        # of the string
54        pattern_str_normalized = pattern.lower().replace("_", "")
55        key = pattern_str_normalized.split(".")[-1]
56        return key
57
58    configs.sort(key=_sort_key_func)
59
60    entries = []
61    for entry in configs:
62        entries.append(entry_to_pretty_str(entry))
63    entries = ",\n".join(entries)
64    f.write(entries)
65