1# 2# Copyright (c) 2023 Apple Inc. All rights reserved. 3# Provided subject to the LICENSE file in the top level directory. 4# 5 6import ctypes 7import logging 8 9from typing import Dict, List, Tuple, Union 10 11import torch 12 13from executorch.backends.apple.mps.serialization.mps_graph_schema import ( 14 Buffer, 15 MPSCast, 16 MPSDataType, 17 MPSGraph, 18 MPSNode, 19 MPSNodeUnion, 20 MPSTensor, 21) 22 23from executorch.backends.apple.mps.utils.mps_utils import ( 24 edge_dtype_to_mps_dtype, 25 get_input_node, 26 get_param_tensor, 27 get_scalar_val, 28 is_parameter, 29) 30 31from executorch.backends.transforms import get_shape 32from executorch.exir.sym_util import eval_shape 33 34from torch.export.exported_program import ExportedProgram 35 36 37class NodeVisitor: 38 """ 39 Node visitor pattern for visiting nodes in an edge IR graph and 40 serializing them using the mps serialization schema. 41 """ 42 43 _tensor_to_id: Dict[torch.fx.Node, int] = {} 44 _convert_model_to_fp16: bool = True 45 46 def __init__( 47 self, exported_program: ExportedProgram, convert_model_to_fp16: bool = True 48 ): 49 self._exported_program = exported_program 50 self._convert_model_to_fp16 = convert_model_to_fp16 51 52 @property 53 def tensor_to_id(self) -> Dict[torch.fx.Node, int]: 54 return self._tensor_to_id 55 56 @property 57 def convert_model_to_fp16(self) -> bool: 58 return self._convert_model_to_fp16 59 60 @property 61 def exported_program(self) -> ExportedProgram: 62 return self._exported_program 63 64 def define_node( 65 self, 66 node: torch.fx.Node, 67 mps_graph: MPSGraph, 68 ) -> None: 69 raise NotImplementedError("NodeVisitor must be extended!") 70 71 def define_tensor( 72 self, 73 node: torch.fx.Node, 74 mps_graph: MPSGraph, 75 mps_data_type: MPSDataType = None, 76 ) -> int: 77 """Defines a tensor value into the MPSGraph serialization schema 78 79 Args: 80 node (torch.fx.Node): EdgeIR tensor to define into mps_graph 81 mps_graph (MPSGraph): MPSGraph object for serializing into flatbuffer 82 """ 83 84 if node is None: 85 return -1 86 87 if node in self.tensor_to_id: 88 return self.tensor_to_id[node] 89 90 # Get a unique id for the node. 91 id = self.get_serialized_id(node, mps_graph) 92 cb_size, constant_buffer, mps_data_type = self.get_serialized_buffer( 93 node, mps_graph, id, mps_data_type 94 ) 95 dims = get_shape(node) 96 97 logging.debug( 98 f"Serializing: {node}, data type: {node.meta['val'].dtype}, dims: {dims}" 99 ) 100 mps_tensor = MPSTensor( 101 datatype=mps_data_type, 102 num_dims=len(dims), 103 dims=dims, 104 constant_buffer_size=cb_size, 105 constant_buffer=constant_buffer, 106 ) 107 108 mps_graph.mps_values.append(mps_tensor) 109 return id 110 111 def define_tensor_list(self, node: torch.fx.Node, mps_graph: MPSGraph) -> List[int]: 112 """_summary_ 113 114 Args: 115 node (torch.fx.Node): _description_ 116 mps_graph (MPSGraph): _description_ 117 """ 118 if node is None: 119 return -1 120 121 if node in self.tensor_to_id: 122 return self.tensor_to_id[node] 123 124 self.tensor_to_id[node] = [] 125 for i in range(len(node.meta["val"])): 126 id = len(mps_graph.mps_values) 127 self.tensor_to_id[node].append(id) 128 129 tensor = node.meta["val"][i] 130 dims = eval_shape(tensor.shape) 131 mps_data_type = edge_dtype_to_mps_dtype(tensor.dtype) 132 logging.debug( 133 f"Serializing: [{i}]: {node}, data type: {tensor.dtype}, dims: {dims}" 134 ) 135 136 mps_tensor = MPSTensor( 137 datatype=mps_data_type, 138 num_dims=len(dims), 139 dims=dims, 140 constant_buffer_size=0, 141 constant_buffer=Buffer(storage=b""), 142 ) 143 logging.debug(f" Serialized tensor: {mps_tensor}") 144 mps_graph.mps_values.append(mps_tensor) 145 return self.tensor_to_id[node] 146 147 def hash_tensor(self, tensor): 148 return hash(tuple(tensor.reshape(-1).tolist())) 149 150 def define_constant( 151 self, 152 constant_tensor: torch.tensor, 153 mps_graph: MPSGraph, 154 ): 155 """Defines a scalar value into the MPSGraph serialization schema 156 157 Args: 158 constant_tensor (torch.fx.Node): EdgeIR tensor to define into mps_graph 159 mps_graph (MPSGraph): MPSGraph object for serializing into flatbuffer 160 """ 161 constant_tensor = constant_tensor.contiguous() 162 hash = self.hash_tensor(constant_tensor) 163 if hash in self.tensor_to_id: 164 return self.tensor_to_id[hash] 165 166 id = self.get_serialized_id(constant_tensor, mps_graph, hash) 167 168 mps_data_type = edge_dtype_to_mps_dtype(constant_tensor.dtype) 169 constant_buffer_size, constant_buffer, mps_data_type = self.get_serialized_data( 170 constant_tensor, mps_graph, mps_data_type, id 171 ) 172 dims = list(constant_tensor.shape) 173 174 mps_tensor = MPSTensor( 175 datatype=mps_data_type, 176 num_dims=len(dims), 177 dims=dims, 178 constant_buffer_size=constant_buffer_size, 179 constant_buffer=constant_buffer, 180 ) 181 182 mps_graph.mps_values.append(mps_tensor) 183 return id 184 185 def define_scalar( 186 self, 187 val: Union[float, int], 188 mps_data_type: MPSDataType, 189 mps_graph: MPSGraph, 190 ): 191 """Defines a scalar value into the MPSGraph serialization schema 192 193 Args: 194 mps_graph (MPSGraph): MPSGraph object for serializing into flatbuffer 195 """ 196 assert isinstance(val, int) or isinstance(val, float) 197 198 if val in self.tensor_to_id: 199 return self.tensor_to_id[val] 200 201 id = self.get_serialized_id(val, mps_graph, val) 202 203 tensor = torch.tensor(val) 204 constant_buffer_size, constant_buffer, mps_data_type = self.get_serialized_data( 205 tensor, mps_graph, mps_data_type, id 206 ) 207 208 mps_tensor = MPSTensor( 209 datatype=mps_data_type, 210 num_dims=1, 211 dims=[1], 212 constant_buffer_size=constant_buffer_size, 213 constant_buffer=constant_buffer, 214 ) 215 216 mps_graph.mps_values.append(mps_tensor) 217 return id 218 219 def get_serialized_buffer( 220 self, 221 node: torch.fx.Node, 222 mps_graph: MPSGraph, 223 node_id: int, 224 mps_data_type: MPSDataType = None, 225 ) -> Tuple[int, Buffer, MPSDataType]: 226 """ 227 If tensor holds some constant data, serialize it and return the 228 index of its placement in the constant buffer 229 230 Args: 231 node (torch.fx.Node): _description_ 232 mps_graph (MPSGraph): _description_ 233 234 Returns: 235 _type_: _description_ 236 """ 237 mps_data_type = ( 238 self.get_serialized_dtype(node) if mps_data_type is None else mps_data_type 239 ) 240 241 # Check if this node is a lifted parameter 242 if not is_parameter(self.exported_program, node): 243 return 0, Buffer(storage=b""), mps_data_type 244 245 tensor = get_param_tensor(self.exported_program, node) 246 assert tensor is not None and isinstance(tensor, torch.Tensor) 247 tensor = tensor.contiguous() 248 249 return self.get_serialized_data(tensor, mps_graph, mps_data_type, node_id) 250 251 def get_serialized_data( 252 self, 253 tensor: torch.tensor, 254 mps_graph: MPSGraph, 255 mps_data_type: MPSDataType, 256 id: int, 257 ) -> Tuple[int, Buffer, MPSDataType]: 258 if ( 259 self.convert_model_to_fp16 260 and mps_data_type == MPSDataType.mps_data_type_float32 261 ): 262 tensor = tensor.half() 263 mps_data_type = MPSDataType.mps_data_type_float16 264 265 if id not in mps_graph.constant_ids: 266 mps_graph.constant_ids.append(id) 267 268 if ( 269 mps_data_type is MPSDataType.mps_data_type_int4 270 and tensor.dtype is torch.int8 271 ): 272 if tensor.dim() != 2: 273 raise RuntimeError(f"Unexpected tensor shape {tensor.shape}") 274 275 tensor = tensor.to(dtype=torch.int32) 276 tensor = (((tensor[::, ::2] & 0x0F) << 4) | (tensor[::, 1::2] & 0x0F)).to( 277 torch.uint8 278 ) 279 tensor = ( 280 torch._convert_weight_to_int4pack(tensor.to("mps"), 2) 281 .cpu() 282 .view(dtype=torch.uint8) 283 ) 284 array_type = ctypes.c_char * tensor.untyped_storage().nbytes() 285 array = ctypes.cast( 286 tensor.untyped_storage().data_ptr(), 287 ctypes.POINTER(array_type), 288 ).contents 289 buffer = Buffer(storage=bytes(array)) 290 291 return tensor.untyped_storage().nbytes(), buffer, mps_data_type 292 293 def get_serialized_id( 294 self, node: Union[torch.fx.Node, float, int], mps_graph: MPSGraph, hash=None 295 ) -> int: 296 """ 297 Map a tensor to a unique id. If the tensor was already mapped, return 298 the existent id. 299 300 Args: 301 node (Union[torch.fx.Node, float]): _description_ 302 mps_graph (MPSGraph): _description_ 303 304 Returns: 305 int: _description_ 306 """ 307 if hash is not None and hash in self.tensor_to_id: 308 return self.tensor_to_id[hash] 309 elif node in self.tensor_to_id: 310 return self.tensor_to_id[node] 311 312 id = len(mps_graph.mps_values) 313 if hash is not None: 314 self.tensor_to_id[hash] = id 315 else: 316 self.tensor_to_id[node] = id 317 318 return id 319 320 def torch_dtype_to_mps_dtype(self, torch_dtype: torch.dtype) -> MPSDataType: 321 return edge_dtype_to_mps_dtype(torch_dtype) 322 323 def get_serialized_dtype( 324 self, 325 node: torch.fx.Node, 326 ) -> MPSDataType: 327 return self.torch_dtype_to_mps_dtype(node.meta["val"].dtype) 328 329 def create_tertiary_node( 330 self, node: torch.fx.Node, mps_graph: MPSGraph, tertiary_op: MPSNodeUnion 331 ): 332 input1_id = self.define_tensor(get_input_node(node, 0), mps_graph) 333 input2_id = self.define_tensor(get_input_node(node, 1), mps_graph) 334 input3_id = self.define_tensor(get_input_node(node, 2), mps_graph) 335 output_id = self.define_tensor(node, mps_graph) 336 return MPSNode( 337 mpsnode_union=tertiary_op( 338 input1_id=input1_id, 339 input2_id=input2_id, 340 input3_id=input3_id, 341 output_id=output_id, 342 ) 343 ) 344 345 def create_binary_node( 346 self, node: torch.fx.Node, mps_graph: MPSGraph, binary_op: MPSNodeUnion 347 ) -> MPSNode: 348 input1_node = get_input_node(node, 0) 349 input1_id = self.define_tensor(input1_node, mps_graph) 350 351 # Handle both tensor and scalar variants of the op. 352 # In case of scalar ops, manually define a constant and serialize it in the FlatBuffer. 353 if isinstance(node.args[1], torch.fx.Node): 354 # Second argument is a node. 355 input2_id = self.define_tensor(get_input_node(node, 1), mps_graph) 356 else: 357 # Second argument is a scalar. 358 scalar_val = get_scalar_val(node, 1) 359 if input1_node.meta["val"].dtype == torch.float32: 360 scalar_val = float(scalar_val) 361 input2_id = self.define_scalar( 362 scalar_val, self.get_serialized_dtype(input1_node), mps_graph 363 ) 364 365 output_id = self.define_tensor(node, mps_graph) 366 return MPSNode( 367 mpsnode_union=binary_op( 368 input1_id=input1_id, input2_id=input2_id, output_id=output_id 369 ) 370 ) 371 372 def create_unary_node( 373 self, node: torch.fx.Node, mps_graph: MPSGraph, unary_op: MPSNodeUnion 374 ) -> MPSNode: 375 input1_id = self.define_tensor(get_input_node(node, 0), mps_graph) 376 output_id = self.define_tensor(node, mps_graph) 377 return MPSNode(mpsnode_union=unary_op(input1_id=input1_id, output_id=output_id)) 378 379 380# This will hold mapping of all node names to the visitor class. 381_node_visitor_dict = {} 382 383 384def register_node_visitor(visitor): 385 assert ( 386 isinstance(visitor, type) 387 and issubclass(visitor, NodeVisitor) 388 and hasattr(visitor, "target") 389 ), f"Illformed NodeVisitor subclass, can't register!, got: {visitor}" 390 if isinstance(visitor.target, list): 391 for elem in visitor.target: 392 _node_visitor_dict[elem] = visitor 393 else: 394 _node_visitor_dict[visitor.target] = visitor 395 396 397def get_node_visitors(*args) -> Dict[str, NodeVisitor]: 398 node_visitors = {} 399 """ 400 Create a new class instance at runtime, and put them in a dict 401 """ 402 for target, visitor in _node_visitor_dict.items(): 403 assert callable( 404 visitor 405 ), f"Expecting a callable class, but got {visitor} of type {type(visitor)}" 406 node_visitors[target] = visitor(*args) 407 408 placeholder_output_visitor = NodeVisitor(*args) 409 node_visitors["placeholder"] = placeholder_output_visitor 410 node_visitors["output"] = placeholder_output_visitor 411 return node_visitors 412 413 414def process_placeholder_nodes( 415 exported_program: ExportedProgram, 416 edge_graph_module: torch.fx.GraphModule, 417 mps_graph: MPSGraph, 418 placeholder_visitor: NodeVisitor, 419) -> None: 420 # Visit the placeholder nodes in the same order they are passed to the 421 # forward function - forward(*args). When lifted graphs are being used, 422 # parameters/buffers are lifted as placeholders and the order of the args 423 # is not matching anymore with the original graph. We can retrieve the 424 # original order by parsing all the placeholder nodes, and check if they are 425 # constant tensors. 426 # 427 # Constant tensors will be bundled directly in the FlatBuffer and they won't be 428 # provided by ExecuTorch during runtime. 429 430 for node in edge_graph_module.graph.nodes: 431 if node.op == "placeholder" and not is_parameter( 432 exp_prog=exported_program, node=node 433 ): 434 if node.meta["val"] is None: 435 continue 436 437 input_id = placeholder_visitor.define_tensor(node, mps_graph) 438 mps_graph.input_ids.append(input_id) 439 440 if ( 441 placeholder_visitor.convert_model_to_fp16 442 and node.meta["val"].dtype == torch.float32 443 ): 444 mps_node = MPSNode( 445 mpsnode_union=MPSCast( 446 input1_id=input_id, 447 output_id=input_id, 448 dtype=MPSDataType.mps_data_type_float16, 449 ) 450 ) 451 mps_graph.mps_nodes.append(mps_node) 452 453 454def process_output_node( 455 output_node, 456 mps_graph: MPSGraph, 457 output_visitor: NodeVisitor, 458) -> None: 459 output_id = output_visitor.define_tensor(output_node, mps_graph) 460 mps_graph.output_ids.append(output_id) 461 462 if ( 463 output_visitor.convert_model_to_fp16 464 and output_node.meta["val"].dtype == torch.float32 465 ): 466 mps_node = MPSNode( 467 mpsnode_union=MPSCast( 468 input1_id=output_id, 469 output_id=output_id, 470 dtype=MPSDataType.mps_data_type_float32, 471 ) 472 ) 473 mps_graph.mps_nodes.append(mps_node) 474