xref: /aosp_15_r20/external/executorch/backends/apple/mps/operators/node_visitor.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1#
2#  Copyright (c) 2023 Apple Inc. All rights reserved.
3#  Provided subject to the LICENSE file in the top level directory.
4#
5
6import ctypes
7import logging
8
9from typing import Dict, List, Tuple, Union
10
11import torch
12
13from executorch.backends.apple.mps.serialization.mps_graph_schema import (
14    Buffer,
15    MPSCast,
16    MPSDataType,
17    MPSGraph,
18    MPSNode,
19    MPSNodeUnion,
20    MPSTensor,
21)
22
23from executorch.backends.apple.mps.utils.mps_utils import (
24    edge_dtype_to_mps_dtype,
25    get_input_node,
26    get_param_tensor,
27    get_scalar_val,
28    is_parameter,
29)
30
31from executorch.backends.transforms import get_shape
32from executorch.exir.sym_util import eval_shape
33
34from torch.export.exported_program import ExportedProgram
35
36
37class NodeVisitor:
38    """
39    Node visitor pattern for visiting nodes in an edge IR graph and
40    serializing them using the mps serialization schema.
41    """
42
43    _tensor_to_id: Dict[torch.fx.Node, int] = {}
44    _convert_model_to_fp16: bool = True
45
46    def __init__(
47        self, exported_program: ExportedProgram, convert_model_to_fp16: bool = True
48    ):
49        self._exported_program = exported_program
50        self._convert_model_to_fp16 = convert_model_to_fp16
51
52    @property
53    def tensor_to_id(self) -> Dict[torch.fx.Node, int]:
54        return self._tensor_to_id
55
56    @property
57    def convert_model_to_fp16(self) -> bool:
58        return self._convert_model_to_fp16
59
60    @property
61    def exported_program(self) -> ExportedProgram:
62        return self._exported_program
63
64    def define_node(
65        self,
66        node: torch.fx.Node,
67        mps_graph: MPSGraph,
68    ) -> None:
69        raise NotImplementedError("NodeVisitor must be extended!")
70
71    def define_tensor(
72        self,
73        node: torch.fx.Node,
74        mps_graph: MPSGraph,
75        mps_data_type: MPSDataType = None,
76    ) -> int:
77        """Defines a tensor value into the MPSGraph serialization schema
78
79        Args:
80            node (torch.fx.Node): EdgeIR tensor to define into mps_graph
81            mps_graph (MPSGraph): MPSGraph object for serializing into flatbuffer
82        """
83
84        if node is None:
85            return -1
86
87        if node in self.tensor_to_id:
88            return self.tensor_to_id[node]
89
90        # Get a unique id for the node.
91        id = self.get_serialized_id(node, mps_graph)
92        cb_size, constant_buffer, mps_data_type = self.get_serialized_buffer(
93            node, mps_graph, id, mps_data_type
94        )
95        dims = get_shape(node)
96
97        logging.debug(
98            f"Serializing: {node}, data type: {node.meta['val'].dtype}, dims: {dims}"
99        )
100        mps_tensor = MPSTensor(
101            datatype=mps_data_type,
102            num_dims=len(dims),
103            dims=dims,
104            constant_buffer_size=cb_size,
105            constant_buffer=constant_buffer,
106        )
107
108        mps_graph.mps_values.append(mps_tensor)
109        return id
110
111    def define_tensor_list(self, node: torch.fx.Node, mps_graph: MPSGraph) -> List[int]:
112        """_summary_
113
114        Args:
115            node (torch.fx.Node): _description_
116            mps_graph (MPSGraph): _description_
117        """
118        if node is None:
119            return -1
120
121        if node in self.tensor_to_id:
122            return self.tensor_to_id[node]
123
124        self.tensor_to_id[node] = []
125        for i in range(len(node.meta["val"])):
126            id = len(mps_graph.mps_values)
127            self.tensor_to_id[node].append(id)
128
129            tensor = node.meta["val"][i]
130            dims = eval_shape(tensor.shape)
131            mps_data_type = edge_dtype_to_mps_dtype(tensor.dtype)
132            logging.debug(
133                f"Serializing: [{i}]: {node}, data type: {tensor.dtype}, dims: {dims}"
134            )
135
136            mps_tensor = MPSTensor(
137                datatype=mps_data_type,
138                num_dims=len(dims),
139                dims=dims,
140                constant_buffer_size=0,
141                constant_buffer=Buffer(storage=b""),
142            )
143            logging.debug(f"  Serialized tensor: {mps_tensor}")
144            mps_graph.mps_values.append(mps_tensor)
145        return self.tensor_to_id[node]
146
147    def hash_tensor(self, tensor):
148        return hash(tuple(tensor.reshape(-1).tolist()))
149
150    def define_constant(
151        self,
152        constant_tensor: torch.tensor,
153        mps_graph: MPSGraph,
154    ):
155        """Defines a scalar value into the MPSGraph serialization schema
156
157        Args:
158            constant_tensor (torch.fx.Node): EdgeIR tensor to define into mps_graph
159            mps_graph (MPSGraph): MPSGraph object for serializing into flatbuffer
160        """
161        constant_tensor = constant_tensor.contiguous()
162        hash = self.hash_tensor(constant_tensor)
163        if hash in self.tensor_to_id:
164            return self.tensor_to_id[hash]
165
166        id = self.get_serialized_id(constant_tensor, mps_graph, hash)
167
168        mps_data_type = edge_dtype_to_mps_dtype(constant_tensor.dtype)
169        constant_buffer_size, constant_buffer, mps_data_type = self.get_serialized_data(
170            constant_tensor, mps_graph, mps_data_type, id
171        )
172        dims = list(constant_tensor.shape)
173
174        mps_tensor = MPSTensor(
175            datatype=mps_data_type,
176            num_dims=len(dims),
177            dims=dims,
178            constant_buffer_size=constant_buffer_size,
179            constant_buffer=constant_buffer,
180        )
181
182        mps_graph.mps_values.append(mps_tensor)
183        return id
184
185    def define_scalar(
186        self,
187        val: Union[float, int],
188        mps_data_type: MPSDataType,
189        mps_graph: MPSGraph,
190    ):
191        """Defines a scalar value into the MPSGraph serialization schema
192
193        Args:
194            mps_graph (MPSGraph): MPSGraph object for serializing into flatbuffer
195        """
196        assert isinstance(val, int) or isinstance(val, float)
197
198        if val in self.tensor_to_id:
199            return self.tensor_to_id[val]
200
201        id = self.get_serialized_id(val, mps_graph, val)
202
203        tensor = torch.tensor(val)
204        constant_buffer_size, constant_buffer, mps_data_type = self.get_serialized_data(
205            tensor, mps_graph, mps_data_type, id
206        )
207
208        mps_tensor = MPSTensor(
209            datatype=mps_data_type,
210            num_dims=1,
211            dims=[1],
212            constant_buffer_size=constant_buffer_size,
213            constant_buffer=constant_buffer,
214        )
215
216        mps_graph.mps_values.append(mps_tensor)
217        return id
218
219    def get_serialized_buffer(
220        self,
221        node: torch.fx.Node,
222        mps_graph: MPSGraph,
223        node_id: int,
224        mps_data_type: MPSDataType = None,
225    ) -> Tuple[int, Buffer, MPSDataType]:
226        """
227        If tensor holds some constant data, serialize it and return the
228        index of its placement in the constant buffer
229
230        Args:
231            node (torch.fx.Node): _description_
232            mps_graph (MPSGraph): _description_
233
234        Returns:
235            _type_: _description_
236        """
237        mps_data_type = (
238            self.get_serialized_dtype(node) if mps_data_type is None else mps_data_type
239        )
240
241        # Check if this node is a lifted parameter
242        if not is_parameter(self.exported_program, node):
243            return 0, Buffer(storage=b""), mps_data_type
244
245        tensor = get_param_tensor(self.exported_program, node)
246        assert tensor is not None and isinstance(tensor, torch.Tensor)
247        tensor = tensor.contiguous()
248
249        return self.get_serialized_data(tensor, mps_graph, mps_data_type, node_id)
250
251    def get_serialized_data(
252        self,
253        tensor: torch.tensor,
254        mps_graph: MPSGraph,
255        mps_data_type: MPSDataType,
256        id: int,
257    ) -> Tuple[int, Buffer, MPSDataType]:
258        if (
259            self.convert_model_to_fp16
260            and mps_data_type == MPSDataType.mps_data_type_float32
261        ):
262            tensor = tensor.half()
263            mps_data_type = MPSDataType.mps_data_type_float16
264
265        if id not in mps_graph.constant_ids:
266            mps_graph.constant_ids.append(id)
267
268        if (
269            mps_data_type is MPSDataType.mps_data_type_int4
270            and tensor.dtype is torch.int8
271        ):
272            if tensor.dim() != 2:
273                raise RuntimeError(f"Unexpected tensor shape {tensor.shape}")
274
275            tensor = tensor.to(dtype=torch.int32)
276            tensor = (((tensor[::, ::2] & 0x0F) << 4) | (tensor[::, 1::2] & 0x0F)).to(
277                torch.uint8
278            )
279            tensor = (
280                torch._convert_weight_to_int4pack(tensor.to("mps"), 2)
281                .cpu()
282                .view(dtype=torch.uint8)
283            )
284        array_type = ctypes.c_char * tensor.untyped_storage().nbytes()
285        array = ctypes.cast(
286            tensor.untyped_storage().data_ptr(),
287            ctypes.POINTER(array_type),
288        ).contents
289        buffer = Buffer(storage=bytes(array))
290
291        return tensor.untyped_storage().nbytes(), buffer, mps_data_type
292
293    def get_serialized_id(
294        self, node: Union[torch.fx.Node, float, int], mps_graph: MPSGraph, hash=None
295    ) -> int:
296        """
297        Map a tensor to a unique id. If the tensor was already mapped, return
298        the existent id.
299
300        Args:
301            node (Union[torch.fx.Node, float]): _description_
302            mps_graph (MPSGraph): _description_
303
304        Returns:
305            int: _description_
306        """
307        if hash is not None and hash in self.tensor_to_id:
308            return self.tensor_to_id[hash]
309        elif node in self.tensor_to_id:
310            return self.tensor_to_id[node]
311
312        id = len(mps_graph.mps_values)
313        if hash is not None:
314            self.tensor_to_id[hash] = id
315        else:
316            self.tensor_to_id[node] = id
317
318        return id
319
320    def torch_dtype_to_mps_dtype(self, torch_dtype: torch.dtype) -> MPSDataType:
321        return edge_dtype_to_mps_dtype(torch_dtype)
322
323    def get_serialized_dtype(
324        self,
325        node: torch.fx.Node,
326    ) -> MPSDataType:
327        return self.torch_dtype_to_mps_dtype(node.meta["val"].dtype)
328
329    def create_tertiary_node(
330        self, node: torch.fx.Node, mps_graph: MPSGraph, tertiary_op: MPSNodeUnion
331    ):
332        input1_id = self.define_tensor(get_input_node(node, 0), mps_graph)
333        input2_id = self.define_tensor(get_input_node(node, 1), mps_graph)
334        input3_id = self.define_tensor(get_input_node(node, 2), mps_graph)
335        output_id = self.define_tensor(node, mps_graph)
336        return MPSNode(
337            mpsnode_union=tertiary_op(
338                input1_id=input1_id,
339                input2_id=input2_id,
340                input3_id=input3_id,
341                output_id=output_id,
342            )
343        )
344
345    def create_binary_node(
346        self, node: torch.fx.Node, mps_graph: MPSGraph, binary_op: MPSNodeUnion
347    ) -> MPSNode:
348        input1_node = get_input_node(node, 0)
349        input1_id = self.define_tensor(input1_node, mps_graph)
350
351        # Handle both tensor and scalar variants of the op.
352        # In case of scalar ops, manually define a constant and serialize it in the FlatBuffer.
353        if isinstance(node.args[1], torch.fx.Node):
354            # Second argument is a node.
355            input2_id = self.define_tensor(get_input_node(node, 1), mps_graph)
356        else:
357            # Second argument is a scalar.
358            scalar_val = get_scalar_val(node, 1)
359            if input1_node.meta["val"].dtype == torch.float32:
360                scalar_val = float(scalar_val)
361            input2_id = self.define_scalar(
362                scalar_val, self.get_serialized_dtype(input1_node), mps_graph
363            )
364
365        output_id = self.define_tensor(node, mps_graph)
366        return MPSNode(
367            mpsnode_union=binary_op(
368                input1_id=input1_id, input2_id=input2_id, output_id=output_id
369            )
370        )
371
372    def create_unary_node(
373        self, node: torch.fx.Node, mps_graph: MPSGraph, unary_op: MPSNodeUnion
374    ) -> MPSNode:
375        input1_id = self.define_tensor(get_input_node(node, 0), mps_graph)
376        output_id = self.define_tensor(node, mps_graph)
377        return MPSNode(mpsnode_union=unary_op(input1_id=input1_id, output_id=output_id))
378
379
380# This will hold mapping of all node names to the visitor class.
381_node_visitor_dict = {}
382
383
384def register_node_visitor(visitor):
385    assert (
386        isinstance(visitor, type)
387        and issubclass(visitor, NodeVisitor)
388        and hasattr(visitor, "target")
389    ), f"Illformed NodeVisitor subclass, can't register!, got: {visitor}"
390    if isinstance(visitor.target, list):
391        for elem in visitor.target:
392            _node_visitor_dict[elem] = visitor
393    else:
394        _node_visitor_dict[visitor.target] = visitor
395
396
397def get_node_visitors(*args) -> Dict[str, NodeVisitor]:
398    node_visitors = {}
399    """
400    Create a new class instance at runtime, and put them in a dict
401    """
402    for target, visitor in _node_visitor_dict.items():
403        assert callable(
404            visitor
405        ), f"Expecting a callable class, but got {visitor} of type {type(visitor)}"
406        node_visitors[target] = visitor(*args)
407
408    placeholder_output_visitor = NodeVisitor(*args)
409    node_visitors["placeholder"] = placeholder_output_visitor
410    node_visitors["output"] = placeholder_output_visitor
411    return node_visitors
412
413
414def process_placeholder_nodes(
415    exported_program: ExportedProgram,
416    edge_graph_module: torch.fx.GraphModule,
417    mps_graph: MPSGraph,
418    placeholder_visitor: NodeVisitor,
419) -> None:
420    # Visit the placeholder nodes in the same order they are passed to the
421    # forward function - forward(*args). When lifted graphs are being used,
422    # parameters/buffers are lifted as placeholders and the order of the args
423    # is not matching anymore with the original graph. We can retrieve the
424    # original order by parsing all the placeholder nodes, and check if they are
425    # constant tensors.
426    #
427    # Constant tensors will be bundled directly in the FlatBuffer and they won't be
428    # provided by ExecuTorch during runtime.
429
430    for node in edge_graph_module.graph.nodes:
431        if node.op == "placeholder" and not is_parameter(
432            exp_prog=exported_program, node=node
433        ):
434            if node.meta["val"] is None:
435                continue
436
437            input_id = placeholder_visitor.define_tensor(node, mps_graph)
438            mps_graph.input_ids.append(input_id)
439
440            if (
441                placeholder_visitor.convert_model_to_fp16
442                and node.meta["val"].dtype == torch.float32
443            ):
444                mps_node = MPSNode(
445                    mpsnode_union=MPSCast(
446                        input1_id=input_id,
447                        output_id=input_id,
448                        dtype=MPSDataType.mps_data_type_float16,
449                    )
450                )
451                mps_graph.mps_nodes.append(mps_node)
452
453
454def process_output_node(
455    output_node,
456    mps_graph: MPSGraph,
457    output_visitor: NodeVisitor,
458) -> None:
459    output_id = output_visitor.define_tensor(output_node, mps_graph)
460    mps_graph.output_ids.append(output_id)
461
462    if (
463        output_visitor.convert_model_to_fp16
464        and output_node.meta["val"].dtype == torch.float32
465    ):
466        mps_node = MPSNode(
467            mpsnode_union=MPSCast(
468                input1_id=output_id,
469                output_id=output_id,
470                dtype=MPSDataType.mps_data_type_float32,
471            )
472        )
473        mps_graph.mps_nodes.append(mps_node)
474