xref: /aosp_15_r20/external/executorch/backends/apple/mps/mps_preprocess.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1#
2#  Copyright (c) 2023 Apple Inc. All rights reserved.
3#  Provided subject to the LICENSE file in the top level directory.
4#
5import logging
6from typing import ClassVar, Dict, final, List, Tuple
7
8import torch
9
10from executorch.backends.apple.mps.operators.node_visitor import (
11    get_node_visitors,
12    NodeVisitor,
13    process_output_node,
14    process_placeholder_nodes,
15)
16
17from executorch.backends.apple.mps.serialization.mps_graph_schema import (
18    Buffer,
19    DataSegment,
20    MPSGraph,
21    MPSTensor,
22    OpType,
23)
24
25from executorch.backends.apple.mps.serialization.mps_graph_serialize import (
26    convert_to_flatbuffer,
27)
28from executorch.exir._serialize._program import Cord
29
30from executorch.exir.backend.backend_details import (
31    BackendDetails,
32    CompileSpec,
33    PreprocessResult,
34)
35from torch.export.exported_program import ExportedProgram
36
37FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s"
38logging.basicConfig(level=logging.INFO, format=FORMAT)
39
40
41@final
42class MPSBackend(BackendDetails):
43    @staticmethod
44    def slice_len_max(s):
45        assert s.start is not None
46        assert s.stop is not None
47        step = 1
48        if s.step is not None:
49            step = s.step
50        return max((s.stop - s.start) // step, 1)
51
52    MAGIC_IX: ClassVar[slice] = slice(4, 8)
53    DATA_SEGMENT_OFFSET_IX: ClassVar[slice] = slice(8, 16)
54    DATA_SEGMENT_SIZE_IX: ClassVar[slice] = slice(16, 24)
55
56    # magic bytes that should be at the beginning of the header
57    EXPECTED_MAGIC: ClassVar[bytes] = b"MP00"
58    # The length of the header in bytes
59    EXPECTED_LENGTH: ClassVar[int] = (
60        4
61        + slice_len_max(MAGIC_IX)
62        + slice_len_max(DATA_SEGMENT_OFFSET_IX)
63        + slice_len_max(DATA_SEGMENT_SIZE_IX)
64    )
65
66    @staticmethod
67    def preprocess(
68        edge_program: ExportedProgram,
69        compile_specs: List[CompileSpec],
70    ) -> PreprocessResult:
71        # The EdgeIR nodes are processed in the following order:
72        # 1. Process first the input feeds to the graph (in the same
73        #    order as args from forward(*args)), and generate a unique
74        #    id for each input placeholder. Each input id is appended to
75        #    `input_ids` array from the FlatBuffer schema.
76        # 2. Process the nodes the graph (e.g `call_function`). For each
77        #    EdgeIR node, create an equivalent MPS node in the FlatBuffer,
78        #    based on which the MPSGraph is constructed at runtime. During
79        #    this process, any visited constant in the EdgeIR is added to the
80        #    final MPS FlatBuffer schema. Each constant id is appended to the
81        #    `constant_ids` FlatBuffer schema.
82        # 3. After all the inputs, nodes and constants are added to the
83        #    FlatBuffer graph, process the `output` nodes and add their id to
84        #    the `output_ids` array in the schema.
85
86        mps_graph = MPSGraph(
87            version="0",
88            mps_nodes=[],
89            mps_values=[],
90            input_ids=[],
91            output_ids=[],
92            constant_ids=[],
93            graph_type=OpType.mps_graph,
94            constant_segment=DataSegment(0, 0),
95        )
96
97        convert_model_to_fp16 = True
98        for spec in compile_specs:
99            if spec.key == "use_fp16":
100                convert_model_to_fp16 = bool(list(bytes(spec.value))[0])
101
102        logging.debug(f"Convert model to FP16: {convert_model_to_fp16}")
103
104        node_visitors = get_node_visitors(edge_program, convert_model_to_fp16)
105        if logging.DEBUG >= logging.root.level:
106            edge_program.graph.print_tabular()
107
108        process_placeholder_nodes(
109            edge_program,
110            edge_program.graph_module,
111            mps_graph,
112            node_visitors["placeholder"],
113        )
114
115        op_handler = {
116            "call_function": MPSBackend.handle_call_function,
117            "placeholder": MPSBackend.handle_placeholder,
118            "output": MPSBackend.handle_output,
119            "get_attr": MPSBackend.handle_get_attr,
120        }
121
122        for node in edge_program.graph_module.graph.nodes:
123            if node.op not in op_handler:
124                raise RuntimeError(f"{node.op} is not supported in MPS")
125            else:
126                op_handler[node.op](edge_program, node_visitors, node, mps_graph)
127
128        segment_data, mps_graph = _extract_constant_segment(mps_graph)
129        if logging.DEBUG >= logging.root.level:
130            pretty_print(mps_graph)
131
132        # Add to aggregate segments cord with padding.
133        padding_length = _padding_required(len(segment_data), 16)
134        if padding_length > 0:
135            segment_data.append(b"\x00" * padding_length)
136
137        # Combine mps_graph with segment data
138        combined = Cord()
139        graph_bytes = convert_to_flatbuffer(mps_graph)
140
141        data_segment_offset: int = MPSBackend.EXPECTED_LENGTH
142        data_segment_offset = data_segment_offset + len(graph_bytes)
143
144        graph_padding_length = _padding_required(data_segment_offset, 16)
145        data_segment_offset = data_segment_offset + graph_padding_length
146        data_segment_size = len(segment_data)
147
148        data: bytes = (
149            b"\x00\x00\x00\x00"
150            + MPSBackend.EXPECTED_MAGIC
151            + data_segment_offset.to_bytes(8, byteorder="little")
152            + data_segment_size.to_bytes(8, byteorder="little")
153        )
154        assert len(data) == MPSBackend.EXPECTED_LENGTH
155
156        combined.append(data)
157        combined.append(graph_bytes)
158
159        if graph_padding_length > 0:
160            combined.append(b"\x00" * graph_padding_length)
161        # Append the segment data to the end of the mps graph
162        combined.append(segment_data)
163
164        return PreprocessResult(processed_bytes=bytes(combined))
165
166    @staticmethod
167    def handle_call_function(
168        _: ExportedProgram,
169        node_visitors: Dict[str, NodeVisitor],
170        node: torch.fx.Node,
171        mps_graph: MPSGraph,
172    ) -> None:
173        logging.info(f"Visiting: {node}, {node.target.__name__}")
174
175        if (
176            "delegation_tag" in node.meta
177            and "metal_kernel" in node.meta["delegation_tag"]
178        ):
179            logging.info(
180                f"Node '{node.target.__name__}' was marked as a Metal kernel by the MPSPartitioner!"
181            )
182            mps_graph.graph_type = OpType.metal_kernel
183
184        if node.target.__name__ in node_visitors:
185            node_visitors[node.target.__name__].define_node(node, mps_graph)
186        else:
187            pretty_print(mps_graph)
188            raise RuntimeError(
189                f"For {node}, {node.op}:{node.target.__name__} is not supported in MPS delegate"
190            )
191
192    @staticmethod
193    def handle_placeholder(
194        edge_program: ExportedProgram,
195        node_visitors: Dict[str, NodeVisitor],
196        node: torch.fx.Node,
197        mps_graph: MPSGraph,
198    ) -> None:
199        # Constants are handled directly when visiting the nodes.
200        pass
201
202    @staticmethod
203    def handle_output(
204        edge_program: ExportedProgram,
205        node_visitors: Dict[str, NodeVisitor],
206        node: torch.fx.Node,
207        mps_graph: MPSGraph,
208    ) -> None:
209        for output_nodes in node.args:
210            for output_node in output_nodes:
211                process_output_node(output_node, mps_graph, node_visitors[node.op])
212
213    @staticmethod
214    def handle_get_attr(
215        edge_program: ExportedProgram,
216        node_visitors: Dict[str, NodeVisitor],
217        node: torch.fx.Node,
218        mps_graph: MPSGraph,
219    ) -> None:
220        pass
221
222
223def _padding_required(offset: int, alignment: int) -> int:
224    """Returns the padding required to align `offset` to `alignment`."""
225    remainder: int = offset % alignment
226    if remainder != 0:
227        return alignment - remainder
228    return 0
229
230
231def _extract_constant_segment(mps_graph: MPSGraph) -> Tuple[Cord, MPSGraph]:
232    """Extracts the constant segment from the MPSGraph and returns the updated MPSGraph along with the segment data."""
233    # Note that the beginning of the segment data is not aligned. Need to handle out of this call.
234    segment_data = Cord()
235    offset = 0
236    for i in range(len(mps_graph.mps_values)):
237        tensor = mps_graph.mps_values[i]
238        if tensor.constant_buffer_size > 0:
239            # Notice that buffer is already force aligned so we don't need to pad it
240            segment_data.append(tensor.constant_buffer.storage)
241
242            # Reset buffer to empty
243            tensor.constant_buffer = Buffer(storage=b"")
244            # Update segment offset
245            tensor.segment_offset = offset
246            offset += tensor.constant_buffer_size
247
248    return segment_data, mps_graph
249
250
251def tensor_to_str(mps_tensor: MPSTensor):
252    tensor_str = "MPSTensor("
253    tensor_str += "datatype=" + str(mps_tensor.datatype) + ", "
254    tensor_str += "num_dims=" + str(mps_tensor.num_dims) + ", "
255    tensor_str += "dims=" + str(mps_tensor.dims) + ", "
256    tensor_str += "constant_buffer_size=" + str(mps_tensor.constant_buffer_size) + ", "
257    tensor_str += "segment_offset=" + str(mps_tensor.segment_offset)
258    tensor_str += ")"
259
260    return tensor_str
261
262
263def pretty_print(mps_graph: MPSGraph):
264    logging.info("Serialized MPSGraph:")
265    logging.info(f" Version: {mps_graph.version}")
266    logging.info(" MPS nodes: ")
267    for i in range(len(mps_graph.mps_nodes)):
268        logging.info(f"   [{i}]: {mps_graph.mps_nodes[i]}")
269    logging.info(" MPS values: ")
270    for i in range(len(mps_graph.mps_values)):
271        logging.info(f"   [{i}]: {tensor_to_str(mps_graph.mps_values[i])}")
272    logging.info(" Input ids:")
273    for in_id in mps_graph.input_ids:
274        logging.info(f"   {in_id}")
275    logging.info(" Constant ids:")
276    for constant_id in mps_graph.constant_ids:
277        logging.info(f"   {constant_id}")
278    logging.info(" Output ids:")
279    for out_id in mps_graph.output_ids:
280        logging.info(f"   {out_id}")
281    logging.info(f" Constant segment: {mps_graph.constant_segment}")
282