xref: /aosp_15_r20/external/executorch/backends/xnnpack/partition/config/gemm_configs.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import logging
8from itertools import chain
9from typing import cast, List, Optional, Tuple
10
11import torch
12from executorch.backends.xnnpack.partition.config.xnnpack_config import (
13    ConfigPrecisionType,
14    XNNPartitionerConfig,
15)
16from executorch.backends.xnnpack.utils.quant_utils import (
17    extract_qdq_affine_op_args_for_decomposed_ops,
18    is_affine_qdq,
19    is_dequant,
20    is_dynamic_qdq,
21    is_per_channel,
22    is_per_channel_group,
23    is_qparam,
24    is_quant,
25)
26from executorch.backends.xnnpack.utils.utils import (
27    get_input_node,
28    is_getitem,
29    is_node,
30    is_param_node,
31)
32from executorch.exir.backend.canonical_partitioners.config_partitioner import (
33    format_target_name,
34)
35from executorch.exir.backend.utils import WhyNoPartition
36from torch.export import ExportedProgram
37from torch.fx.passes.utils.source_matcher_utils import (
38    get_source_partitions,
39    SourcePartition,
40)
41
42logger = logging.getLogger(__name__)
43why = WhyNoPartition(logger=logger)
44
45
46class GEMMConfig(XNNPartitionerConfig):
47    """
48    GEMM-like ops like Convolution, Addmm, Linear, mostly behave in the same way, in which we
49    have some weight, bias, and activation node. The only difference between these types
50    of ops are that the weight, bias, and activations are in different indicies of the
51    nodes arguments, this class helps to generalize the logic needed to partition these
52    different ops
53    """
54
55    def __init__(self, weight_idx, bias_idx, act_idx, fused_acts, **kwargs):
56        super().__init__(**kwargs)
57        self.weight_idx = weight_idx
58        self.bias_idx = bias_idx
59        self.act_idx = act_idx
60        self.fused_acts = fused_acts
61
62    def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool:
63        if not self.check_common_constraints(node, ep):
64            # short circuit if we don't pass common constraints
65            return False
66
67        is_valid, _ = self.get_deps(node, ep)
68        if not is_valid:
69            why(node, "Failed to get valid dependent nodes.")
70        return is_valid
71
72    def get_node_and_deps(
73        self, node: torch.fx.Node, ep: ExportedProgram
74    ) -> List[torch.fx.Node]:
75        partition = [node]
76        _, deps = self.get_deps(node, ep)
77        partition.extend(deps)
78
79        return partition
80
81    def get_original_aten(self) -> Optional[torch._ops.OpOverload]:
82        return None
83
84    def _detect_precision(self, node: torch.fx.Node) -> ConfigPrecisionType:
85        weight = get_input_node(node, self.weight_idx)
86
87        if not is_dequant(weight):
88            return ConfigPrecisionType.FP32
89
90        activation = get_input_node(node, self.act_idx)
91        if is_dynamic_qdq(activation):
92            return ConfigPrecisionType.DYNAMIC_QUANT
93
94        return ConfigPrecisionType.STATIC_QUANT
95
96    def get_deps(
97        self,
98        node: torch.fx.Node,
99        ep: ExportedProgram,
100    ) -> Tuple[bool, List[torch.fx.Node]]:
101        """
102        Gets all dependencies for this gemm partition. Returns a tuple of
103        a bool indicating if the deps are valid and a list of all the
104        dep nodes
105        """
106        precision = self._detect_precision(node)
107        if precision not in self.supported_precision_types():
108            # detected precision but it is either disabled or not supported
109            return (False, [])
110
111        valid_bias, bias_deps = self._get_bias_deps(node, ep, precision)
112        valid_weight, weight_deps = self._get_weight_deps(node, ep, precision)
113        valid_act, act_deps = self._get_act_deps(node, ep, precision)
114        valid_output, output_deps = self._get_output_deps(node, ep, precision)
115
116        valid_deps = valid_bias and valid_weight and valid_act and valid_output
117        deps = list(chain(bias_deps, weight_deps, act_deps, output_deps))
118
119        return valid_deps, deps
120
121    def _get_weight_deps(
122        self, node: torch.fx.Node, ep: ExportedProgram, precision: ConfigPrecisionType
123    ) -> Tuple[bool, List[torch.fx.Node]]:
124        gemm_deps = []
125        if precision == ConfigPrecisionType.FP32:
126            # First find the weight
127            weight_node = get_input_node(node, self.weight_idx)
128            if not is_param_node(ep, weight_node):
129                return (False, [])  # weight must be a static param
130            gemm_deps.append(weight_node)
131
132            return (True, gemm_deps)
133        else:
134            # Quantized Weight deps
135            dequant_node = get_input_node(node, self.weight_idx)
136            if not is_dequant(dequant_node):
137                return False, []
138            gemm_deps.append(dequant_node)
139            weight = get_input_node(dequant_node, 0)
140            if not is_param_node(ep, weight):
141                return False, []
142            gemm_deps.append(weight)
143
144            if is_per_channel(dequant_node) or is_per_channel_group(dequant_node):
145                if len(dequant_node.all_input_nodes) < 2:
146                    # Expected channel quantized to have scale/zp nodes
147                    return False, []
148
149                gemm_deps.extend(dequant_node.all_input_nodes[1:3])
150            return (True, gemm_deps)
151
152    def _get_output_deps(
153        self, node: torch.fx.Node, ep: ExportedProgram, precision: ConfigPrecisionType
154    ) -> Tuple[bool, List[torch.fx.Node]]:
155        gemm_deps = []
156        if precision == ConfigPrecisionType.STATIC_QUANT:
157            # Look for fused activations and tail end quant node
158            node_users = list(node.users.keys())
159            if len(node_users) != 1:
160                # Expect quantized node to have a single output (fused act or dequant)
161                return False, []
162
163            # Check if the quantized pattern has a fused activation
164            n_output = node_users[0]
165            if (
166                n_output.op == "call_function"
167                and format_target_name(n_output.target.__name__) in self.fused_acts
168            ):
169                gemm_deps.append(n_output)
170                fused_out_users = list(n_output.users.keys())
171                if len(fused_out_users) == 1:
172                    n_output = fused_out_users[0]
173
174            if not is_quant(n_output):
175                # Expected gemm_node --> fused_act (optional) --> dequant
176                return (False, [])
177            gemm_deps.append(n_output)
178        elif precision == ConfigPrecisionType.FP32:
179            # Look for fused activations only, and partition with fp32 op
180            node_users = list(node.users.keys())
181            if len(node_users) == 1:
182                n_output = node_users[0]
183                if (
184                    n_output.op == "call_function"
185                    and format_target_name(n_output.target.__name__) in self.fused_acts
186                ):
187                    gemm_deps.append(n_output)
188
189        # FP32 and Dynamic Quant have no output dependencies
190        return (True, gemm_deps)
191
192    def _get_bias_deps(
193        self, node: torch.fx.Node, ep: ExportedProgram, precision: ConfigPrecisionType
194    ) -> Tuple[bool, List[torch.fx.Node]]:
195        gemm_deps = []
196        if len(node.all_input_nodes) > 2 and self.bias_idx:
197            bias_node = get_input_node(node, self.bias_idx)
198            if bias_node:
199                if not is_param_node(ep, bias_node):
200                    return (False, [])  # bias node must be a static param
201                gemm_deps.append(bias_node)
202
203        return (True, gemm_deps)
204
205    def _get_act_deps(
206        self, node: torch.fx.Node, ep: ExportedProgram, precision: ConfigPrecisionType
207    ) -> Tuple[bool, List[torch.fx.Node]]:
208        gemm_deps = []
209        if precision == ConfigPrecisionType.FP32:
210            return (True, [])
211        else:
212            dq_input = get_input_node(node, self.act_idx)
213            if not is_dequant(dq_input):
214                # Expected static quant input to be dequant node
215                return False, []
216            gemm_deps.append(dq_input)
217            if precision == ConfigPrecisionType.STATIC_QUANT:
218                # if static quant we are done after finding first dq_input
219                return (True, gemm_deps)
220
221            # q input node
222            q_input = get_input_node(dq_input, 0)
223            if not is_quant(q_input):
224                return (False, [])
225
226            gemm_deps.append(q_input)
227            q_input_args = q_input.args
228            if is_affine_qdq(q_input):
229                q_input_args = extract_qdq_affine_op_args_for_decomposed_ops(q_input)
230            if not (is_node(q_input_args[1]) and is_node(q_input_args[2])):
231                # expected to find getitem node from choose qparam
232                return (False, [])
233
234            getitem1 = q_input_args[1]
235            getitem2 = q_input_args[2]
236
237            if not (is_getitem(getitem1) and is_getitem(getitem2)):
238                # expected getitem node from choose qparam
239                return (False, [])
240
241            gemm_deps.extend([getitem1, getitem2])
242            choose_qparam = get_input_node(getitem1, 0)
243            if not is_qparam(choose_qparam):
244                # expected to find choose_qparam node
245                return (False, [])
246            gemm_deps.append(choose_qparam)
247            return (True, gemm_deps)
248
249
250class LinearConfig(GEMMConfig):
251    target_name = "linear.default"
252
253    def __init__(self, **kwargs):
254        super().__init__(
255            weight_idx=1,
256            bias_idx=2,
257            act_idx=0,
258            fused_acts=["relu.default", "hardtanh.default"],
259            **kwargs,
260        )
261
262    def get_original_aten(self) -> Optional[torch._ops.OpOverload]:
263        return torch.ops.aten.linear.default
264
265    def _get_weight_deps(
266        self, node: torch.fx.Node, ep: ExportedProgram, precision: ConfigPrecisionType
267    ) -> Tuple[bool, List[torch.fx.Node]]:
268        if precision == ConfigPrecisionType.FP32 and self.force_fp32_dynamic_linear:
269            # if force fp32_dynamic_linear is on and we detected this as fp32, then we
270            # do not partition the weight node
271            return (True, [])
272
273        return super()._get_weight_deps(node, ep, precision)
274
275    def supported_precision_types(self):
276        return [
277            ConfigPrecisionType.DYNAMIC_QUANT,
278            ConfigPrecisionType.FP32,
279            ConfigPrecisionType.STATIC_QUANT,
280        ]
281
282
283class ConvolutionConfig(GEMMConfig):
284    target_name = "convolution.default"
285
286    def __init__(self, **kwargs):
287        super().__init__(
288            weight_idx=1,
289            bias_idx=2,
290            act_idx=0,
291            fused_acts=["relu.default", "hardtanh.default"],
292            **kwargs,
293        )
294
295    def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool:
296        """
297        Currently we have no support for convolution 3d and transposed convolution
298        """
299        if not super().check_constraints(node, ep):
300            return False
301
302        conv_stride = cast(List[int], node.args[3])
303        if len(conv_stride) > 2:
304            why(node, "Only support 1D + 2D Conv")
305            return False  # Only support 1D + 2D Conv
306
307        transposed = cast(bool, node.args[6])
308        if transposed:
309            why(node, "Transposed Conv is not supported")
310            return False  # Currently don't support transposed conv
311
312        return True
313
314    def supported_precision_types(self):
315        return [
316            ConfigPrecisionType.FP32,
317            ConfigPrecisionType.STATIC_QUANT,
318        ]
319
320
321class AddmmConfig(GEMMConfig):
322    """
323    We will handle the legacy form of addmm partitioning which will include
324    partitioning using source partitions.
325    """
326
327    target_name = "addmm.default"
328
329    def __init__(self, **kwargs):
330        super().__init__(
331            weight_idx=2,
332            bias_idx=0,
333            act_idx=1,
334            fused_acts=["relu.default", "hardtanh.default"],
335            **kwargs,
336        )
337        self.src_partitions = None
338        self.linear_modules = [torch.nn.functional.linear, torch.nn.Linear]
339
340    def get_deps(
341        self,
342        node: torch.fx.Node,
343        ep: ExportedProgram,
344    ) -> Tuple[bool, List[torch.fx.Node]]:
345        """
346        Gets all dependencies for this gemm partition. Returns a tuple of
347        a bool indicating if the deps are valid and a list of all the
348        dep nodes. This handles the src partition for
349        """
350        if self.src_partitions is None:
351            # Cache src partitions so we don't have to recompute them every time
352            self.src_partitions = get_source_partitions(ep.graph, self.linear_modules)
353
354        # src_partition is None if node is not in source partition,
355        # otherwise gives us the linear source partition it belongs to
356        src_partition = None
357        for partition_list in self.src_partitions.values():
358            for partition in partition_list:
359                if node in partition.nodes:
360                    src_partition = partition
361
362        if src_partition:
363            # if addmm belongs to linear src partition, then partition the
364            # src partition and get its deps
365            return self.get_deps_from_src_partition(node, ep, src_partition)
366
367        return super().get_deps(node, ep)
368
369    def get_deps_from_src_partition(
370        self, node: torch.fx.Node, ep: ExportedProgram, src_partition: SourcePartition
371    ):
372        """
373        Gets all the dependencies for the src partition. This is done by simulating the
374        linear node from the src partition. We find the associated weights, act, bias
375        from the linear src partition, and plug those in as the addmm node's args. We also
376        take the users of the src partitions output node as the addmm node's users. Finally
377        we just run the GEMMConfig's get_deps method no this faked linear node. After
378        getting the deps, we return the addmm nodes users and args back.
379        """
380
381        def find_partition_args(input_node):
382            while (
383                len(input_node.all_input_nodes) != 0
384                and input_node not in src_partition.input_nodes
385            ):
386                input_node = input_node.all_input_nodes[0]
387            return input_node
388
389        old_args, old_users = node.args, node.users
390
391        fake_args = []
392        for arg in node.args:
393            # map addmm's args to the source partition's inputs
394            # basically simulating what the args of the linear node would be
395            fake_args.append(find_partition_args(arg))
396
397        # validate source partition
398        if (
399            # bias must be in source partition
400            (self.bias_idx and fake_args[self.bias_idx] not in src_partition.nodes)
401            # activation input must be an input node to partition
402            or fake_args[self.act_idx] not in src_partition.input_nodes
403            # weight can either be in the nodes or input_nodes
404            or fake_args[self.weight_idx]
405            not in (src_partition.nodes + src_partition.input_nodes)
406            # there can only be a single output node in partition
407            or len(src_partition.output_nodes) != 1
408        ):
409            return (False, [])
410
411        # map addmm's args to the source partition linear's inputs and users
412        node.args = tuple(fake_args)
413        node.users = src_partition.output_nodes[0].users
414        valid_deps, deps = super().get_deps(node, ep)
415
416        # Reset addmm node back to old args and users
417        node.args = old_args
418        node.users = old_users
419
420        return valid_deps, list(set(deps) | set(src_partition.nodes))
421
422    def supported_precision_types(self):
423        return [
424            ConfigPrecisionType.FP32,
425            ConfigPrecisionType.STATIC_QUANT,
426            ConfigPrecisionType.DYNAMIC_QUANT,
427        ]
428
429
430class MMConfig(AddmmConfig):
431    target_name = "mm.default"
432
433    def __init__(self, **kwargs):
434        super().__init__(**kwargs)
435        self.bias_idx = None
436        self.weight_idx = 1
437        self.act_idx = 0
438
439    def supported_precision_types(self):
440        return [
441            ConfigPrecisionType.FP32,
442            ConfigPrecisionType.STATIC_QUANT,
443            ConfigPrecisionType.DYNAMIC_QUANT,
444        ]
445