xref: /aosp_15_r20/external/executorch/backends/xnnpack/partition/config/xnnpack_config.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import logging
8from abc import abstractmethod
9from enum import Enum
10from typing import List, Optional
11
12import torch
13from executorch.exir.backend.canonical_partitioners.config_partitioner import (
14    format_target_name,
15    PartitionerConfig,
16)
17from executorch.exir.backend.utils import WhyNoPartition
18from torch.export import ExportedProgram
19
20logger = logging.getLogger(__name__)
21why = WhyNoPartition(logger=logger)
22
23
24class ConfigPrecisionType(Enum):
25    FP32 = 1
26    STATIC_QUANT = 2
27    DYNAMIC_QUANT = 3
28
29
30class XNNPartitionerConfig(PartitionerConfig):
31    """
32    Base partitioner config for XNNPACK Partitioner Configs. Base wrapper class
33    for all XNNPACK Partitioner Configs allows us to apply control over
34    all PartitionerConfigs. XNNPACK Partitioner config also sets a property
35    for supported precision types. This allows partitioner configs to set
36    the precision types they support, and let users toggle which precision
37    types they want to enable
38    """
39
40    def __init__(self, **kwargs):
41        super().__init__()
42        self.enabled_precision_types = self.supported_precision_types()
43        # Flag used in GEMMConfig()
44        self.force_fp32_dynamic_linear = kwargs.get("force_fp32_dynamic_linear", False)
45
46    def get_partition(
47        self, node: torch.fx.Node, ep: ExportedProgram
48    ) -> List[torch.fx.Node]:
49        """
50        Overriding abstract method get_partition.
51
52        Returns the partitioned nodes from get_node_and_deps, but also labels them
53        with the name of the XNNPartitionerConfig class which return this set of nodes.
54        This enforces that all partitions returned by XNNPartitioner configs are labeled
55        with the partitioner config which returned them
56        """
57        partitioned_nodes = self.get_node_and_deps(node, ep)
58        # label partitioned nodes with the name of the partitioner config
59        for node in partitioned_nodes:
60            if "xnn_partitioner_config" in node.meta:
61                node.meta["xnn_partitioner_config"].append(self.__class__.__name__)
62            else:
63                node.meta["xnn_partitioner_config"] = [self.__class__.__name__]
64
65        return partitioned_nodes
66
67    def get_original_aten(self) -> Optional[torch._ops.OpOverload]:
68        # By default if not specified, we do not halt decomposition for those configs
69        return None
70
71    @abstractmethod
72    def supported_precision_types(self) -> List[ConfigPrecisionType]:
73        """
74        Returns the supported PrecisionType of this partitioner config
75        """
76        pass
77
78    @abstractmethod
79    def get_node_and_deps(
80        self, node: torch.fx.Node, ep: ExportedProgram
81    ) -> List[torch.fx.Node]:
82        """
83        Takes in a node and its exported program and returns a list of nodes
84        and its dependencies that need to be partitioned together
85
86        Args:
87            node: Node to be partitioned
88            ep: Exported program of the graph module
89        Returns:
90            List of nodes that can be partitioned
91        """
92        pass
93
94    def set_enabled_precision_types(
95        self, precision_types: Optional[List[ConfigPrecisionType]]
96    ):
97        """
98        Set the enabled precisions.
99
100        We take the intersection of the precision_types we wish to enable with
101        the precision types that this config supports. If enabled_precisions is empty, i.e.
102        the config does not support any of the precision types we want to enable,
103        then we will not partition nothing and return false at the common constraints
104        """
105
106        if precision_types:
107            enabled_precisions = []
108            for precision in precision_types:
109                if precision in self.supported_precision_types():
110                    enabled_precisions.append(precision)
111
112            self.enabled_precision_types = enabled_precisions
113
114    def check_common_constraints(
115        self, node: torch.fx.Node, ep: ExportedProgram
116    ) -> bool:
117        """
118        Checks common xnnpack constraints
119
120        Args:
121            node (torch.fx.Node): Node to check common constraints against
122            ep (ExportedProgram): Exported Program to check constraints against
123
124        Returns:
125            True or False whether this node is partitionable
126        """
127        assert (
128            node.op == "call_function"
129            and format_target_name(node.target.__name__)  # pyre-ignore
130            == self.target_name
131        )
132
133        if len(self.enabled_precision_types) == 0:
134            why(node, reason="not enabled precision types")
135            return False
136
137        has_valid_dtypes = self._check_node_has_valid_dtype(node)
138        if not has_valid_dtypes:
139            why(node, reason="invalid dtype")
140            return False
141
142        return True
143
144    def _check_inputs_are_valid_dtypes(self, node, valid_dtypes):
145        # Check inputs are valid dtypes
146        # Gather all args which are nodes
147        args_to_check = []
148        for arg in node.args:
149            if isinstance(arg, list) or isinstance(arg, tuple):
150                for item in arg:
151                    if isinstance(item, torch.fx.Node):
152                        args_to_check.append(item)
153
154            if isinstance(arg, torch.fx.Node):
155                args_to_check.append(arg)
156
157        for arg in args_to_check:
158            arg_val = arg.meta.get("val", None)
159
160            if arg_val is None or isinstance(arg_val, tuple):
161                continue
162
163            # Being conservative for now, UX >> Perf
164            # TODO: We need a pass to scrub these out.
165            if not isinstance(arg_val, torch.Tensor):
166                return False
167
168            # XNNPACK does not support empty tensors
169            if arg_val.numel() == 0:
170                return False
171
172            if arg_val.dtype not in valid_dtypes:
173                return False
174
175        return True
176
177    def _check_outputs_are_valid_dtypes(self, node, valid_dtypes):
178        # Check outputs are valid dtype
179        node_val = node.meta.get("val", None)
180        if node_val is None:
181            return True
182
183        if not isinstance(node_val, tuple):
184            node_val = (node_val,)
185
186        for val in node_val:
187            if not isinstance(val, torch.Tensor):
188                return False
189
190            if val.dtype not in valid_dtypes:
191                return False
192
193        return True
194
195    def _check_node_has_valid_dtype(self, node):
196        valid_dtypes = {
197            torch.float32,
198            torch.float16,
199            torch.int8,
200            torch.qint8,
201        }
202        if (
203            node.op != "placeholder"
204            and node.op != "call_function"
205            and node.op != "get_attr"
206        ):
207            return False
208
209        return self._check_inputs_are_valid_dtypes(
210            node, valid_dtypes
211        ) and self._check_outputs_are_valid_dtypes(node, valid_dtypes)
212