1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7from dataclasses import dataclass 8from typing import List 9 10import torch 11from executorch.backends.example.example_operators.ops import module_to_annotator 12from torch import fx 13from torch.ao.quantization.observer import HistogramObserver, MinMaxObserver 14from torch.ao.quantization.pt2e.graph_utils import find_sequential_partitions 15from torch.ao.quantization.quantizer import QuantizationSpec, Quantizer 16from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import OperatorConfig 17 18 19def get_uint8_tensor_spec(observer_or_fake_quant_ctr): 20 return QuantizationSpec( 21 dtype=torch.uint8, 22 quant_min=0, 23 quant_max=255, 24 qscheme=torch.per_tensor_affine, 25 is_dynamic=False, 26 observer_or_fake_quant_ctr=observer_or_fake_quant_ctr, 27 ) 28 29 30@dataclass 31class ExampleQuantConfig: 32 input_quant_spec: QuantizationSpec 33 output_quant_spec: QuantizationSpec 34 weight_quant_spec: QuantizationSpec 35 bias_quant_spec: QuantizationSpec 36 37 38default_static_config = ExampleQuantConfig( 39 get_uint8_tensor_spec(HistogramObserver), 40 get_uint8_tensor_spec(HistogramObserver), 41 get_uint8_tensor_spec(MinMaxObserver), 42 # pyre-fixme[6]: Incompatible parameter type [6]: In call `ExampleQuantConfig.__init__`, for 4th positional argument, expected `QuantizationSpec` but got `None`. 43 None, # #bias quantization can be configured here or done in a pass later on. 44) 45 46 47def check_for_outside_users(partitions) -> bool: 48 """ 49 Make sure that all the users of this partiton are within the delegatable subgraph, 50 except the last partition. If we quantize partitions that have users outside this 51 subgraph then delegation of this partition to the backend will not be possible. 52 """ 53 for source_partition in partitions[:-1]: 54 if len(source_partition.output_nodes) != 1: 55 return True 56 if len(source_partition.output_nodes[0].users) != 1: 57 return True 58 return False 59 60 61class ExampleQuantizer(Quantizer): 62 def __init__(self, quantizer_supported_modules=None, quant_config=None): 63 super().__init__() 64 if quantizer_supported_modules is not None: 65 self.quantizer_supported_modules = quantizer_supported_modules 66 for module in self.quantizer_supported_modules: 67 if module not in module_to_annotator.keys(): 68 assert 0, f"{module} is not supported by this quantizer" 69 else: 70 self.quantizer_supported_modules = module_to_annotator.keys() 71 if quant_config is not None: 72 self.quant_config = quant_config 73 else: 74 self.quant_config = default_static_config 75 76 def annotate(self, model): 77 for supported_modules in self.quantizer_supported_modules: 78 # print("supported modules: ", supported_modules) 79 fused_partitions = find_sequential_partitions( 80 model, 81 list(supported_modules), 82 ) 83 84 for partitions in fused_partitions: 85 if check_for_outside_users(partitions): 86 continue 87 88 source_module_list = () 89 for partition in partitions: 90 source_module_list += (partition,) 91 92 annotator = module_to_annotator[supported_modules].annotate_handle 93 annotator(partitions, self.quant_config) 94 95 return model 96 97 def validate(self, model: fx.GraphModule) -> None: 98 pass 99 100 @classmethod 101 def get_supported_operators(cls) -> List[OperatorConfig]: 102 return [] 103