1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7import logging 8from copy import deepcopy 9from typing import Set 10 11import executorch.backends.vulkan.utils as utils 12 13import torch 14 15from executorch.backends.vulkan.op_registry import get_op_features, has_impl 16 17from executorch.backends.vulkan.serialization.vulkan_graph_schema import ( 18 VkMemoryLayout, 19 VkStorageType, 20) 21 22from executorch.exir.dialects._ops import ops as exir_ops 23 24from executorch.exir.pass_base import ExportPass, PassResult 25 26from torch._subclasses.fake_tensor import FakeTensor 27 28from torch.fx.passes.tools_common import NodeList 29from torch.fx.passes.utils.fuser_utils import topo_sort 30 31logger: logging.Logger = logging.getLogger("") 32logger.setLevel(logging.INFO) 33 34 35def set_memory_metadata( 36 node: torch.fx.Node, storage: VkStorageType, layout: VkMemoryLayout 37) -> None: 38 utils.set_node_spec_attr(node, "vk_storage_type", storage) 39 utils.set_node_spec_attr(node, "vk_memory_layout", layout) 40 41 42def insert_transition_node( 43 graph_module: torch.fx.GraphModule, 44 node: torch.fx.Node, 45 arg: torch.fx.Node, 46 storage: VkStorageType, 47 layout: VkMemoryLayout, 48) -> None: 49 """ 50 Insert a clone node to copy the original tensor to a tensor with the desired storage 51 type and memory layout. 52 """ 53 with graph_module.graph.inserting_before(node): 54 clone_node = graph_module.graph.create_node( 55 "call_function", 56 exir_ops.edge.aten.clone.default, 57 (arg,), 58 ) 59 clone_node.meta["val"] = arg.meta["val"] 60 clone_node.meta["spec"] = deepcopy(arg.meta["spec"]) 61 clone_node.meta["spec"].const = False 62 set_memory_metadata(clone_node, storage, layout) 63 arg.replace_all_uses_with(clone_node, lambda x, y=node: x == y) 64 65 66class TagMemoryMetaPass(ExportPass): 67 """ 68 There are a variety of ways that tensors can be represented in Vulkan. The two main 69 descriptors for how a tensor is laid out in memory is: 70 71 1. Storage Type (buffer or texture) 72 2. Memory Layout (which dim is packed along a texel / has a stride of 1, etc.) 73 74 Due to the differences between buffers and textures, and the differences between 75 different memory layouts, an implementation for an operator may only support a 76 specific set of (storage type, memory layout) combinations. 77 78 Furthermore, if an operator implementation supports multiple (storage type, memory 79 layout) combinations, there may be a "preferred" setting which results in optimal 80 performance. 81 82 This pass is responsible for ensuring that all tensors participating in an operator 83 call have a valid/optimal (storage type, memory layout) setting, and insert 84 transition operators to transfer input tensors to the correct memory settings when 85 necessary. 86 """ 87 88 def __init__( 89 self, 90 texture_limits: utils.ImageExtents, 91 default_storage_type: VkStorageType = VkStorageType.TEXTURE_3D, 92 default_memory_layout: VkMemoryLayout = VkMemoryLayout.TENSOR_WIDTH_PACKED, 93 ): 94 super().__init__() 95 self.default_storage: VkStorageType = default_storage_type 96 self.default_layout: VkMemoryLayout = default_memory_layout 97 self.texture_limits = texture_limits 98 99 def propose_node_storage( 100 self, 101 node: torch.fx.Node, 102 ) -> VkStorageType: 103 """ 104 Uses the operator registry to determine the storage type that should be used for 105 a given node. The storage type is determined with the following priorities: 106 1. In some cases, a tensor involved in the computation may be too large to be 107 represented as a texture. If this is the case, the node is "opinionated" and 108 buffer representation must be used. 109 1. If the operator called by the node indicates an optimal storage type, or only 110 supports a single storage type, use that storage type. If either is true, 111 then the node is considered to be opinionated as well. If multiple storage 112 and no preferred storage type is indicated, then the node is not opinionated; 113 go to the next step. 114 2. If the node's arguments already have memory metadata annotations, then 115 preserve the settings of the first argument. Otherwise, proceed to the next 116 step. 117 3. Recursively search the node's uses to see if any subsequent uses are 118 opinionated; inherit the settings of the first opinionated node. If no 119 opinionated user can be found, then proceed to the last step. 120 4. Use the default storage type setting. 121 """ 122 # The node may have an input/output tensor that is too big to be stored in a 123 # texture. In this case, buffer storage must be used. Note that the partitioner 124 # has already checked for the fact that buffer storage is supported by the 125 # operator. 126 if len(utils.possible_node_memory_layouts(node, self.texture_limits)) == 0: 127 return VkStorageType.BUFFER 128 129 valid_storage_types: Set[VkStorageType] = utils.all_storage_types 130 131 # pyre-ignore 132 if has_impl(node.target): 133 # pyre-ignore 134 features = get_op_features(node.target) 135 valid_storage_types = features.supported_storage_types() 136 storage = features.propose_storage_type() 137 if storage is not None: 138 return storage 139 140 for arg in node.args: 141 if isinstance(arg, torch.fx.Node) and isinstance( 142 arg.meta["val"], FakeTensor 143 ): 144 storage = utils.get_node_storage_type(arg) 145 if storage is not None and storage in valid_storage_types: 146 return storage 147 148 # If no storage type has been resolved yet, assume the optimal storage type of 149 # the first opinionated user. This search is recursive. 150 for user in node.users: 151 optimal_storage = self.propose_node_storage(user) 152 if optimal_storage is not None: 153 return optimal_storage 154 155 if self.default_storage in valid_storage_types: 156 return self.default_storage 157 else: 158 return next(iter(valid_storage_types)) 159 160 def propose_node_layout( 161 self, 162 node: torch.fx.Node, 163 storage: VkStorageType, 164 ) -> VkMemoryLayout: 165 """ 166 Performs the same steps as propose_node_storage, but detects the memory layout 167 that should be used for the specific storage type. The same prioritization logic 168 is applied. 169 """ 170 valid_layouts: Set[VkMemoryLayout] = utils.all_memory_layouts 171 # pyre-ignore 172 if has_impl(node.target): 173 # pyre-ignore 174 features = get_op_features(node.target) 175 valid_layouts = features.supported_memory_layouts(storage) 176 layout = features.propose_memory_layout(storage) 177 if layout is not None: 178 return layout 179 180 for arg in node.args: 181 if isinstance(arg, torch.fx.Node) and isinstance( 182 arg.meta["val"], FakeTensor 183 ): 184 layout = utils.get_node_memory_layout(arg) 185 if layout is not None and layout in valid_layouts: 186 return layout 187 188 # If no storage type has been resolved yet, assume the optimal storage type of 189 # the first opinionated user. This search is recursive. 190 for user in node.users: 191 optimal_storage = self.propose_node_layout(user, storage) 192 if optimal_storage is not None: 193 return optimal_storage 194 195 # As a last resort, return the default storage type that should be used. 196 if self.default_layout in valid_layouts: 197 return self.default_layout 198 else: 199 return next(iter(valid_layouts)) 200 201 def should_annotate(self, node) -> bool: 202 if not isinstance(node, torch.fx.Node): 203 return False 204 205 if not isinstance(node.meta["val"], FakeTensor): 206 return False 207 208 # Storage type and memory layout for tensorref will be determined at runtime 209 # so there's no use in setting those attributes ahead of time. 210 if node.meta.get("vkdg_tensorref", False): 211 return False 212 213 return True 214 215 def should_delay_annotation(self, node: torch.fx.Node) -> bool: 216 # For prepack nodes, delay setting the storage type and memory layout as long as 217 # possible. This is to minimize the number of transitions, since it can be 218 # difficult to predict what storage type and memory layout should be used at the 219 # time the prepack node is observed. 220 return node.target == exir_ops.edge.et_vk.prepack.default 221 222 # noqa 223 def call(self, graph_module: torch.fx.GraphModule) -> PassResult: 224 sorted_nodes: NodeList = topo_sort(list(graph_module.graph.nodes)) 225 226 for node in sorted_nodes: 227 if not self.should_annotate(node) or self.should_delay_annotation(node): 228 continue 229 230 storage = self.propose_node_storage(node) 231 layout = self.propose_node_layout(node, storage) 232 233 set_memory_metadata(node, storage, layout) 234 235 inserting_transitions_for_node = False 236 for i, arg in enumerate(node.args): 237 if not self.should_annotate(arg): 238 continue 239 240 assert isinstance(arg, torch.fx.Node) 241 242 arg_storage = utils.get_node_storage_type(arg) 243 arg_layout = utils.get_node_memory_layout(arg) 244 245 if arg_storage is None: 246 utils.set_node_spec_attr(arg, "vk_storage_type", storage) 247 arg_storage = storage 248 if arg_layout is None: 249 utils.set_node_spec_attr(arg, "vk_memory_layout", layout) 250 arg_layout = layout 251 252 if arg_storage == storage and arg_layout == layout: 253 continue 254 255 if not inserting_transitions_for_node: 256 inserting_transitions_for_node = True 257 logger.info( 258 f"[Vulkan Delegate] Inserting transition(s) for {node.format_node()}:" 259 ) 260 261 insert_transition_node(graph_module, node, arg, storage, layout) 262 263 logger.info( 264 f" args {i} ({arg}): ({arg_storage}, {arg_layout}) -> ({storage}, {layout})" 265 ) 266 267 return PassResult(graph_module, True) 268