1# 2# Copyright (c) 2023 Apple Inc. All rights reserved. 3# Provided subject to the LICENSE file in the top level directory. 4# 5import logging 6from typing import ClassVar, Dict, final, List, Tuple 7 8import torch 9 10from executorch.backends.apple.mps.operators.node_visitor import ( 11 get_node_visitors, 12 NodeVisitor, 13 process_output_node, 14 process_placeholder_nodes, 15) 16 17from executorch.backends.apple.mps.serialization.mps_graph_schema import ( 18 Buffer, 19 DataSegment, 20 MPSGraph, 21 MPSTensor, 22 OpType, 23) 24 25from executorch.backends.apple.mps.serialization.mps_graph_serialize import ( 26 convert_to_flatbuffer, 27) 28from executorch.exir._serialize._program import Cord 29 30from executorch.exir.backend.backend_details import ( 31 BackendDetails, 32 CompileSpec, 33 PreprocessResult, 34) 35from torch.export.exported_program import ExportedProgram 36 37FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s" 38logging.basicConfig(level=logging.INFO, format=FORMAT) 39 40 41@final 42class MPSBackend(BackendDetails): 43 @staticmethod 44 def slice_len_max(s): 45 assert s.start is not None 46 assert s.stop is not None 47 step = 1 48 if s.step is not None: 49 step = s.step 50 return max((s.stop - s.start) // step, 1) 51 52 MAGIC_IX: ClassVar[slice] = slice(4, 8) 53 DATA_SEGMENT_OFFSET_IX: ClassVar[slice] = slice(8, 16) 54 DATA_SEGMENT_SIZE_IX: ClassVar[slice] = slice(16, 24) 55 56 # magic bytes that should be at the beginning of the header 57 EXPECTED_MAGIC: ClassVar[bytes] = b"MP00" 58 # The length of the header in bytes 59 EXPECTED_LENGTH: ClassVar[int] = ( 60 4 61 + slice_len_max(MAGIC_IX) 62 + slice_len_max(DATA_SEGMENT_OFFSET_IX) 63 + slice_len_max(DATA_SEGMENT_SIZE_IX) 64 ) 65 66 @staticmethod 67 def preprocess( 68 edge_program: ExportedProgram, 69 compile_specs: List[CompileSpec], 70 ) -> PreprocessResult: 71 # The EdgeIR nodes are processed in the following order: 72 # 1. Process first the input feeds to the graph (in the same 73 # order as args from forward(*args)), and generate a unique 74 # id for each input placeholder. Each input id is appended to 75 # `input_ids` array from the FlatBuffer schema. 76 # 2. Process the nodes the graph (e.g `call_function`). For each 77 # EdgeIR node, create an equivalent MPS node in the FlatBuffer, 78 # based on which the MPSGraph is constructed at runtime. During 79 # this process, any visited constant in the EdgeIR is added to the 80 # final MPS FlatBuffer schema. Each constant id is appended to the 81 # `constant_ids` FlatBuffer schema. 82 # 3. After all the inputs, nodes and constants are added to the 83 # FlatBuffer graph, process the `output` nodes and add their id to 84 # the `output_ids` array in the schema. 85 86 mps_graph = MPSGraph( 87 version="0", 88 mps_nodes=[], 89 mps_values=[], 90 input_ids=[], 91 output_ids=[], 92 constant_ids=[], 93 graph_type=OpType.mps_graph, 94 constant_segment=DataSegment(0, 0), 95 ) 96 97 convert_model_to_fp16 = True 98 for spec in compile_specs: 99 if spec.key == "use_fp16": 100 convert_model_to_fp16 = bool(list(bytes(spec.value))[0]) 101 102 logging.debug(f"Convert model to FP16: {convert_model_to_fp16}") 103 104 node_visitors = get_node_visitors(edge_program, convert_model_to_fp16) 105 if logging.DEBUG >= logging.root.level: 106 edge_program.graph.print_tabular() 107 108 process_placeholder_nodes( 109 edge_program, 110 edge_program.graph_module, 111 mps_graph, 112 node_visitors["placeholder"], 113 ) 114 115 op_handler = { 116 "call_function": MPSBackend.handle_call_function, 117 "placeholder": MPSBackend.handle_placeholder, 118 "output": MPSBackend.handle_output, 119 "get_attr": MPSBackend.handle_get_attr, 120 } 121 122 for node in edge_program.graph_module.graph.nodes: 123 if node.op not in op_handler: 124 raise RuntimeError(f"{node.op} is not supported in MPS") 125 else: 126 op_handler[node.op](edge_program, node_visitors, node, mps_graph) 127 128 segment_data, mps_graph = _extract_constant_segment(mps_graph) 129 if logging.DEBUG >= logging.root.level: 130 pretty_print(mps_graph) 131 132 # Add to aggregate segments cord with padding. 133 padding_length = _padding_required(len(segment_data), 16) 134 if padding_length > 0: 135 segment_data.append(b"\x00" * padding_length) 136 137 # Combine mps_graph with segment data 138 combined = Cord() 139 graph_bytes = convert_to_flatbuffer(mps_graph) 140 141 data_segment_offset: int = MPSBackend.EXPECTED_LENGTH 142 data_segment_offset = data_segment_offset + len(graph_bytes) 143 144 graph_padding_length = _padding_required(data_segment_offset, 16) 145 data_segment_offset = data_segment_offset + graph_padding_length 146 data_segment_size = len(segment_data) 147 148 data: bytes = ( 149 b"\x00\x00\x00\x00" 150 + MPSBackend.EXPECTED_MAGIC 151 + data_segment_offset.to_bytes(8, byteorder="little") 152 + data_segment_size.to_bytes(8, byteorder="little") 153 ) 154 assert len(data) == MPSBackend.EXPECTED_LENGTH 155 156 combined.append(data) 157 combined.append(graph_bytes) 158 159 if graph_padding_length > 0: 160 combined.append(b"\x00" * graph_padding_length) 161 # Append the segment data to the end of the mps graph 162 combined.append(segment_data) 163 164 return PreprocessResult(processed_bytes=bytes(combined)) 165 166 @staticmethod 167 def handle_call_function( 168 _: ExportedProgram, 169 node_visitors: Dict[str, NodeVisitor], 170 node: torch.fx.Node, 171 mps_graph: MPSGraph, 172 ) -> None: 173 logging.info(f"Visiting: {node}, {node.target.__name__}") 174 175 if ( 176 "delegation_tag" in node.meta 177 and "metal_kernel" in node.meta["delegation_tag"] 178 ): 179 logging.info( 180 f"Node '{node.target.__name__}' was marked as a Metal kernel by the MPSPartitioner!" 181 ) 182 mps_graph.graph_type = OpType.metal_kernel 183 184 if node.target.__name__ in node_visitors: 185 node_visitors[node.target.__name__].define_node(node, mps_graph) 186 else: 187 pretty_print(mps_graph) 188 raise RuntimeError( 189 f"For {node}, {node.op}:{node.target.__name__} is not supported in MPS delegate" 190 ) 191 192 @staticmethod 193 def handle_placeholder( 194 edge_program: ExportedProgram, 195 node_visitors: Dict[str, NodeVisitor], 196 node: torch.fx.Node, 197 mps_graph: MPSGraph, 198 ) -> None: 199 # Constants are handled directly when visiting the nodes. 200 pass 201 202 @staticmethod 203 def handle_output( 204 edge_program: ExportedProgram, 205 node_visitors: Dict[str, NodeVisitor], 206 node: torch.fx.Node, 207 mps_graph: MPSGraph, 208 ) -> None: 209 for output_nodes in node.args: 210 for output_node in output_nodes: 211 process_output_node(output_node, mps_graph, node_visitors[node.op]) 212 213 @staticmethod 214 def handle_get_attr( 215 edge_program: ExportedProgram, 216 node_visitors: Dict[str, NodeVisitor], 217 node: torch.fx.Node, 218 mps_graph: MPSGraph, 219 ) -> None: 220 pass 221 222 223def _padding_required(offset: int, alignment: int) -> int: 224 """Returns the padding required to align `offset` to `alignment`.""" 225 remainder: int = offset % alignment 226 if remainder != 0: 227 return alignment - remainder 228 return 0 229 230 231def _extract_constant_segment(mps_graph: MPSGraph) -> Tuple[Cord, MPSGraph]: 232 """Extracts the constant segment from the MPSGraph and returns the updated MPSGraph along with the segment data.""" 233 # Note that the beginning of the segment data is not aligned. Need to handle out of this call. 234 segment_data = Cord() 235 offset = 0 236 for i in range(len(mps_graph.mps_values)): 237 tensor = mps_graph.mps_values[i] 238 if tensor.constant_buffer_size > 0: 239 # Notice that buffer is already force aligned so we don't need to pad it 240 segment_data.append(tensor.constant_buffer.storage) 241 242 # Reset buffer to empty 243 tensor.constant_buffer = Buffer(storage=b"") 244 # Update segment offset 245 tensor.segment_offset = offset 246 offset += tensor.constant_buffer_size 247 248 return segment_data, mps_graph 249 250 251def tensor_to_str(mps_tensor: MPSTensor): 252 tensor_str = "MPSTensor(" 253 tensor_str += "datatype=" + str(mps_tensor.datatype) + ", " 254 tensor_str += "num_dims=" + str(mps_tensor.num_dims) + ", " 255 tensor_str += "dims=" + str(mps_tensor.dims) + ", " 256 tensor_str += "constant_buffer_size=" + str(mps_tensor.constant_buffer_size) + ", " 257 tensor_str += "segment_offset=" + str(mps_tensor.segment_offset) 258 tensor_str += ")" 259 260 return tensor_str 261 262 263def pretty_print(mps_graph: MPSGraph): 264 logging.info("Serialized MPSGraph:") 265 logging.info(f" Version: {mps_graph.version}") 266 logging.info(" MPS nodes: ") 267 for i in range(len(mps_graph.mps_nodes)): 268 logging.info(f" [{i}]: {mps_graph.mps_nodes[i]}") 269 logging.info(" MPS values: ") 270 for i in range(len(mps_graph.mps_values)): 271 logging.info(f" [{i}]: {tensor_to_str(mps_graph.mps_values[i])}") 272 logging.info(" Input ids:") 273 for in_id in mps_graph.input_ids: 274 logging.info(f" {in_id}") 275 logging.info(" Constant ids:") 276 for constant_id in mps_graph.constant_ids: 277 logging.info(f" {constant_id}") 278 logging.info(" Output ids:") 279 for out_id in mps_graph.output_ids: 280 logging.info(f" {out_id}") 281 logging.info(f" Constant segment: {mps_graph.constant_segment}") 282