xref: /aosp_15_r20/external/executorch/backends/cadence/aot/quantizer/patterns.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7# pyre-strict
8
9from abc import ABC, abstractmethod
10from dataclasses import dataclass, field
11from typing import List, Optional, Tuple, Union
12
13import torch
14from executorch.backends.cadence.aot.quantizer.utils import get_bias_qparams
15
16from torch import fx
17from torch._ops import OpOverload
18from torch.ao.quantization.quantizer import (
19    DerivedQuantizationSpec,
20    SharedQuantizationSpec,
21)
22
23
24@dataclass
25class PartitionAnchors:
26    """
27    All fields except output are lists of (node, args_index) pair, where node is from
28    the given partition and node.args[args_index] is an input to the partition. Assumes
29    a single output.
30
31    Quantizer uses inputs, weights and biases for quantization annotation. The others
32    field contains tensor inputs that aren't quantized, and the literals fields contains
33    is used for other types of input values as well as handling default parameters.
34    """
35
36    inputs: List[Tuple[fx.Node, int]] = field(default_factory=list)
37    weights: List[Tuple[fx.Node, int]] = field(default_factory=list)
38    biases: List[
39        Union[Tuple[fx.Node, int], Tuple[fx.Node, int, DerivedQuantizationSpec]]
40    ] = field(default_factory=list)
41    others: List[Tuple[fx.Node, int]] = field(default_factory=list)
42    literals: List[Tuple[fx.Node, int]] = field(default_factory=list)
43    output: List[Union[Tuple[fx.Node], Tuple[fx.Node, SharedQuantizationSpec]]] = field(
44        default_factory=list
45    )
46
47
48class QuantizationPattern(ABC):
49    @abstractmethod
50    def partition_types(self) -> list[OpOverload]:
51        """
52        List of types to be passed to find_sequential_partitions_aten.
53        """
54        pass
55
56    @abstractmethod
57    def get_anchors(
58        self, gm: torch.fx.GraphModule, fused_partition: List[fx.GraphModule]
59    ) -> Optional[PartitionAnchors]:
60        pass
61
62    @abstractmethod
63    def replacement_op(self) -> OpOverload:
64        """
65        Operator (most likely a custom one) that this partition should be fused into in
66        the backend. Refer to the QuantFusion pass for examples.
67        """
68        pass
69
70
71class AddmmPattern(QuantizationPattern):
72    def partition_types(self) -> List[OpOverload]:
73        return [torch.ops.aten.addmm.default]
74
75    def get_anchors(
76        self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
77    ) -> PartitionAnchors:
78        # pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge...
79        addmm_node = fused_partition[0].nodes[-1]
80
81        bias_qspec = DerivedQuantizationSpec(
82            derived_from=[
83                (addmm_node.args[1], addmm_node),
84                (addmm_node.args[2], addmm_node),
85            ],
86            derive_qparams_fn=get_bias_qparams,
87            dtype=torch.int32,
88            quant_min=-(2**31),
89            quant_max=2**31 - 1,
90            qscheme=torch.per_tensor_affine,
91        )
92
93        return PartitionAnchors(
94            inputs=[(addmm_node, 1)],
95            weights=[(addmm_node, 2)],
96            biases=[(addmm_node, 0, bias_qspec)],
97            output=[(addmm_node,)],
98        )
99
100    def replacement_op(self) -> OpOverload:
101        return torch.ops.cadence.quantized_linear
102
103
104class BmmPattern(QuantizationPattern):
105    def partition_types(self) -> List[OpOverload]:
106        return [torch.ops.aten.bmm.default]
107
108    def get_anchors(
109        self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
110    ) -> PartitionAnchors:
111        # pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge...
112        bmm_node = fused_partition[0].nodes[-1]
113
114        return PartitionAnchors(
115            inputs=[(bmm_node, 0), (bmm_node, 1)],
116            weights=[],
117            biases=[],
118            output=[(bmm_node,)],
119        )
120
121    def replacement_op(self) -> OpOverload:
122        return torch.ops.cadence.quantized_matmul.default
123
124
125class Conv1dPattern(QuantizationPattern):
126    def partition_types(self) -> List[OpOverload]:
127        return [torch.ops.aten.conv1d.default]
128
129    def get_anchors(
130        self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
131    ) -> PartitionAnchors:
132        # pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge...
133        conv1d_node = fused_partition[0].nodes[-1]
134
135        bias_qspec = DerivedQuantizationSpec(
136            derived_from=[
137                (conv1d_node.args[0], conv1d_node),
138                (conv1d_node.args[1], conv1d_node),
139            ],
140            derive_qparams_fn=get_bias_qparams,
141            dtype=torch.int32,
142            quant_min=-(2**31),
143            quant_max=2**31 - 1,
144            qscheme=torch.per_tensor_affine,
145        )
146
147        # Keep bias empty if not supplied
148        bias = []
149        if len(conv1d_node.args) > 2 and conv1d_node.args[2] is not None:
150            bias = [(conv1d_node, 2, bias_qspec)]
151
152        return PartitionAnchors(
153            inputs=[(conv1d_node, 0)],
154            weights=[(conv1d_node, 1)],
155            # pyre-fixme[6]: Incompatible parameter type
156            biases=bias,
157            output=[(conv1d_node,)],
158        )
159
160    def replacement_op(self) -> OpOverload:
161        return torch.ops.cadence.quantized_conv.default
162
163
164class Conv2dPattern(QuantizationPattern):
165    def partition_types(self) -> List[OpOverload]:
166        return [torch.ops.aten.conv2d.default]
167
168    def get_anchors(
169        self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
170    ) -> PartitionAnchors:
171        # pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge...
172        conv2d_node = fused_partition[0].nodes[-1]
173
174        bias_qspec = DerivedQuantizationSpec(
175            derived_from=[
176                (conv2d_node.args[0], conv2d_node),
177                (conv2d_node.args[1], conv2d_node),
178            ],
179            derive_qparams_fn=get_bias_qparams,
180            dtype=torch.int32,
181            quant_min=-(2**31),
182            quant_max=2**31 - 1,
183            qscheme=torch.per_tensor_affine,
184        )
185
186        # Keep bias empty if not supplied
187        bias = []
188        if len(conv2d_node.args) > 2 and conv2d_node.args[2] is not None:
189            bias = [(conv2d_node, 2, bias_qspec)]
190
191        return PartitionAnchors(
192            inputs=[(conv2d_node, 0)],
193            weights=[(conv2d_node, 1)],
194            # pyre-fixme[6]: Incompatible parameter type
195            biases=bias,
196            output=[(conv2d_node,)],
197        )
198
199    def replacement_op(self) -> OpOverload:
200        return torch.ops.cadence.quantized_conv.default
201
202
203class LayerNormPattern(QuantizationPattern):
204    def partition_types(self) -> List[OpOverload]:
205        return [torch.ops.aten.layer_norm.default]
206
207    def get_anchors(
208        self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
209    ) -> PartitionAnchors:
210        # pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge...
211        layer_norm_node = fused_partition[0].nodes[-1]
212
213        others = [(layer_norm_node, 1)]
214
215        # Add weights if supplied
216        if len(layer_norm_node.args) > 2 and layer_norm_node.args[2]:
217            others.append((layer_norm_node, 2))
218
219        # Add bias if supplied
220        if len(layer_norm_node.args) > 3 and layer_norm_node.args[3]:
221            others.append((layer_norm_node, 3))
222
223        # Weights are used in quantized mode by our kernel, so they are
224        # passed in as others here along with the normalized shape.
225        return PartitionAnchors(
226            inputs=[(layer_norm_node, 0)],
227            weights=[],
228            biases=[],
229            # Ordering: normalized_shape, weights, bias
230            others=others,
231            output=[(layer_norm_node,)],
232        )
233
234    def replacement_op(self) -> OpOverload:
235        return torch.ops.cadence.quantized_layer_norm.default
236
237
238class LinearPattern(QuantizationPattern):
239    def partition_types(self) -> List[OpOverload]:
240        return [torch.ops.aten.linear.default]
241
242    def get_anchors(
243        self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
244    ) -> PartitionAnchors:
245        # pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge...
246        linear_node = fused_partition[0].nodes[-1]
247
248        bias_qspec = DerivedQuantizationSpec(
249            derived_from=[
250                (linear_node.args[0], linear_node),
251                (linear_node.args[1], linear_node),
252            ],
253            derive_qparams_fn=get_bias_qparams,
254            dtype=torch.int32,
255            quant_min=-(2**31),
256            quant_max=2**31 - 1,
257            qscheme=torch.per_tensor_affine,
258        )
259
260        # Keep bias empty if not supplied
261        bias = []
262        if len(linear_node.args) > 2:
263            bias = [(linear_node, 2, bias_qspec)]
264
265        return PartitionAnchors(
266            inputs=[(linear_node, 0)],
267            weights=[(linear_node, 1)],
268            # pyre-fixme[6]: Incompatible parameter type
269            biases=bias,
270            output=[(linear_node,)],
271        )
272
273    def replacement_op(self) -> OpOverload:
274        return torch.ops.cadence.quantized_linear.default
275
276
277class MatmulPattern(QuantizationPattern):
278    def partition_types(self) -> List[OpOverload]:
279        return [torch.ops.aten.matmul.default]
280
281    def get_anchors(
282        self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
283    ) -> PartitionAnchors:
284        # pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge...
285        matmul_node = fused_partition[0].nodes[-1]
286
287        return PartitionAnchors(
288            inputs=[(matmul_node, 0), (matmul_node, 1)],
289            weights=[],
290            biases=[],
291            output=[(matmul_node,)],
292        )
293
294    def replacement_op(self) -> OpOverload:
295        return torch.ops.cadence.quantized_matmul.default
296
297
298# This is a base class for ReLU, since it can be used with two different aten ops
299class ReluBasePattern(QuantizationPattern):
300    @abstractmethod
301    def partition_types(self) -> List[OpOverload]:
302        pass
303
304    def get_anchors(
305        self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
306    ) -> PartitionAnchors:
307        # pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge...
308        relu_node = fused_partition[0].nodes[-1]
309
310        return PartitionAnchors(
311            inputs=[(relu_node, 0)],
312            weights=[],
313            biases=[],
314            output=[(relu_node,)],
315        )
316
317    def replacement_op(self) -> OpOverload:
318        return torch.ops.cadence.quantized_relu.default
319
320
321# Regular relu op
322class ReluPattern0(ReluBasePattern):
323    def partition_types(self) -> List[OpOverload]:
324        return [torch.ops.aten.relu.default]
325
326
327# Alternate relu op
328class ReluPattern1(ReluBasePattern):
329    def partition_types(self) -> List[OpOverload]:
330        return [torch.ops.aten.relu_.default]
331