1*523fa7a6SAndroid Build Coastguard Worker# Copyright (c) Meta Platforms, Inc. and affiliates. 2*523fa7a6SAndroid Build Coastguard Worker# All rights reserved. 3*523fa7a6SAndroid Build Coastguard Worker# 4*523fa7a6SAndroid Build Coastguard Worker# This source code is licensed under the BSD-style license found in the 5*523fa7a6SAndroid Build Coastguard Worker# LICENSE file in the root directory of this source tree. 6*523fa7a6SAndroid Build Coastguard Worker 7*523fa7a6SAndroid Build Coastguard Workerfrom enum import IntEnum 8*523fa7a6SAndroid Build Coastguard Workerfrom typing import Optional, Set, Tuple 9*523fa7a6SAndroid Build Coastguard Worker 10*523fa7a6SAndroid Build Coastguard Workerimport torch 11*523fa7a6SAndroid Build Coastguard Worker 12*523fa7a6SAndroid Build Coastguard Workerfrom executorch.backends.vulkan.serialization.vulkan_graph_schema import ( 13*523fa7a6SAndroid Build Coastguard Worker VkMemoryLayout, 14*523fa7a6SAndroid Build Coastguard Worker VkStorageType, 15*523fa7a6SAndroid Build Coastguard Worker) 16*523fa7a6SAndroid Build Coastguard Worker 17*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.tensor import TensorSpec 18*523fa7a6SAndroid Build Coastguard Worker 19*523fa7a6SAndroid Build Coastguard Workerfrom torch._export.utils import is_buffer, is_param 20*523fa7a6SAndroid Build Coastguard Worker 21*523fa7a6SAndroid Build Coastguard Workerfrom torch._subclasses.fake_tensor import FakeTensor 22*523fa7a6SAndroid Build Coastguard Worker 23*523fa7a6SAndroid Build Coastguard Workerfrom torch.export import ExportedProgram 24*523fa7a6SAndroid Build Coastguard Worker 25*523fa7a6SAndroid Build Coastguard Worker## 26*523fa7a6SAndroid Build Coastguard Worker## Node type determination 27*523fa7a6SAndroid Build Coastguard Worker## 28*523fa7a6SAndroid Build Coastguard Worker 29*523fa7a6SAndroid Build Coastguard Worker 30*523fa7a6SAndroid Build Coastguard Workerdef is_get_attr_node(node: torch.fx.Node) -> bool: 31*523fa7a6SAndroid Build Coastguard Worker return isinstance(node, torch.fx.Node) and node.op == "get_attr" 32*523fa7a6SAndroid Build Coastguard Worker 33*523fa7a6SAndroid Build Coastguard Worker 34*523fa7a6SAndroid Build Coastguard Workerdef is_constant(program: ExportedProgram, node: torch.fx.Node) -> bool: 35*523fa7a6SAndroid Build Coastguard Worker return node.name in program.graph_signature.inputs_to_lifted_tensor_constants 36*523fa7a6SAndroid Build Coastguard Worker 37*523fa7a6SAndroid Build Coastguard Worker 38*523fa7a6SAndroid Build Coastguard Workerdef is_param_node(program: ExportedProgram, node: torch.fx.Node) -> bool: 39*523fa7a6SAndroid Build Coastguard Worker """ 40*523fa7a6SAndroid Build Coastguard Worker Check if the given node is a parameter within the exported program 41*523fa7a6SAndroid Build Coastguard Worker """ 42*523fa7a6SAndroid Build Coastguard Worker return ( 43*523fa7a6SAndroid Build Coastguard Worker is_get_attr_node(node) 44*523fa7a6SAndroid Build Coastguard Worker or is_param(program, node) 45*523fa7a6SAndroid Build Coastguard Worker or is_buffer(program, node) 46*523fa7a6SAndroid Build Coastguard Worker or is_constant(program, node) 47*523fa7a6SAndroid Build Coastguard Worker ) 48*523fa7a6SAndroid Build Coastguard Worker 49*523fa7a6SAndroid Build Coastguard Worker 50*523fa7a6SAndroid Build Coastguard Workerdef is_symint_node(node: torch.fx.Node) -> bool: 51*523fa7a6SAndroid Build Coastguard Worker """ 52*523fa7a6SAndroid Build Coastguard Worker Returns true if the given node produces a SymInt value 53*523fa7a6SAndroid Build Coastguard Worker """ 54*523fa7a6SAndroid Build Coastguard Worker if "val" not in node.meta: 55*523fa7a6SAndroid Build Coastguard Worker return False 56*523fa7a6SAndroid Build Coastguard Worker 57*523fa7a6SAndroid Build Coastguard Worker if isinstance(node.meta["val"], torch.SymInt): 58*523fa7a6SAndroid Build Coastguard Worker return True 59*523fa7a6SAndroid Build Coastguard Worker 60*523fa7a6SAndroid Build Coastguard Worker return False 61*523fa7a6SAndroid Build Coastguard Worker 62*523fa7a6SAndroid Build Coastguard Worker 63*523fa7a6SAndroid Build Coastguard Workerdef is_tensor_node(node: torch.fx.Node) -> bool: 64*523fa7a6SAndroid Build Coastguard Worker """ 65*523fa7a6SAndroid Build Coastguard Worker Returns true if the given node produces a tensor value, or a collection of tensor values 66*523fa7a6SAndroid Build Coastguard Worker """ 67*523fa7a6SAndroid Build Coastguard Worker # All nodes with tensor values are tagged by the SpecPropPass transform 68*523fa7a6SAndroid Build Coastguard Worker if "spec" in node.meta: 69*523fa7a6SAndroid Build Coastguard Worker return True 70*523fa7a6SAndroid Build Coastguard Worker 71*523fa7a6SAndroid Build Coastguard Worker if "val" not in node.meta: 72*523fa7a6SAndroid Build Coastguard Worker return False 73*523fa7a6SAndroid Build Coastguard Worker 74*523fa7a6SAndroid Build Coastguard Worker if isinstance(node.meta["val"], FakeTensor): 75*523fa7a6SAndroid Build Coastguard Worker return True 76*523fa7a6SAndroid Build Coastguard Worker 77*523fa7a6SAndroid Build Coastguard Worker if isinstance(node.meta["val"], list) or isinstance(node.meta["val"], tuple): 78*523fa7a6SAndroid Build Coastguard Worker return all(isinstance(x, FakeTensor) for x in node.meta["val"]) 79*523fa7a6SAndroid Build Coastguard Worker 80*523fa7a6SAndroid Build Coastguard Worker return False 81*523fa7a6SAndroid Build Coastguard Worker 82*523fa7a6SAndroid Build Coastguard Worker 83*523fa7a6SAndroid Build Coastguard Worker## 84*523fa7a6SAndroid Build Coastguard Worker## Memory Layout, Storage Type Determination 85*523fa7a6SAndroid Build Coastguard Worker## 86*523fa7a6SAndroid Build Coastguard Worker 87*523fa7a6SAndroid Build Coastguard WorkerImageExtents = Tuple[int, int, int] 88*523fa7a6SAndroid Build Coastguard Worker 89*523fa7a6SAndroid Build Coastguard WorkerDEFAULT_TEXTURE_LIMITS = (16384, 16384, 2048) 90*523fa7a6SAndroid Build Coastguard WorkerDEFAULT_BUFFER_LIMIT = 128 * (1024 * 1024) 91*523fa7a6SAndroid Build Coastguard Worker 92*523fa7a6SAndroid Build Coastguard Worker 93*523fa7a6SAndroid Build Coastguard Workerclass PackedDim(IntEnum): 94*523fa7a6SAndroid Build Coastguard Worker WIDTH = 0 95*523fa7a6SAndroid Build Coastguard Worker HEIGHT = 1 96*523fa7a6SAndroid Build Coastguard Worker CHANNELS = 2 97*523fa7a6SAndroid Build Coastguard Worker 98*523fa7a6SAndroid Build Coastguard Worker 99*523fa7a6SAndroid Build Coastguard Workerall_packed_dims: Set[PackedDim] = { 100*523fa7a6SAndroid Build Coastguard Worker PackedDim.WIDTH, 101*523fa7a6SAndroid Build Coastguard Worker PackedDim.HEIGHT, 102*523fa7a6SAndroid Build Coastguard Worker PackedDim.CHANNELS, 103*523fa7a6SAndroid Build Coastguard Worker} 104*523fa7a6SAndroid Build Coastguard Worker 105*523fa7a6SAndroid Build Coastguard Workerall_storage_types: Set[VkStorageType] = { 106*523fa7a6SAndroid Build Coastguard Worker VkStorageType.BUFFER, 107*523fa7a6SAndroid Build Coastguard Worker VkStorageType.TEXTURE_3D, 108*523fa7a6SAndroid Build Coastguard Worker} 109*523fa7a6SAndroid Build Coastguard Worker 110*523fa7a6SAndroid Build Coastguard Workerall_memory_layouts: Set[VkMemoryLayout] = { 111*523fa7a6SAndroid Build Coastguard Worker VkMemoryLayout.TENSOR_WIDTH_PACKED, 112*523fa7a6SAndroid Build Coastguard Worker VkMemoryLayout.TENSOR_HEIGHT_PACKED, 113*523fa7a6SAndroid Build Coastguard Worker VkMemoryLayout.TENSOR_CHANNELS_PACKED, 114*523fa7a6SAndroid Build Coastguard Worker} 115*523fa7a6SAndroid Build Coastguard Worker 116*523fa7a6SAndroid Build Coastguard Worker 117*523fa7a6SAndroid Build Coastguard Workerdef within_buffer_limit(node: torch.fx.Node, buffer_limit: int) -> int: 118*523fa7a6SAndroid Build Coastguard Worker """ 119*523fa7a6SAndroid Build Coastguard Worker Checks whether the tensors produced by the given node can fit within the device's 120*523fa7a6SAndroid Build Coastguard Worker GPU buffer limit, which represents the maximum number of elements that can be stored 121*523fa7a6SAndroid Build Coastguard Worker in a GPU buffer. 122*523fa7a6SAndroid Build Coastguard Worker """ 123*523fa7a6SAndroid Build Coastguard Worker assert is_tensor_node(node) 124*523fa7a6SAndroid Build Coastguard Worker 125*523fa7a6SAndroid Build Coastguard Worker if isinstance(node.meta["val"], FakeTensor): 126*523fa7a6SAndroid Build Coastguard Worker return node.meta["val"].numel() < buffer_limit 127*523fa7a6SAndroid Build Coastguard Worker elif isinstance(node.meta["val"], list) or isinstance(node.meta["val"], tuple): 128*523fa7a6SAndroid Build Coastguard Worker return all(x.numel() < buffer_limit for x in node.meta["val"]) 129*523fa7a6SAndroid Build Coastguard Worker else: 130*523fa7a6SAndroid Build Coastguard Worker raise RuntimeError(f"Cannot get numel for val of type {type(node.meta['val'])}") 131*523fa7a6SAndroid Build Coastguard Worker 132*523fa7a6SAndroid Build Coastguard Worker 133*523fa7a6SAndroid Build Coastguard Workerdef required_image_extents(sizes: torch.Size, layout: VkMemoryLayout) -> ImageExtents: 134*523fa7a6SAndroid Build Coastguard Worker """ 135*523fa7a6SAndroid Build Coastguard Worker Calculate the image extents that will be used to represent a tensor with the given sizes 136*523fa7a6SAndroid Build Coastguard Worker and memory layout in the Vulkan Delegate. 137*523fa7a6SAndroid Build Coastguard Worker """ 138*523fa7a6SAndroid Build Coastguard Worker width = sizes[-1] if len(sizes) >= 1 else 1 139*523fa7a6SAndroid Build Coastguard Worker height = sizes[-2] if len(sizes) >= 2 else 1 140*523fa7a6SAndroid Build Coastguard Worker channels = sizes[-3] if len(sizes) >= 3 else 1 141*523fa7a6SAndroid Build Coastguard Worker batch = sizes[0] if len(sizes) >= 4 else 1 142*523fa7a6SAndroid Build Coastguard Worker 143*523fa7a6SAndroid Build Coastguard Worker if layout == VkMemoryLayout.TENSOR_WIDTH_PACKED: 144*523fa7a6SAndroid Build Coastguard Worker width = (width + 3) // 4 145*523fa7a6SAndroid Build Coastguard Worker elif layout == VkMemoryLayout.TENSOR_HEIGHT_PACKED: 146*523fa7a6SAndroid Build Coastguard Worker height = (height + 3) // 4 147*523fa7a6SAndroid Build Coastguard Worker elif layout == VkMemoryLayout.TENSOR_CHANNELS_PACKED: 148*523fa7a6SAndroid Build Coastguard Worker channels = (channels + 3) // 4 149*523fa7a6SAndroid Build Coastguard Worker else: 150*523fa7a6SAndroid Build Coastguard Worker raise RuntimeError(f"Unsupported memory layout {layout}") 151*523fa7a6SAndroid Build Coastguard Worker 152*523fa7a6SAndroid Build Coastguard Worker return width, height, channels * batch 153*523fa7a6SAndroid Build Coastguard Worker 154*523fa7a6SAndroid Build Coastguard Worker 155*523fa7a6SAndroid Build Coastguard Workerdef extents_are_valid(extents: ImageExtents, limits: ImageExtents) -> bool: 156*523fa7a6SAndroid Build Coastguard Worker return all(extents[i] <= limits[i] for i in range(len(extents))) 157*523fa7a6SAndroid Build Coastguard Worker 158*523fa7a6SAndroid Build Coastguard Worker 159*523fa7a6SAndroid Build Coastguard Workerdef valid_texture_memory_layouts( 160*523fa7a6SAndroid Build Coastguard Worker tensor_sizes: torch.Size, texture_limits: ImageExtents 161*523fa7a6SAndroid Build Coastguard Worker) -> Set[VkMemoryLayout]: 162*523fa7a6SAndroid Build Coastguard Worker """ 163*523fa7a6SAndroid Build Coastguard Worker Given tensor sizes, determine the set of memory layouts which will prodice a texture 164*523fa7a6SAndroid Build Coastguard Worker that can fit within the specified device limits. 165*523fa7a6SAndroid Build Coastguard Worker """ 166*523fa7a6SAndroid Build Coastguard Worker valid_layouts = set() 167*523fa7a6SAndroid Build Coastguard Worker for layout in list(all_memory_layouts): 168*523fa7a6SAndroid Build Coastguard Worker extents = required_image_extents(tensor_sizes, layout) 169*523fa7a6SAndroid Build Coastguard Worker if extents_are_valid(extents, texture_limits): 170*523fa7a6SAndroid Build Coastguard Worker valid_layouts.add(layout) 171*523fa7a6SAndroid Build Coastguard Worker 172*523fa7a6SAndroid Build Coastguard Worker return valid_layouts 173*523fa7a6SAndroid Build Coastguard Worker 174*523fa7a6SAndroid Build Coastguard Worker 175*523fa7a6SAndroid Build Coastguard Workerdef possible_node_memory_layouts( 176*523fa7a6SAndroid Build Coastguard Worker node: torch.fx.Node, texture_limits: ImageExtents 177*523fa7a6SAndroid Build Coastguard Worker) -> Set[VkMemoryLayout]: 178*523fa7a6SAndroid Build Coastguard Worker """ 179*523fa7a6SAndroid Build Coastguard Worker Given a node, determine the set of memory layouts which can be used to represent all 180*523fa7a6SAndroid Build Coastguard Worker tensors involved in the computation. 181*523fa7a6SAndroid Build Coastguard Worker """ 182*523fa7a6SAndroid Build Coastguard Worker assert is_tensor_node(node) 183*523fa7a6SAndroid Build Coastguard Worker if isinstance(node.meta["val"], FakeTensor): 184*523fa7a6SAndroid Build Coastguard Worker return valid_texture_memory_layouts(node.meta["val"].shape, texture_limits) 185*523fa7a6SAndroid Build Coastguard Worker valid_layouts = set() 186*523fa7a6SAndroid Build Coastguard Worker if isinstance(node.meta["val"], list) or isinstance(node.meta["val"], tuple): 187*523fa7a6SAndroid Build Coastguard Worker for fake_tensor in node.meta["val"]: 188*523fa7a6SAndroid Build Coastguard Worker valid_layouts = valid_layouts.union( 189*523fa7a6SAndroid Build Coastguard Worker valid_texture_memory_layouts(fake_tensor.shape, texture_limits) 190*523fa7a6SAndroid Build Coastguard Worker ) 191*523fa7a6SAndroid Build Coastguard Worker 192*523fa7a6SAndroid Build Coastguard Worker return valid_layouts 193*523fa7a6SAndroid Build Coastguard Worker 194*523fa7a6SAndroid Build Coastguard Worker 195*523fa7a6SAndroid Build Coastguard Worker## 196*523fa7a6SAndroid Build Coastguard Worker## TensorSpec Utils 197*523fa7a6SAndroid Build Coastguard Worker## 198*523fa7a6SAndroid Build Coastguard Worker 199*523fa7a6SAndroid Build Coastguard Worker 200*523fa7a6SAndroid Build Coastguard Workerdef set_node_spec_attr(node: torch.fx.Node, attr: str, value): 201*523fa7a6SAndroid Build Coastguard Worker assert "spec" in node.meta 202*523fa7a6SAndroid Build Coastguard Worker spec = node.meta["spec"] 203*523fa7a6SAndroid Build Coastguard Worker if isinstance(spec, TensorSpec): 204*523fa7a6SAndroid Build Coastguard Worker setattr(spec, attr, value) 205*523fa7a6SAndroid Build Coastguard Worker elif isinstance(spec, list) or isinstance(spec, tuple): 206*523fa7a6SAndroid Build Coastguard Worker for s in spec: 207*523fa7a6SAndroid Build Coastguard Worker assert isinstance(s, TensorSpec) 208*523fa7a6SAndroid Build Coastguard Worker setattr(s, attr, value) 209*523fa7a6SAndroid Build Coastguard Worker else: 210*523fa7a6SAndroid Build Coastguard Worker raise RuntimeError(f"Cannot set attr for spec of type {type(spec)}") 211*523fa7a6SAndroid Build Coastguard Worker 212*523fa7a6SAndroid Build Coastguard Worker 213*523fa7a6SAndroid Build Coastguard Workerdef get_node_spec_attr(node: torch.fx.Node, attr: str, return_first: bool = True): 214*523fa7a6SAndroid Build Coastguard Worker assert "spec" in node.meta 215*523fa7a6SAndroid Build Coastguard Worker spec = node.meta["spec"] 216*523fa7a6SAndroid Build Coastguard Worker if isinstance(spec, TensorSpec): 217*523fa7a6SAndroid Build Coastguard Worker return getattr(spec, attr) if hasattr(spec, attr) else None 218*523fa7a6SAndroid Build Coastguard Worker elif isinstance(spec, list) or isinstance(spec, tuple): 219*523fa7a6SAndroid Build Coastguard Worker if return_first: 220*523fa7a6SAndroid Build Coastguard Worker return getattr(spec[0], attr) if hasattr(spec, attr) else None 221*523fa7a6SAndroid Build Coastguard Worker else: 222*523fa7a6SAndroid Build Coastguard Worker return [getattr(s, attr) if hasattr(s, attr) else None for s in spec] 223*523fa7a6SAndroid Build Coastguard Worker else: 224*523fa7a6SAndroid Build Coastguard Worker raise RuntimeError(f"Cannot get attr for spec of type {type(spec)}") 225*523fa7a6SAndroid Build Coastguard Worker 226*523fa7a6SAndroid Build Coastguard Worker 227*523fa7a6SAndroid Build Coastguard Workerdef get_node_storage_type(node: torch.fx.Node) -> Optional[VkStorageType]: 228*523fa7a6SAndroid Build Coastguard Worker return get_node_spec_attr(node, "vk_storage_type") 229*523fa7a6SAndroid Build Coastguard Worker 230*523fa7a6SAndroid Build Coastguard Worker 231*523fa7a6SAndroid Build Coastguard Workerdef get_node_memory_layout(node: torch.fx.Node) -> Optional[VkMemoryLayout]: 232*523fa7a6SAndroid Build Coastguard Worker return get_node_spec_attr(node, "vk_memory_layout") 233