1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7# pyre-strict 8 9from abc import ABC, abstractmethod 10from dataclasses import dataclass, field 11from typing import List, Optional, Tuple, Union 12 13import torch 14from executorch.backends.cadence.aot.quantizer.utils import get_bias_qparams 15 16from torch import fx 17from torch._ops import OpOverload 18from torch.ao.quantization.quantizer import ( 19 DerivedQuantizationSpec, 20 SharedQuantizationSpec, 21) 22 23 24@dataclass 25class PartitionAnchors: 26 """ 27 All fields except output are lists of (node, args_index) pair, where node is from 28 the given partition and node.args[args_index] is an input to the partition. Assumes 29 a single output. 30 31 Quantizer uses inputs, weights and biases for quantization annotation. The others 32 field contains tensor inputs that aren't quantized, and the literals fields contains 33 is used for other types of input values as well as handling default parameters. 34 """ 35 36 inputs: List[Tuple[fx.Node, int]] = field(default_factory=list) 37 weights: List[Tuple[fx.Node, int]] = field(default_factory=list) 38 biases: List[ 39 Union[Tuple[fx.Node, int], Tuple[fx.Node, int, DerivedQuantizationSpec]] 40 ] = field(default_factory=list) 41 others: List[Tuple[fx.Node, int]] = field(default_factory=list) 42 literals: List[Tuple[fx.Node, int]] = field(default_factory=list) 43 output: List[Union[Tuple[fx.Node], Tuple[fx.Node, SharedQuantizationSpec]]] = field( 44 default_factory=list 45 ) 46 47 48class QuantizationPattern(ABC): 49 @abstractmethod 50 def partition_types(self) -> list[OpOverload]: 51 """ 52 List of types to be passed to find_sequential_partitions_aten. 53 """ 54 pass 55 56 @abstractmethod 57 def get_anchors( 58 self, gm: torch.fx.GraphModule, fused_partition: List[fx.GraphModule] 59 ) -> Optional[PartitionAnchors]: 60 pass 61 62 @abstractmethod 63 def replacement_op(self) -> OpOverload: 64 """ 65 Operator (most likely a custom one) that this partition should be fused into in 66 the backend. Refer to the QuantFusion pass for examples. 67 """ 68 pass 69 70 71class AddmmPattern(QuantizationPattern): 72 def partition_types(self) -> List[OpOverload]: 73 return [torch.ops.aten.addmm.default] 74 75 def get_anchors( 76 self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule] 77 ) -> PartitionAnchors: 78 # pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge... 79 addmm_node = fused_partition[0].nodes[-1] 80 81 bias_qspec = DerivedQuantizationSpec( 82 derived_from=[ 83 (addmm_node.args[1], addmm_node), 84 (addmm_node.args[2], addmm_node), 85 ], 86 derive_qparams_fn=get_bias_qparams, 87 dtype=torch.int32, 88 quant_min=-(2**31), 89 quant_max=2**31 - 1, 90 qscheme=torch.per_tensor_affine, 91 ) 92 93 return PartitionAnchors( 94 inputs=[(addmm_node, 1)], 95 weights=[(addmm_node, 2)], 96 biases=[(addmm_node, 0, bias_qspec)], 97 output=[(addmm_node,)], 98 ) 99 100 def replacement_op(self) -> OpOverload: 101 return torch.ops.cadence.quantized_linear 102 103 104class BmmPattern(QuantizationPattern): 105 def partition_types(self) -> List[OpOverload]: 106 return [torch.ops.aten.bmm.default] 107 108 def get_anchors( 109 self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule] 110 ) -> PartitionAnchors: 111 # pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge... 112 bmm_node = fused_partition[0].nodes[-1] 113 114 return PartitionAnchors( 115 inputs=[(bmm_node, 0), (bmm_node, 1)], 116 weights=[], 117 biases=[], 118 output=[(bmm_node,)], 119 ) 120 121 def replacement_op(self) -> OpOverload: 122 return torch.ops.cadence.quantized_matmul.default 123 124 125class Conv1dPattern(QuantizationPattern): 126 def partition_types(self) -> List[OpOverload]: 127 return [torch.ops.aten.conv1d.default] 128 129 def get_anchors( 130 self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule] 131 ) -> PartitionAnchors: 132 # pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge... 133 conv1d_node = fused_partition[0].nodes[-1] 134 135 bias_qspec = DerivedQuantizationSpec( 136 derived_from=[ 137 (conv1d_node.args[0], conv1d_node), 138 (conv1d_node.args[1], conv1d_node), 139 ], 140 derive_qparams_fn=get_bias_qparams, 141 dtype=torch.int32, 142 quant_min=-(2**31), 143 quant_max=2**31 - 1, 144 qscheme=torch.per_tensor_affine, 145 ) 146 147 # Keep bias empty if not supplied 148 bias = [] 149 if len(conv1d_node.args) > 2 and conv1d_node.args[2] is not None: 150 bias = [(conv1d_node, 2, bias_qspec)] 151 152 return PartitionAnchors( 153 inputs=[(conv1d_node, 0)], 154 weights=[(conv1d_node, 1)], 155 # pyre-fixme[6]: Incompatible parameter type 156 biases=bias, 157 output=[(conv1d_node,)], 158 ) 159 160 def replacement_op(self) -> OpOverload: 161 return torch.ops.cadence.quantized_conv.default 162 163 164class Conv2dPattern(QuantizationPattern): 165 def partition_types(self) -> List[OpOverload]: 166 return [torch.ops.aten.conv2d.default] 167 168 def get_anchors( 169 self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule] 170 ) -> PartitionAnchors: 171 # pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge... 172 conv2d_node = fused_partition[0].nodes[-1] 173 174 bias_qspec = DerivedQuantizationSpec( 175 derived_from=[ 176 (conv2d_node.args[0], conv2d_node), 177 (conv2d_node.args[1], conv2d_node), 178 ], 179 derive_qparams_fn=get_bias_qparams, 180 dtype=torch.int32, 181 quant_min=-(2**31), 182 quant_max=2**31 - 1, 183 qscheme=torch.per_tensor_affine, 184 ) 185 186 # Keep bias empty if not supplied 187 bias = [] 188 if len(conv2d_node.args) > 2 and conv2d_node.args[2] is not None: 189 bias = [(conv2d_node, 2, bias_qspec)] 190 191 return PartitionAnchors( 192 inputs=[(conv2d_node, 0)], 193 weights=[(conv2d_node, 1)], 194 # pyre-fixme[6]: Incompatible parameter type 195 biases=bias, 196 output=[(conv2d_node,)], 197 ) 198 199 def replacement_op(self) -> OpOverload: 200 return torch.ops.cadence.quantized_conv.default 201 202 203class LayerNormPattern(QuantizationPattern): 204 def partition_types(self) -> List[OpOverload]: 205 return [torch.ops.aten.layer_norm.default] 206 207 def get_anchors( 208 self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule] 209 ) -> PartitionAnchors: 210 # pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge... 211 layer_norm_node = fused_partition[0].nodes[-1] 212 213 others = [(layer_norm_node, 1)] 214 215 # Add weights if supplied 216 if len(layer_norm_node.args) > 2 and layer_norm_node.args[2]: 217 others.append((layer_norm_node, 2)) 218 219 # Add bias if supplied 220 if len(layer_norm_node.args) > 3 and layer_norm_node.args[3]: 221 others.append((layer_norm_node, 3)) 222 223 # Weights are used in quantized mode by our kernel, so they are 224 # passed in as others here along with the normalized shape. 225 return PartitionAnchors( 226 inputs=[(layer_norm_node, 0)], 227 weights=[], 228 biases=[], 229 # Ordering: normalized_shape, weights, bias 230 others=others, 231 output=[(layer_norm_node,)], 232 ) 233 234 def replacement_op(self) -> OpOverload: 235 return torch.ops.cadence.quantized_layer_norm.default 236 237 238class LinearPattern(QuantizationPattern): 239 def partition_types(self) -> List[OpOverload]: 240 return [torch.ops.aten.linear.default] 241 242 def get_anchors( 243 self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule] 244 ) -> PartitionAnchors: 245 # pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge... 246 linear_node = fused_partition[0].nodes[-1] 247 248 bias_qspec = DerivedQuantizationSpec( 249 derived_from=[ 250 (linear_node.args[0], linear_node), 251 (linear_node.args[1], linear_node), 252 ], 253 derive_qparams_fn=get_bias_qparams, 254 dtype=torch.int32, 255 quant_min=-(2**31), 256 quant_max=2**31 - 1, 257 qscheme=torch.per_tensor_affine, 258 ) 259 260 # Keep bias empty if not supplied 261 bias = [] 262 if len(linear_node.args) > 2: 263 bias = [(linear_node, 2, bias_qspec)] 264 265 return PartitionAnchors( 266 inputs=[(linear_node, 0)], 267 weights=[(linear_node, 1)], 268 # pyre-fixme[6]: Incompatible parameter type 269 biases=bias, 270 output=[(linear_node,)], 271 ) 272 273 def replacement_op(self) -> OpOverload: 274 return torch.ops.cadence.quantized_linear.default 275 276 277class MatmulPattern(QuantizationPattern): 278 def partition_types(self) -> List[OpOverload]: 279 return [torch.ops.aten.matmul.default] 280 281 def get_anchors( 282 self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule] 283 ) -> PartitionAnchors: 284 # pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge... 285 matmul_node = fused_partition[0].nodes[-1] 286 287 return PartitionAnchors( 288 inputs=[(matmul_node, 0), (matmul_node, 1)], 289 weights=[], 290 biases=[], 291 output=[(matmul_node,)], 292 ) 293 294 def replacement_op(self) -> OpOverload: 295 return torch.ops.cadence.quantized_matmul.default 296 297 298# This is a base class for ReLU, since it can be used with two different aten ops 299class ReluBasePattern(QuantizationPattern): 300 @abstractmethod 301 def partition_types(self) -> List[OpOverload]: 302 pass 303 304 def get_anchors( 305 self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule] 306 ) -> PartitionAnchors: 307 # pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge... 308 relu_node = fused_partition[0].nodes[-1] 309 310 return PartitionAnchors( 311 inputs=[(relu_node, 0)], 312 weights=[], 313 biases=[], 314 output=[(relu_node,)], 315 ) 316 317 def replacement_op(self) -> OpOverload: 318 return torch.ops.cadence.quantized_relu.default 319 320 321# Regular relu op 322class ReluPattern0(ReluBasePattern): 323 def partition_types(self) -> List[OpOverload]: 324 return [torch.ops.aten.relu.default] 325 326 327# Alternate relu op 328class ReluPattern1(ReluBasePattern): 329 def partition_types(self) -> List[OpOverload]: 330 return [torch.ops.aten.relu_.default] 331