xref: /aosp_15_r20/external/executorch/backends/vulkan/_passes/tag_memory_meta_pass.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import logging
8from copy import deepcopy
9from typing import Set
10
11import executorch.backends.vulkan.utils as utils
12
13import torch
14
15from executorch.backends.vulkan.op_registry import get_op_features, has_impl
16
17from executorch.backends.vulkan.serialization.vulkan_graph_schema import (
18    VkMemoryLayout,
19    VkStorageType,
20)
21
22from executorch.exir.dialects._ops import ops as exir_ops
23
24from executorch.exir.pass_base import ExportPass, PassResult
25
26from torch._subclasses.fake_tensor import FakeTensor
27
28from torch.fx.passes.tools_common import NodeList
29from torch.fx.passes.utils.fuser_utils import topo_sort
30
31logger: logging.Logger = logging.getLogger("")
32logger.setLevel(logging.INFO)
33
34
35def set_memory_metadata(
36    node: torch.fx.Node, storage: VkStorageType, layout: VkMemoryLayout
37) -> None:
38    utils.set_node_spec_attr(node, "vk_storage_type", storage)
39    utils.set_node_spec_attr(node, "vk_memory_layout", layout)
40
41
42def insert_transition_node(
43    graph_module: torch.fx.GraphModule,
44    node: torch.fx.Node,
45    arg: torch.fx.Node,
46    storage: VkStorageType,
47    layout: VkMemoryLayout,
48) -> None:
49    """
50    Insert a clone node to copy the original tensor to a tensor with the desired storage
51    type and memory layout.
52    """
53    with graph_module.graph.inserting_before(node):
54        clone_node = graph_module.graph.create_node(
55            "call_function",
56            exir_ops.edge.aten.clone.default,
57            (arg,),
58        )
59        clone_node.meta["val"] = arg.meta["val"]
60        clone_node.meta["spec"] = deepcopy(arg.meta["spec"])
61        clone_node.meta["spec"].const = False
62        set_memory_metadata(clone_node, storage, layout)
63        arg.replace_all_uses_with(clone_node, lambda x, y=node: x == y)
64
65
66class TagMemoryMetaPass(ExportPass):
67    """
68    There are a variety of ways that tensors can be represented in Vulkan. The two main
69    descriptors for how a tensor is laid out in memory is:
70
71    1. Storage Type (buffer or texture)
72    2. Memory Layout (which dim is packed along a texel / has a stride of 1, etc.)
73
74    Due to the differences between buffers and textures, and the differences between
75    different memory layouts, an implementation for an operator may only support a
76    specific set of (storage type, memory layout) combinations.
77
78    Furthermore, if an operator implementation supports multiple (storage type, memory
79    layout) combinations, there may be a "preferred" setting which results in optimal
80    performance.
81
82    This pass is responsible for ensuring that all tensors participating in an operator
83    call have a valid/optimal (storage type, memory layout) setting, and insert
84    transition operators to transfer input tensors to the correct memory settings when
85    necessary.
86    """
87
88    def __init__(
89        self,
90        texture_limits: utils.ImageExtents,
91        default_storage_type: VkStorageType = VkStorageType.TEXTURE_3D,
92        default_memory_layout: VkMemoryLayout = VkMemoryLayout.TENSOR_WIDTH_PACKED,
93    ):
94        super().__init__()
95        self.default_storage: VkStorageType = default_storage_type
96        self.default_layout: VkMemoryLayout = default_memory_layout
97        self.texture_limits = texture_limits
98
99    def propose_node_storage(
100        self,
101        node: torch.fx.Node,
102    ) -> VkStorageType:
103        """
104        Uses the operator registry to determine the storage type that should be used for
105        a given node. The storage type is determined with the following priorities:
106        1. In some cases, a tensor involved in the computation may be too large to be
107           represented as a texture. If this is the case, the node is "opinionated" and
108           buffer representation must be used.
109        1. If the operator called by the node indicates an optimal storage type, or only
110           supports a single storage type, use that storage type. If either is true,
111           then the node is considered to be opinionated as well. If multiple storage
112           and no preferred storage type is indicated, then the node is not opinionated;
113           go to the next step.
114        2. If the node's arguments already have memory metadata annotations, then
115           preserve the settings of the first argument. Otherwise, proceed to the next
116           step.
117        3. Recursively search the node's uses to see if any subsequent uses are
118           opinionated; inherit the settings of the first opinionated node. If no
119           opinionated user can be found, then proceed to the last step.
120        4. Use the default storage type setting.
121        """
122        # The node may have an input/output tensor that is too big to be stored in a
123        # texture. In this case, buffer storage must be used. Note that the partitioner
124        # has already checked for the fact that buffer storage is supported by the
125        # operator.
126        if len(utils.possible_node_memory_layouts(node, self.texture_limits)) == 0:
127            return VkStorageType.BUFFER
128
129        valid_storage_types: Set[VkStorageType] = utils.all_storage_types
130
131        # pyre-ignore
132        if has_impl(node.target):
133            # pyre-ignore
134            features = get_op_features(node.target)
135            valid_storage_types = features.supported_storage_types()
136            storage = features.propose_storage_type()
137            if storage is not None:
138                return storage
139
140        for arg in node.args:
141            if isinstance(arg, torch.fx.Node) and isinstance(
142                arg.meta["val"], FakeTensor
143            ):
144                storage = utils.get_node_storage_type(arg)
145                if storage is not None and storage in valid_storage_types:
146                    return storage
147
148        # If no storage type has been resolved yet, assume the optimal storage type of
149        # the first opinionated user. This search is recursive.
150        for user in node.users:
151            optimal_storage = self.propose_node_storage(user)
152            if optimal_storage is not None:
153                return optimal_storage
154
155        if self.default_storage in valid_storage_types:
156            return self.default_storage
157        else:
158            return next(iter(valid_storage_types))
159
160    def propose_node_layout(
161        self,
162        node: torch.fx.Node,
163        storage: VkStorageType,
164    ) -> VkMemoryLayout:
165        """
166        Performs the same steps as propose_node_storage, but detects the memory layout
167        that should be used for the specific storage type. The same prioritization logic
168        is applied.
169        """
170        valid_layouts: Set[VkMemoryLayout] = utils.all_memory_layouts
171        # pyre-ignore
172        if has_impl(node.target):
173            # pyre-ignore
174            features = get_op_features(node.target)
175            valid_layouts = features.supported_memory_layouts(storage)
176            layout = features.propose_memory_layout(storage)
177            if layout is not None:
178                return layout
179
180        for arg in node.args:
181            if isinstance(arg, torch.fx.Node) and isinstance(
182                arg.meta["val"], FakeTensor
183            ):
184                layout = utils.get_node_memory_layout(arg)
185                if layout is not None and layout in valid_layouts:
186                    return layout
187
188        # If no storage type has been resolved yet, assume the optimal storage type of
189        # the first opinionated user. This search is recursive.
190        for user in node.users:
191            optimal_storage = self.propose_node_layout(user, storage)
192            if optimal_storage is not None:
193                return optimal_storage
194
195        # As a last resort, return the default storage type that should be used.
196        if self.default_layout in valid_layouts:
197            return self.default_layout
198        else:
199            return next(iter(valid_layouts))
200
201    def should_annotate(self, node) -> bool:
202        if not isinstance(node, torch.fx.Node):
203            return False
204
205        if not isinstance(node.meta["val"], FakeTensor):
206            return False
207
208        # Storage type and memory layout for tensorref will be determined at runtime
209        # so there's no use in setting those attributes ahead of time.
210        if node.meta.get("vkdg_tensorref", False):
211            return False
212
213        return True
214
215    def should_delay_annotation(self, node: torch.fx.Node) -> bool:
216        # For prepack nodes, delay setting the storage type and memory layout as long as
217        # possible. This is to minimize the number of transitions, since it can be
218        # difficult to predict what storage type and memory layout should be used at the
219        # time the prepack node is observed.
220        return node.target == exir_ops.edge.et_vk.prepack.default
221
222    # noqa
223    def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
224        sorted_nodes: NodeList = topo_sort(list(graph_module.graph.nodes))
225
226        for node in sorted_nodes:
227            if not self.should_annotate(node) or self.should_delay_annotation(node):
228                continue
229
230            storage = self.propose_node_storage(node)
231            layout = self.propose_node_layout(node, storage)
232
233            set_memory_metadata(node, storage, layout)
234
235            inserting_transitions_for_node = False
236            for i, arg in enumerate(node.args):
237                if not self.should_annotate(arg):
238                    continue
239
240                assert isinstance(arg, torch.fx.Node)
241
242                arg_storage = utils.get_node_storage_type(arg)
243                arg_layout = utils.get_node_memory_layout(arg)
244
245                if arg_storage is None:
246                    utils.set_node_spec_attr(arg, "vk_storage_type", storage)
247                    arg_storage = storage
248                if arg_layout is None:
249                    utils.set_node_spec_attr(arg, "vk_memory_layout", layout)
250                    arg_layout = layout
251
252                if arg_storage == storage and arg_layout == layout:
253                    continue
254
255                if not inserting_transitions_for_node:
256                    inserting_transitions_for_node = True
257                    logger.info(
258                        f"[Vulkan Delegate] Inserting transition(s) for {node.format_node()}:"
259                    )
260
261                insert_transition_node(graph_module, node, arg, storage, layout)
262
263                logger.info(
264                    f"   args {i} ({arg}): ({arg_storage}, {arg_layout}) -> ({storage}, {layout})"
265                )
266
267        return PassResult(graph_module, True)
268