xref: /aosp_15_r20/external/executorch/backends/vulkan/partitioner/vulkan_partitioner.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7# pyre-strict
8
9import logging
10from typing import Any, Callable, Dict, final, List, Mapping, Optional, Tuple
11
12import executorch.backends.vulkan.utils as utils
13
14import torch
15
16from executorch.backends.vulkan.op_registry import (
17    get_op_features,
18    has_impl,
19    OpFeatures,
20    vulkan_supported_ops,
21)
22
23from executorch.backends.vulkan.serialization.vulkan_graph_schema import (
24    VkMemoryLayout,
25    VkStorageType,
26)
27from executorch.backends.vulkan.vulkan_preprocess import VulkanBackend
28
29from executorch.exir.backend.compile_spec_schema import CompileSpec
30from executorch.exir.backend.partitioner import (
31    DelegationSpec,
32    Partitioner,
33    PartitionResult,
34)
35from executorch.exir.backend.utils import tag_constant_data
36from executorch.exir.dialects._ops import ops as exir_ops
37
38from torch.export.exported_program import ExportedProgram
39from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner
40
41from torch.fx.passes.operator_support import OperatorSupportBase
42
43# pyre-ignore
44ops_not_to_decompose = [
45    torch.ops.aten.upsample_nearest2d.vec,
46]
47
48logger: logging.Logger = logging.getLogger("")
49logger.setLevel(logging.INFO)
50
51
52class VulkanSupportedOperators(OperatorSupportBase):
53    def __init__(
54        self,
55        texture_limits: utils.ImageExtents,
56        buffer_limit: int,
57        require_dynamic_shape: bool = False,
58    ) -> None:
59        super().__init__()
60        self.texture_limits: utils.ImageExtents = texture_limits
61        self.buffer_limit = buffer_limit
62        self.require_dynamic_shapes = require_dynamic_shape
63        # The tensor dim limit is to guard against tensors with one or more
64        # large dimensions, which cannot be represented by an image texture due
65        # to the texture axis limits.
66        self.tensor_dim_limit = 16384
67
68    def op_node_is_compatible(
69        self, node: torch.fx.Node, features: Optional[OpFeatures] = None
70    ) -> Tuple[bool, str]:
71        """
72        Check if a given node is compatible with the Vulkan delegate's implementation
73        of the operator called by the node. Each tensor argument participating in the
74        operator call must be able to be represented with a (storage type, memory layout)
75        combination that is supported by the operator implementation.
76        """
77        target = node.target
78        # Account for custom operators
79        if node.target == torch.ops.higher_order.auto_functionalized:
80            first_arg = node.args[0]
81            assert isinstance(first_arg, torch._ops.OpOverload)
82            target = first_arg.name()
83
84        # Extract the features for the node's operator, if no override was provided
85        if features is None:
86            if not has_impl(target):
87                return False, "no operator implementation"
88            features = get_op_features(target)
89
90        valid_texture_layouts = utils.possible_node_memory_layouts(
91            node, self.texture_limits
92        )
93
94        can_use_buffers = utils.within_buffer_limit(node, self.buffer_limit)
95        for i, arg in enumerate(node.args):
96            if (
97                isinstance(arg, torch.fx.Node)
98                and utils.is_tensor_node(arg)
99                and i not in features.skip_limits_check
100            ):
101                arg_texture_layouts = utils.possible_node_memory_layouts(
102                    arg, self.texture_limits
103                )
104                valid_texture_layouts = valid_texture_layouts.intersection(
105                    arg_texture_layouts
106                )
107                can_use_buffers = can_use_buffers and utils.within_buffer_limit(
108                    arg, self.buffer_limit
109                )
110
111        # If there are no valid texture memory layouts, then buffer storage must be
112        # supported by the operator implementation.
113        if len(valid_texture_layouts) == 0:
114            if not can_use_buffers:
115                return (
116                    False,
117                    f"op requires buffers that exceed the buffer limit ({self.buffer_limit})",
118                )
119
120            compatible = VkStorageType.BUFFER in features.supported_storage_types()
121            reason = "op is compatible"
122            if not compatible:
123                reason = "op requires buffers which is not supported by op impl"
124            return compatible, reason
125
126        op_available_layouts = features.supported_memory_layouts(
127            VkStorageType.TEXTURE_3D
128        )
129
130        is_compatible = any(
131            layout in op_available_layouts for layout in valid_texture_layouts
132        )
133        if not is_compatible:
134            return False, "Required texutre memory layout not supported"
135
136        return is_compatible, "Op is compatible"
137
138    def node_is_compatible(
139        self, node: torch.fx.Node, features: Optional[OpFeatures] = None
140    ) -> Tuple[bool, str]:
141        # TODO(ssjia) support symbolic ints
142        if utils.is_symint_node(node):
143            return False, "symint node not supported yet"
144        elif utils.is_tensor_node(node):
145            return self.op_node_is_compatible(node, features=features)
146
147        return False, f"Unsupported node type: {node.format_node()}"
148
149    def is_linear_permute(self, node: torch.fx.Node) -> Tuple[bool, bool]:
150        """
151        Detect if a node is a permute/transpose that precedes a call to a `mm` or
152        `addmm` operator. This node can be fused with the `mm` or `addmm` to produce a
153        `linear` operator.
154
155        This function returns two bool values:
156        1. The first indicates if this node can be fused into a linear node
157        2. The second indicates if the overall linear op can be executed with Vulkan
158
159        The node will be partitioned only if both are true.
160        """
161        if node.target not in [
162            exir_ops.edge.aten.t_copy.default,
163            exir_ops.edge.aten.permute_copy.default,
164        ]:
165            return False, False
166
167        if len(node.users) != 1:
168            return False, False
169
170        first_user = list(node.users.keys())[0]
171        if first_user.target in [
172            exir_ops.edge.aten.mm.default,
173            exir_ops.edge.aten.addmm.default,
174        ]:
175            # Only mark this node if the target linear op is valid
176            if self.node_is_compatible(first_user)[0]:
177                return True, True
178            else:
179                return True, False
180
181        return False, False
182
183    def is_in_local_scalar_dense_chain(self, node: torch.fx.Node) -> Tuple[bool, bool]:
184        """
185        Scalar tensors are usually converted to scalar values in the graph via`
186        scalar_tensor[0].item()` in Python, which translates to a chain of
187        `local_scalar_dense(torch.select.int(scalar_tensor, 0, 0))` in the graph.
188        This function marks the entire chain as supported by the Vulkan delegate.
189
190        Later, within vulkan_preprocess there will be a graph transform which replaces
191        the chain with passing in the scalar tensor directly.
192
193        Similar to the `is_linear_permute` function, this function has 2 return values.
194        """
195        if node.target == exir_ops.edge.aten.select_copy.int:
196            if len(node.users) != 1:
197                return False, False
198            # pyre-ignore
199            if node.args[0].meta["val"].numel() != 1:
200                return False, False
201
202            local_scalar_dense = list(node.users.keys())[0]
203            if local_scalar_dense.target != torch.ops.aten._local_scalar_dense.default:
204                return False, False
205
206            return self.is_in_local_scalar_dense_chain(local_scalar_dense)
207
208        if node.target == torch.ops.aten._local_scalar_dense.default:
209            return True, all(self.node_is_compatible(user)[0] for user in node.users)
210
211        return False, False
212
213    def log_skip(self, node: torch.fx.Node, reason: str) -> None:
214        if node.op == "call_function":
215            logger.info(
216                f"[Vulkan Partitioner] Due to [{reason}], skipping {node.format_node()}"
217            )
218
219    def is_node_supported(
220        self, submodules: Mapping[str, torch.nn.Module], node: torch.fx.Node
221    ) -> bool:
222        r = self._is_node_supported(node)
223        return r
224
225    def _is_node_supported(self, node: torch.fx.Node) -> bool:
226        target = node.target
227        if node.target == torch.ops.higher_order.auto_functionalized:
228            first_arg = node.args[0]
229            assert isinstance(first_arg, torch._ops.OpOverload)
230            target = first_arg.name()
231
232        is_linear_permute, target_linear_is_compatible = self.is_linear_permute(node)
233        if is_linear_permute and target_linear_is_compatible:
234            return True
235        elif is_linear_permute:
236            # Skip so that the permute can be fused into a linear by another backend
237            self.log_skip(node, "permute node of non compatible linear node")
238            return False
239
240        is_in_local_scalar_dense_chain, dst_node_is_compatible = (
241            self.is_in_local_scalar_dense_chain(node)
242        )
243        if is_in_local_scalar_dense_chain and dst_node_is_compatible:
244            return True
245        elif is_in_local_scalar_dense_chain:
246            self.log_skip(node, "local scalar dense of incompatible op node")
247            return False
248
249        if target not in vulkan_supported_ops:
250            self.log_skip(node, "no operator implementation")
251            return False
252
253        features = vulkan_supported_ops[target]
254
255        if not features.check_node_fn(node):
256            self.log_skip(node, "op args not supported")
257            return False
258
259        if self.require_dynamic_shapes and not features.resize_fn:
260            self.log_skip(node, "no dynamic shape support")
261            return False
262
263        is_compatible, reason = self.node_is_compatible(node, features=features)
264        if not is_compatible:
265            self.log_skip(node, reason)
266
267        return is_compatible
268
269
270def parse_compile_options(compile_options: Dict[str, Any]) -> List[CompileSpec]:
271    compile_specs = []
272
273    for key, value in compile_options.items():
274        if isinstance(value, (VkStorageType, VkMemoryLayout)):
275            value_bytes = int(value).to_bytes(4, byteorder="little")
276            compile_specs.append(CompileSpec(key, value_bytes))
277
278        if isinstance(value, bool):
279            value_bytes = value.to_bytes(1, byteorder="little")
280            compile_specs.append(CompileSpec(key, value_bytes))
281
282        if key == "texture_limits":
283            compile_specs.append(
284                CompileSpec(
285                    "texture_limits_x", int(value[0]).to_bytes(4, byteorder="little")
286                )
287            )
288            compile_specs.append(
289                CompileSpec(
290                    "texture_limits_y", int(value[1]).to_bytes(4, byteorder="little")
291                )
292            )
293            compile_specs.append(
294                CompileSpec(
295                    "texture_limits_z", int(value[2]).to_bytes(4, byteorder="little")
296                )
297            )
298
299        # Unhandled options are ignored
300
301    return compile_specs
302
303
304@final
305class VulkanPartitioner(Partitioner):
306    def __init__(
307        self,
308        compile_options: Optional[Dict[str, Any]] = None,
309    ) -> None:
310        self.options: Dict[str, Any] = {}
311        if compile_options is not None:
312            self.options = compile_options
313
314        compile_spec = parse_compile_options(self.options)
315        self.delegation_spec = DelegationSpec(VulkanBackend.__name__, compile_spec)
316
317    def ops_to_not_decompose(
318        self, ep: ExportedProgram
319    ) -> Tuple[List[torch._ops.OpOverload], Optional[Callable[[torch.fx.Node], bool]]]:
320        return (ops_not_to_decompose, None)
321
322    def partition(self, exported_program: ExportedProgram) -> PartitionResult:
323        # Run the CapabilityBasedPartitioner to return the largest possible
324        # subgraphs containing the nodes with the tags
325        partition_tags = {}
326
327        texture_limits: utils.ImageExtents = self.options.get(
328            "texture_limits", utils.DEFAULT_TEXTURE_LIMITS
329        )
330        buffer_limit: int = self.options.get("buffer_limit", utils.DEFAULT_BUFFER_LIMIT)
331        capability_partitioner = CapabilityBasedPartitioner(
332            exported_program.graph_module,
333            VulkanSupportedOperators(
334                texture_limits,
335                buffer_limit,
336                require_dynamic_shape=self.options.get("require_dynamic_shapes", False),
337            ),
338            allows_single_node_partition=True,
339        )
340        partition_list = capability_partitioner.propose_partitions()
341        for partition in partition_list:
342            for node in partition.nodes:
343                tag = f"tag{partition.id}"
344                node.meta["delegation_tag"] = tag
345                partition_tags[tag] = self.delegation_spec
346
347        pl = len(partition_list)
348        if pl == 0:
349            logger.warning("No Vulkan subgraphs can be partitioned!")
350        else:
351            logger.info(f"Found {pl} Vulkan subgraphs to be partitioned.")
352
353        tag_constant_data(exported_program)
354
355        return PartitionResult(
356            tagged_exported_program=exported_program, partition_tags=partition_tags
357        )
358