1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7# pyre-strict 8 9from typing import List, Optional, Tuple, Union 10 11import torch 12from executorch.backends.cadence.aot.quantizer.patterns import ( 13 AddmmPattern, 14 BmmPattern, 15 Conv1dPattern, 16 Conv2dPattern, 17 LayerNormPattern, 18 LinearPattern, 19 MatmulPattern, 20 QuantizationPattern, 21 ReluPattern0, 22 ReluPattern1, 23) 24from executorch.backends.cadence.aot.quantizer.utils import ( 25 find_sequential_partitions_aten, 26 is_annotated, 27 no_outside_users, 28) 29 30from torch import fx 31 32from torch.ao.quantization.observer import HistogramObserver, MinMaxObserver 33from torch.ao.quantization.quantizer import DerivedQuantizationSpec, Quantizer 34from torch.ao.quantization.quantizer.composable_quantizer import ComposableQuantizer 35from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import ( 36 OperatorConfig, 37 QuantizationAnnotation, 38 QuantizationConfig, 39 QuantizationSpec, 40) 41 42 43act_qspec = QuantizationSpec( 44 dtype=torch.uint8, 45 quant_min=0, 46 quant_max=255, 47 qscheme=torch.per_tensor_affine, 48 is_dynamic=False, 49 observer_or_fake_quant_ctr=HistogramObserver.with_args(eps=2**-12), 50) 51 52wgt_qspec = QuantizationSpec( 53 dtype=torch.uint8, 54 quant_min=0, 55 quant_max=255, 56 qscheme=torch.per_tensor_affine, 57 is_dynamic=False, 58 observer_or_fake_quant_ctr=MinMaxObserver, 59) 60 61bias_qspec: Optional[QuantizationSpec] = None 62 63 64class CadenceAtenQuantizer(Quantizer): 65 def __init__( 66 self, pattern: QuantizationPattern, quantization_config: QuantizationConfig 67 ) -> None: 68 super().__init__() 69 self.pattern = pattern 70 self.quantization_config = quantization_config 71 72 def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: 73 fused_partitions = find_sequential_partitions_aten( 74 model, 75 self.pattern.partition_types(), 76 ) 77 78 input_act_qspec = self.quantization_config.input_activation 79 weight_qspec = self.quantization_config.weight 80 bias_qspec = self.quantization_config.bias 81 output_act_qspec = self.quantization_config.output_activation 82 83 for fused_partition in fused_partitions: 84 if not no_outside_users(fused_partition): 85 continue 86 87 anchors = self.pattern.get_anchors(model, fused_partition) 88 if not anchors: 89 continue 90 if is_annotated( 91 [ 92 x[0] 93 for x in anchors.inputs 94 + anchors.weights 95 + anchors.biases 96 + anchors.output 97 ] 98 ): 99 continue 100 101 for output, *custom_spec in anchors.output: 102 # pyre-ignore[16]: no attribute 103 output.meta["quantization_annotation"] = QuantizationAnnotation( 104 # pyre-ignore[6]: incompatible parameter type 105 output_qspec=(custom_spec[0] if custom_spec else output_act_qspec), 106 _annotated=True, 107 ) 108 109 def annotate_inputs( 110 inputs: Union[ 111 List[Tuple[fx.Node, int]], 112 List[Tuple[fx.Node, int, DerivedQuantizationSpec],], 113 ], 114 spec: Optional[QuantizationSpec], 115 ) -> None: 116 for node, idx, *custom_spec in inputs: 117 # pyre-ignore[16]: no attribute 118 annotation = node.meta.get( 119 "quantization_annotation", 120 QuantizationAnnotation(_annotated=True), 121 ) 122 # pyre-ignore[16]: no attribute 123 annotation.input_qspec_map[node.args[idx]] = ( 124 custom_spec[0] if custom_spec else spec 125 ) 126 # pyre-ignore[16]: no attribute 127 node.meta["quantization_annotation"] = annotation 128 129 annotate_inputs(anchors.inputs, input_act_qspec) 130 annotate_inputs(anchors.weights, weight_qspec) 131 # pyre-ignore[6]: incompatible parameter type 132 annotate_inputs(anchors.biases, bias_qspec) 133 return model 134 135 def validate(self, model: fx.GraphModule) -> None: 136 pass 137 138 @classmethod 139 def get_supported_operators(cls) -> List[OperatorConfig]: 140 return [] 141 142 143class CadenceQuantizer(ComposableQuantizer): 144 def __init__( 145 self, quantization_config: Optional[QuantizationConfig] = None 146 ) -> None: 147 static_qconfig = ( 148 QuantizationConfig( 149 act_qspec, 150 act_qspec, 151 wgt_qspec, 152 None, 153 ) 154 if not quantization_config 155 else quantization_config 156 ) 157 158 super().__init__( 159 [ 160 CadenceAtenQuantizer(AddmmPattern(), static_qconfig), 161 CadenceAtenQuantizer(BmmPattern(), static_qconfig), 162 CadenceAtenQuantizer(Conv1dPattern(), static_qconfig), 163 CadenceAtenQuantizer(Conv2dPattern(), static_qconfig), 164 CadenceAtenQuantizer(LayerNormPattern(), static_qconfig), 165 CadenceAtenQuantizer(LinearPattern(), static_qconfig), 166 CadenceAtenQuantizer(MatmulPattern(), static_qconfig), 167 CadenceAtenQuantizer(ReluPattern0(), static_qconfig), 168 CadenceAtenQuantizer(ReluPattern1(), static_qconfig), 169 ] 170 ) 171