1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7from enum import IntEnum 8from typing import Optional, Set, Tuple 9 10import torch 11 12from executorch.backends.vulkan.serialization.vulkan_graph_schema import ( 13 VkMemoryLayout, 14 VkStorageType, 15) 16 17from executorch.exir.tensor import TensorSpec 18 19from torch._export.utils import is_buffer, is_param 20 21from torch._subclasses.fake_tensor import FakeTensor 22 23from torch.export import ExportedProgram 24 25## 26## Node type determination 27## 28 29 30def is_get_attr_node(node: torch.fx.Node) -> bool: 31 return isinstance(node, torch.fx.Node) and node.op == "get_attr" 32 33 34def is_constant(program: ExportedProgram, node: torch.fx.Node) -> bool: 35 return node.name in program.graph_signature.inputs_to_lifted_tensor_constants 36 37 38def is_param_node(program: ExportedProgram, node: torch.fx.Node) -> bool: 39 """ 40 Check if the given node is a parameter within the exported program 41 """ 42 return ( 43 is_get_attr_node(node) 44 or is_param(program, node) 45 or is_buffer(program, node) 46 or is_constant(program, node) 47 ) 48 49 50def is_symint_node(node: torch.fx.Node) -> bool: 51 """ 52 Returns true if the given node produces a SymInt value 53 """ 54 if "val" not in node.meta: 55 return False 56 57 if isinstance(node.meta["val"], torch.SymInt): 58 return True 59 60 return False 61 62 63def is_tensor_node(node: torch.fx.Node) -> bool: 64 """ 65 Returns true if the given node produces a tensor value, or a collection of tensor values 66 """ 67 # All nodes with tensor values are tagged by the SpecPropPass transform 68 if "spec" in node.meta: 69 return True 70 71 if "val" not in node.meta: 72 return False 73 74 if isinstance(node.meta["val"], FakeTensor): 75 return True 76 77 if isinstance(node.meta["val"], list) or isinstance(node.meta["val"], tuple): 78 return all(isinstance(x, FakeTensor) for x in node.meta["val"]) 79 80 return False 81 82 83## 84## Memory Layout, Storage Type Determination 85## 86 87ImageExtents = Tuple[int, int, int] 88 89DEFAULT_TEXTURE_LIMITS = (16384, 16384, 2048) 90DEFAULT_BUFFER_LIMIT = 128 * (1024 * 1024) 91 92 93class PackedDim(IntEnum): 94 WIDTH = 0 95 HEIGHT = 1 96 CHANNELS = 2 97 98 99all_packed_dims: Set[PackedDim] = { 100 PackedDim.WIDTH, 101 PackedDim.HEIGHT, 102 PackedDim.CHANNELS, 103} 104 105all_storage_types: Set[VkStorageType] = { 106 VkStorageType.BUFFER, 107 VkStorageType.TEXTURE_3D, 108} 109 110all_memory_layouts: Set[VkMemoryLayout] = { 111 VkMemoryLayout.TENSOR_WIDTH_PACKED, 112 VkMemoryLayout.TENSOR_HEIGHT_PACKED, 113 VkMemoryLayout.TENSOR_CHANNELS_PACKED, 114} 115 116 117def within_buffer_limit(node: torch.fx.Node, buffer_limit: int) -> int: 118 """ 119 Checks whether the tensors produced by the given node can fit within the device's 120 GPU buffer limit, which represents the maximum number of elements that can be stored 121 in a GPU buffer. 122 """ 123 assert is_tensor_node(node) 124 125 if isinstance(node.meta["val"], FakeTensor): 126 return node.meta["val"].numel() < buffer_limit 127 elif isinstance(node.meta["val"], list) or isinstance(node.meta["val"], tuple): 128 return all(x.numel() < buffer_limit for x in node.meta["val"]) 129 else: 130 raise RuntimeError(f"Cannot get numel for val of type {type(node.meta['val'])}") 131 132 133def required_image_extents(sizes: torch.Size, layout: VkMemoryLayout) -> ImageExtents: 134 """ 135 Calculate the image extents that will be used to represent a tensor with the given sizes 136 and memory layout in the Vulkan Delegate. 137 """ 138 width = sizes[-1] if len(sizes) >= 1 else 1 139 height = sizes[-2] if len(sizes) >= 2 else 1 140 channels = sizes[-3] if len(sizes) >= 3 else 1 141 batch = sizes[0] if len(sizes) >= 4 else 1 142 143 if layout == VkMemoryLayout.TENSOR_WIDTH_PACKED: 144 width = (width + 3) // 4 145 elif layout == VkMemoryLayout.TENSOR_HEIGHT_PACKED: 146 height = (height + 3) // 4 147 elif layout == VkMemoryLayout.TENSOR_CHANNELS_PACKED: 148 channels = (channels + 3) // 4 149 else: 150 raise RuntimeError(f"Unsupported memory layout {layout}") 151 152 return width, height, channels * batch 153 154 155def extents_are_valid(extents: ImageExtents, limits: ImageExtents) -> bool: 156 return all(extents[i] <= limits[i] for i in range(len(extents))) 157 158 159def valid_texture_memory_layouts( 160 tensor_sizes: torch.Size, texture_limits: ImageExtents 161) -> Set[VkMemoryLayout]: 162 """ 163 Given tensor sizes, determine the set of memory layouts which will prodice a texture 164 that can fit within the specified device limits. 165 """ 166 valid_layouts = set() 167 for layout in list(all_memory_layouts): 168 extents = required_image_extents(tensor_sizes, layout) 169 if extents_are_valid(extents, texture_limits): 170 valid_layouts.add(layout) 171 172 return valid_layouts 173 174 175def possible_node_memory_layouts( 176 node: torch.fx.Node, texture_limits: ImageExtents 177) -> Set[VkMemoryLayout]: 178 """ 179 Given a node, determine the set of memory layouts which can be used to represent all 180 tensors involved in the computation. 181 """ 182 assert is_tensor_node(node) 183 if isinstance(node.meta["val"], FakeTensor): 184 return valid_texture_memory_layouts(node.meta["val"].shape, texture_limits) 185 valid_layouts = set() 186 if isinstance(node.meta["val"], list) or isinstance(node.meta["val"], tuple): 187 for fake_tensor in node.meta["val"]: 188 valid_layouts = valid_layouts.union( 189 valid_texture_memory_layouts(fake_tensor.shape, texture_limits) 190 ) 191 192 return valid_layouts 193 194 195## 196## TensorSpec Utils 197## 198 199 200def set_node_spec_attr(node: torch.fx.Node, attr: str, value): 201 assert "spec" in node.meta 202 spec = node.meta["spec"] 203 if isinstance(spec, TensorSpec): 204 setattr(spec, attr, value) 205 elif isinstance(spec, list) or isinstance(spec, tuple): 206 for s in spec: 207 assert isinstance(s, TensorSpec) 208 setattr(s, attr, value) 209 else: 210 raise RuntimeError(f"Cannot set attr for spec of type {type(spec)}") 211 212 213def get_node_spec_attr(node: torch.fx.Node, attr: str, return_first: bool = True): 214 assert "spec" in node.meta 215 spec = node.meta["spec"] 216 if isinstance(spec, TensorSpec): 217 return getattr(spec, attr) if hasattr(spec, attr) else None 218 elif isinstance(spec, list) or isinstance(spec, tuple): 219 if return_first: 220 return getattr(spec[0], attr) if hasattr(spec, attr) else None 221 else: 222 return [getattr(s, attr) if hasattr(s, attr) else None for s in spec] 223 else: 224 raise RuntimeError(f"Cannot get attr for spec of type {type(spec)}") 225 226 227def get_node_storage_type(node: torch.fx.Node) -> Optional[VkStorageType]: 228 return get_node_spec_attr(node, "vk_storage_type") 229 230 231def get_node_memory_layout(node: torch.fx.Node) -> Optional[VkMemoryLayout]: 232 return get_node_spec_attr(node, "vk_memory_layout") 233