1""" 2This script will generate default values of quantization configs. 3These are for use in the documentation. 4""" 5 6import os.path 7 8import torch 9from torch.ao.quantization.backend_config import get_native_backend_config_dict 10from torch.ao.quantization.backend_config.utils import ( 11 entry_to_pretty_str, 12 remove_boolean_dispatch_from_name, 13) 14 15 16# Create a directory for the images, if it doesn't exist 17QUANTIZATION_BACKEND_CONFIG_IMAGE_PATH = os.path.join( 18 os.path.realpath(os.path.join(__file__, "..")), "quantization_backend_configs" 19) 20 21if not os.path.exists(QUANTIZATION_BACKEND_CONFIG_IMAGE_PATH): 22 os.mkdir(QUANTIZATION_BACKEND_CONFIG_IMAGE_PATH) 23 24output_path = os.path.join( 25 QUANTIZATION_BACKEND_CONFIG_IMAGE_PATH, "default_backend_config.txt" 26) 27 28with open(output_path, "w") as f: 29 native_backend_config_dict = get_native_backend_config_dict() 30 31 configs = native_backend_config_dict["configs"] 32 33 def _sort_key_func(entry): 34 pattern = entry["pattern"] 35 while isinstance(pattern, tuple): 36 pattern = pattern[-1] 37 38 pattern = remove_boolean_dispatch_from_name(pattern) 39 if not isinstance(pattern, str): 40 # methods are already strings 41 pattern = torch.typename(pattern) 42 43 # we want 44 # 45 # torch.nn.modules.pooling.AdaptiveAvgPool1d 46 # 47 # and 48 # 49 # torch._VariableFunctionsClass.adaptive_avg_pool1d 50 # 51 # to be next to each other, so convert to all lower case 52 # and remove the underscores, and compare the last part 53 # of the string 54 pattern_str_normalized = pattern.lower().replace("_", "") 55 key = pattern_str_normalized.split(".")[-1] 56 return key 57 58 configs.sort(key=_sort_key_func) 59 60 entries = [] 61 for entry in configs: 62 entries.append(entry_to_pretty_str(entry)) 63 entries = ",\n".join(entries) 64 f.write(entries) 65