1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7import logging 8import operator 9from types import NoneType 10from typing import cast, List, Optional, Union 11 12import executorch.backends.vulkan.serialization.vulkan_graph_schema as vk_graph_schema 13 14import torch 15 16from executorch.backends.vulkan.serialization.vulkan_graph_schema import ( 17 VkMemoryLayout, 18 VkStorageType, 19) 20from executorch.backends.vulkan.utils import ( 21 is_constant, 22 is_get_attr_node, 23 is_param_node, 24) 25from executorch.exir.backend.utils import DelegateMappingBuilder 26 27from executorch.exir.tensor import TensorSpec 28from torch._export.utils import get_buffer, get_param, is_buffer, is_param 29from torch.export import ExportedProgram 30from torch.fx import Node 31 32_ScalarType = Union[bool, int, float] 33_Argument = Union[ 34 Node, NoneType, _ScalarType, TensorSpec, List[_ScalarType], List[Node], str 35] 36 37logger: logging.Logger = logging.getLogger("") 38logger.setLevel(logging.INFO) 39 40 41class VkGraphBuilder: 42 def __init__( 43 self, 44 program: ExportedProgram, 45 delegate_mapping_builder: DelegateMappingBuilder, 46 ) -> None: 47 self.program = program 48 self.delegate_mapping_builder = delegate_mapping_builder 49 self.chain = [] 50 self.values = [] 51 self.input_ids = [] 52 self.output_ids = [] 53 self.const_tensors = [] 54 55 # Mapping from Node to VkValue id 56 self.node_to_value_ids = {} 57 58 # For logging 59 self.seen_ops = set() 60 61 @staticmethod 62 def get_vk_datatype(torch_dtype: torch.dtype) -> vk_graph_schema.VkDataType: 63 if torch_dtype == torch.bool: 64 return vk_graph_schema.VkDataType.BOOL 65 elif torch_dtype == torch.uint8: 66 return vk_graph_schema.VkDataType.UINT8 67 elif torch_dtype == torch.int8: 68 return vk_graph_schema.VkDataType.INT8 69 elif torch_dtype == torch.int32: 70 return vk_graph_schema.VkDataType.INT32 71 elif torch_dtype == torch.float16: 72 return vk_graph_schema.VkDataType.FLOAT16 73 elif torch_dtype == torch.float32: 74 return vk_graph_schema.VkDataType.FLOAT32 75 # Narrowing conversion for index tensor produced by max_poolNd_with_indices. 76 elif torch_dtype == torch.int64: 77 return vk_graph_schema.VkDataType.INT32 78 else: 79 raise AssertionError(f"Invalid dtype for vulkan_preprocess ({torch_dtype})") 80 81 def get_constant(self, node: Node) -> Optional[torch.Tensor]: 82 """ 83 Returns the constant associated with the given node in the exported program. 84 Returns None if the node is not a constant within the exported program 85 """ 86 if is_constant(self.program, node): 87 constant_name = ( 88 self.program.graph_signature.inputs_to_lifted_tensor_constants[ 89 node.name 90 ] 91 ) 92 if constant_name in self.program.constants: 93 return self.program.constants[constant_name] 94 else: 95 return None 96 97 return None 98 99 def get_param_tensor(self, node: Node) -> torch.Tensor: 100 tensor = None 101 if node is None: 102 raise RuntimeError("node is None") 103 elif is_param(self.program, node): 104 tensor = get_param(self.program, node) 105 elif is_buffer(self.program, node): 106 tensor = get_buffer(self.program, node) 107 elif is_constant(self.program, node): 108 tensor = self.get_constant(node) 109 elif is_get_attr_node(node): 110 # This is a hack to support both lifted and unlifted graph 111 try: 112 tensor = getattr(node.graph.owning_module, node.target) 113 except AttributeError: 114 tensor = getattr(self.program.graph_module, node.target) 115 else: 116 raise RuntimeError(f"unsupported param type, {node.op}.") 117 118 assert tensor is not None 119 return tensor 120 121 def maybe_add_constant_tensor(self, node: Node) -> int: 122 constant_id = -1 123 if is_param_node(self.program, node): 124 constant_id = len(self.const_tensors) 125 self.const_tensors.append(self.get_param_tensor(node)) 126 127 return constant_id 128 129 def create_node_value(self, node: Node) -> int: 130 # If the node has been marked as a scalar tensor, create a SymInt instead of a tensor 131 if node.meta.get("vkdg_is_scalar_tensor", False): 132 new_id = self.create_symint_value() 133 self.node_to_value_ids[node] = new_id 134 return new_id 135 136 spec = node.meta.get("spec") 137 if isinstance(spec, TensorSpec): 138 constant_id = self.maybe_add_constant_tensor(node) 139 new_id = self.create_tensor_value(spec, constant_id) 140 self.node_to_value_ids[node] = new_id 141 return new_id 142 elif isinstance(spec, list) or isinstance(spec, tuple): 143 # pyre-ignore[6]: pyre having hard time to infer Node type inside 144 # the container. 145 new_id = self.create_value_list_value(spec) 146 self.node_to_value_ids[node] = new_id 147 return new_id 148 else: 149 raise RuntimeError(f"Cannot create value for spec of type {type(spec)}") 150 151 def create_null_value(self) -> int: 152 new_id = len(self.values) 153 self.values.append(vk_graph_schema.VkValue(vk_graph_schema.Null())) 154 return new_id 155 156 def create_scalar_value(self, scalar: _ScalarType) -> int: 157 new_id = len(self.values) 158 if isinstance(scalar, bool): 159 self.values.append(vk_graph_schema.VkValue(vk_graph_schema.Bool(scalar))) 160 elif isinstance(scalar, int): 161 self.values.append(vk_graph_schema.VkValue(vk_graph_schema.Int(scalar))) 162 elif isinstance(scalar, float): 163 self.values.append(vk_graph_schema.VkValue(vk_graph_schema.Double(scalar))) 164 return new_id 165 166 def create_symint_value(self) -> int: 167 new_id = len(self.values) 168 self.values.append(vk_graph_schema.VkValue(vk_graph_schema.SymInt(0))) 169 return new_id 170 171 def create_tensor_value(self, spec: TensorSpec, constant_id: int = -1) -> int: 172 # Negative id indicates that this tensor will have its own dedicated memory. 173 mem_obj_id = -1 174 if spec.mem_obj_id is not None: 175 mem_obj_id = spec.mem_obj_id 176 177 storage_type = VkStorageType.DEFAULT_STORAGE 178 memory_layout = VkMemoryLayout.DEFAULT_LAYOUT 179 if hasattr(spec, "vk_storage_type"): 180 # pyre-ignore[16] 181 storage_type = spec.vk_storage_type 182 if hasattr(spec, "vk_memory_layout"): 183 # pyre-ignore[16] 184 memory_layout = spec.vk_memory_layout 185 186 new_id = len(self.values) 187 self.values.append( 188 vk_graph_schema.VkValue( 189 value=vk_graph_schema.VkTensor( 190 datatype=self.get_vk_datatype(spec.dtype), 191 dims=spec.shape, 192 constant_id=constant_id, 193 mem_obj_id=mem_obj_id, 194 storage_type=storage_type, 195 memory_layout=memory_layout, 196 ) 197 ) 198 ) 199 return new_id 200 201 def create_scalar_list_value(self, arg: List[_ScalarType]) -> int: 202 new_id = len(self.values) 203 if len(arg) == 0: 204 self.values.append( 205 vk_graph_schema.VkValue(vk_graph_schema.IntList(items=[])) 206 ) 207 elif isinstance(arg[0], bool): 208 self.values.append( 209 vk_graph_schema.VkValue( 210 vk_graph_schema.BoolList(items=[cast(bool, e) for e in arg]) 211 ) 212 ) 213 elif isinstance(arg[0], int): 214 self.values.append( 215 vk_graph_schema.VkValue( 216 vk_graph_schema.IntList(items=[cast(int, e) for e in arg]) 217 ) 218 ) 219 elif isinstance(arg[0], float): 220 self.values.append( 221 vk_graph_schema.VkValue( 222 vk_graph_schema.DoubleList(items=[cast(float, e) for e in arg]) 223 ) 224 ) 225 return new_id 226 227 def create_value_list_value(self, arg: tuple | list) -> int: 228 self.values.append( 229 vk_graph_schema.VkValue( 230 vk_graph_schema.ValueList( 231 items=[self.get_or_create_value_for(e) for e in arg] 232 ) 233 ) 234 ) 235 return len(self.values) - 1 236 237 def create_string_value(self, string: str) -> int: 238 new_id = len(self.values) 239 self.values.append( 240 vk_graph_schema.VkValue(vk_graph_schema.String(string_val=string)) 241 ) 242 return new_id 243 244 def get_or_create_value_for(self, arg: _Argument): 245 if isinstance(arg, Node): 246 # If the Node has already been processed, return the existing id. 247 if arg in self.node_to_value_ids: 248 return self.node_to_value_ids[arg] 249 return self.create_node_value(arg) 250 elif ( 251 isinstance(arg, NoneType) 252 or isinstance(arg, torch.device) 253 or isinstance(arg, torch.dtype) 254 or isinstance(arg, torch.layout) 255 or isinstance(arg, torch.memory_format) 256 ): 257 return self.create_null_value() 258 elif isinstance(arg, _ScalarType): 259 return self.create_scalar_value(arg) 260 elif isinstance(arg, TensorSpec): 261 return self.create_tensor_value(arg) 262 elif isinstance(arg, list) and ( 263 len(arg) == 0 or isinstance(arg[0], _ScalarType) 264 ): 265 # pyre-ignore[6] 266 return self.create_scalar_list_value(arg) 267 elif isinstance(arg, list) and isinstance(arg[0], Node): 268 return self.create_value_list_value(arg) 269 elif isinstance(arg, torch.fx.immutable_collections.immutable_list): 270 # pyre-ignore[6] 271 return self.create_value_list_value(arg) 272 elif isinstance(arg, str): 273 return self.create_string_value(arg) 274 else: 275 raise RuntimeError(f"Cannot create value for arg of type {type(arg)}") 276 277 def process_placeholder_node(self, node: Node) -> None: 278 # ignores any tensors that don't get used in any ops 279 if len(node.users) == 0: 280 return None 281 ids = self.create_node_value(node) 282 if not is_param_node(self.program, node): 283 if isinstance(ids, int): 284 self.input_ids.append(ids) 285 else: 286 self.input_ids += ids 287 288 def process_getitem_node(self, node: Node) -> None: 289 # Find ValueList id from the collection node. 290 collection_node = node.all_input_nodes[0] 291 list_id = self.node_to_value_ids[collection_node] 292 293 # Extract the target Value id from ValueList. 294 valuelist_id = node.args[1] 295 value_id = self.values[list_id].value.items[valuelist_id] 296 297 # Map Node to Value id. 298 self.node_to_value_ids[node] = value_id 299 300 def process_call_function_node(self, node) -> None: 301 operator_call_args = [] 302 303 self.seen_ops.add(node.target) 304 305 for i, schema_arg in enumerate(node.target._schema.arguments): 306 if not schema_arg.kwarg_only and i < len(node.args): 307 function_arg = node.args[i] 308 elif schema_arg.name in node.kwargs: 309 function_arg = node.kwargs[schema_arg.name] 310 else: 311 function_arg = schema_arg.default_value 312 313 # Create a Value for each function argument. If the argument has been 314 # previously encountered, then use the existing Value id. 315 operator_call_args.append(self.get_or_create_value_for(function_arg)) 316 317 # Add output node 318 operator_call_args.append(self.create_node_value(node)) 319 operator_node_id = ( 320 0 321 if not self.delegate_mapping_builder 322 else self.delegate_mapping_builder.insert_delegate_mapping_entry(node) 323 ) 324 self.chain.append( 325 vk_graph_schema.OperatorCall( 326 node_id=operator_node_id, # pyre-ignore[6]: this is going to be an int 327 name=node.target.__name__, 328 args=operator_call_args, 329 ), 330 ) 331 332 def process_getattr_node(self, node: Node) -> None: 333 self.create_node_value(node) 334 335 def process_output_node(self, node: Node) -> None: 336 for out_node in node.all_input_nodes: 337 if out_node not in self.node_to_value_ids: 338 raise AssertionError( 339 "Cannot find input to output node in node_to_value_ids. This means " 340 "the output node is being serialized before its corresponding " 341 "internal node which is not allowed." 342 ) 343 self.output_ids.append(self.node_to_value_ids[out_node]) 344 345 def process_node(self, node: Node, call_node_debug_hdl: int) -> None: 346 if node.op == "placeholder": 347 self.process_placeholder_node(node) 348 elif node.op == "call_function": 349 if node.target == operator.getitem: 350 self.process_getitem_node(node) 351 else: 352 node.meta["debug_handle"] = call_node_debug_hdl 353 self.process_call_function_node(node) 354 elif node.op == "get_attr": 355 self.process_getattr_node(node) 356 elif node.op == "output": 357 self.process_output_node(node) 358 else: 359 raise AssertionError(f"Unsupported node op: {node.op}") 360 361 def build_graph(self) -> vk_graph_schema.VkGraph: 362 call_node_debug_hdl = 0 363 for node in self.program.graph_module.graph.nodes: 364 self.process_node(node, call_node_debug_hdl) 365 call_node_debug_hdl += 1 366 367 logger.info("Operators included in this Vulkan partition: ") 368 for op in self.seen_ops: 369 logger.info(f" {op.__name__}") 370 371 return vk_graph_schema.VkGraph( 372 version="0", 373 chain=self.chain, 374 values=self.values, 375 input_ids=self.input_ids, 376 output_ids=self.output_ids, 377 constants=[], 378 shaders=[], 379 ) 380